Commit 995e50d6 by xushaohua

feat:部署

parent 8d724f10
......@@ -66,15 +66,15 @@ public class SubmitController {
task.setAction(TaskAction.IMAGINE);
task.setPrompt(prompt);
String promptEn;
int paramStart = prompt.indexOf(" --");
if (paramStart > 0) {
promptEn = this.translateService.translateToEnglish(prompt.substring(0, paramStart)).trim() + prompt.substring(paramStart);
} else {
promptEn = this.translateService.translateToEnglish(prompt).trim();
}
if (CharSequenceUtil.isBlank(promptEn)) {
// int paramStart = prompt.indexOf(" --");
// if (paramStart > 0) {
// promptEn = this.translateService.translateToEnglish(prompt.substring(0, paramStart)).trim() + prompt.substring(paramStart);
// } else {
// promptEn = this.translateService.translateToEnglish(prompt).trim();
// }
// if (CharSequenceUtil.isBlank(promptEn)) {
promptEn = prompt;
}
// }
if (BannedPromptUtils.isBanned(promptEn)) {
return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词");
}
......
......@@ -9,6 +9,7 @@ import com.github.novicezk.midjourney.result.SubmitResultVO;
import com.github.novicezk.midjourney.service.NotifyService;
import com.github.novicezk.midjourney.service.TaskStoreService;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONObject;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;
......@@ -118,6 +119,7 @@ public class TaskQueueHelper {
// 更新任务状态为已提交
changeStatusAndNotify(task, TaskStatus.SUBMITTED);
do {
// 任务提交后让当前任务线程处于休眠状态,当mj有回调监听消息时,会在ImagineMessageHandler中唤醒线程,此时这里就会继续执行更新任务状态
task.sleep();
changeStatusAndNotify(task, task.getStatus());
} while (task.getStatus() == TaskStatus.IN_PROGRESS);
......@@ -140,6 +142,7 @@ public class TaskQueueHelper {
* @param status
*/
public void changeStatusAndNotify(Task task, TaskStatus status) {
log.info("更新任务状态:{}", status);
// 更新任务状态
task.setStatus(status);
// 更新map值
......
......@@ -11,6 +11,7 @@ import com.github.novicezk.midjourney.util.ContentParseData;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.entities.Message;
import net.dv8tion.jda.api.utils.data.DataObject;
import org.json.JSONObject;
import org.springframework.stereotype.Component;
import java.util.Set;
......@@ -31,12 +32,14 @@ public class ImagineMessageHandler extends MessageHandler {
@Override
public void handle(MessageType messageType, DataObject message) {
log.info("监听处理绘图数据:{}", messageType);
String content = getMessageContent(message);
ContentParseData parseData = parse(content);
if (parseData == null) {
return;
}
String realPrompt = this.discordHelper.getRealPrompt(parseData.getPrompt());
log.info("realPrompt=======>{}", realPrompt);
if (MessageType.CREATE == messageType) {
if ("Waiting to start".equals(parseData.getStatus())) {
// 开始
......@@ -73,6 +76,7 @@ public class ImagineMessageHandler extends MessageHandler {
.setStatusSet(Set.of(TaskStatus.SUBMITTED, TaskStatus.IN_PROGRESS));
Task task = this.taskQueueHelper.findRunningTask(taskPredicate(condition, realPrompt))
.findFirst().orElse(null);
log.info("进度更新:{}", JSONObject.valueToString(task));
if (task == null) {
return;
}
......@@ -142,7 +146,7 @@ public class ImagineMessageHandler extends MessageHandler {
}
private Predicate<Task> taskPredicate(TaskCondition condition, String prompt) {
return condition.and(t -> prompt.startsWith(t.getPromptEn()));
return condition.and(t -> prompt.replaceAll("\\s", "").equals(t.getPromptEn().replaceAll("\\s", "")));
}
private ContentParseData parse(String content) {
......
package com.github.novicezk.midjourney.wss.user;
import cn.hutool.json.JSON;
import com.github.novicezk.midjourney.ProxyProperties;
import com.github.novicezk.midjourney.enums.MessageType;
import com.github.novicezk.midjourney.wss.handle.MessageHandler;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataObject;
import org.json.JSONObject;
import org.springframework.boot.context.event.ApplicationStartedEvent;
import org.springframework.context.ApplicationListener;
......
......@@ -102,6 +102,7 @@ public class UserWebSocketStarter extends WebSocketAdapter implements WebSocketS
String json = new String(decompressBinary, StandardCharsets.UTF_8);
DataObject data = DataObject.fromJson(json);
int opCode = data.getInt("op");
log.info("监听二进制消息类型opCode:{}", opCode);
if (opCode != WebSocketCode.HEARTBEAT_ACK) {
this.sequence.incrementAndGet();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment