Commit 995e50d6 by xushaohua

feat:部署

parent 8d724f10
...@@ -66,15 +66,15 @@ public class SubmitController { ...@@ -66,15 +66,15 @@ public class SubmitController {
task.setAction(TaskAction.IMAGINE); task.setAction(TaskAction.IMAGINE);
task.setPrompt(prompt); task.setPrompt(prompt);
String promptEn; String promptEn;
int paramStart = prompt.indexOf(" --"); // int paramStart = prompt.indexOf(" --");
if (paramStart > 0) { // if (paramStart > 0) {
promptEn = this.translateService.translateToEnglish(prompt.substring(0, paramStart)).trim() + prompt.substring(paramStart); // promptEn = this.translateService.translateToEnglish(prompt.substring(0, paramStart)).trim() + prompt.substring(paramStart);
} else { // } else {
promptEn = this.translateService.translateToEnglish(prompt).trim(); // promptEn = this.translateService.translateToEnglish(prompt).trim();
} // }
if (CharSequenceUtil.isBlank(promptEn)) { // if (CharSequenceUtil.isBlank(promptEn)) {
promptEn = prompt; promptEn = prompt;
} // }
if (BannedPromptUtils.isBanned(promptEn)) { if (BannedPromptUtils.isBanned(promptEn)) {
return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词"); return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词");
} }
......
...@@ -9,6 +9,7 @@ import com.github.novicezk.midjourney.result.SubmitResultVO; ...@@ -9,6 +9,7 @@ import com.github.novicezk.midjourney.result.SubmitResultVO;
import com.github.novicezk.midjourney.service.NotifyService; import com.github.novicezk.midjourney.service.NotifyService;
import com.github.novicezk.midjourney.service.TaskStoreService; import com.github.novicezk.midjourney.service.TaskStoreService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.json.JSONObject;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
...@@ -60,7 +61,7 @@ public class TaskQueueHelper { ...@@ -60,7 +61,7 @@ public class TaskQueueHelper {
} }
public Stream<Task> findRunningTask(Predicate<Task> condition) { public Stream<Task> findRunningTask(Predicate<Task> condition) {
return this.runningTasks.stream().filter(condition); return this.runningTasks.stream().filter(condition);
} }
public Future<?> getRunningFuture(String taskId) { public Future<?> getRunningFuture(String taskId) {
...@@ -118,6 +119,7 @@ public class TaskQueueHelper { ...@@ -118,6 +119,7 @@ public class TaskQueueHelper {
// 更新任务状态为已提交 // 更新任务状态为已提交
changeStatusAndNotify(task, TaskStatus.SUBMITTED); changeStatusAndNotify(task, TaskStatus.SUBMITTED);
do { do {
// 任务提交后让当前任务线程处于休眠状态,当mj有回调监听消息时,会在ImagineMessageHandler中唤醒线程,此时这里就会继续执行更新任务状态
task.sleep(); task.sleep();
changeStatusAndNotify(task, task.getStatus()); changeStatusAndNotify(task, task.getStatus());
} while (task.getStatus() == TaskStatus.IN_PROGRESS); } while (task.getStatus() == TaskStatus.IN_PROGRESS);
...@@ -140,6 +142,7 @@ public class TaskQueueHelper { ...@@ -140,6 +142,7 @@ public class TaskQueueHelper {
* @param status * @param status
*/ */
public void changeStatusAndNotify(Task task, TaskStatus status) { public void changeStatusAndNotify(Task task, TaskStatus status) {
log.info("更新任务状态:{}", status);
// 更新任务状态 // 更新任务状态
task.setStatus(status); task.setStatus(status);
// 更新map值 // 更新map值
......
...@@ -11,6 +11,7 @@ import com.github.novicezk.midjourney.util.ContentParseData; ...@@ -11,6 +11,7 @@ import com.github.novicezk.midjourney.util.ContentParseData;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.entities.Message; import net.dv8tion.jda.api.entities.Message;
import net.dv8tion.jda.api.utils.data.DataObject; import net.dv8tion.jda.api.utils.data.DataObject;
import org.json.JSONObject;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.Set; import java.util.Set;
...@@ -31,12 +32,14 @@ public class ImagineMessageHandler extends MessageHandler { ...@@ -31,12 +32,14 @@ public class ImagineMessageHandler extends MessageHandler {
@Override @Override
public void handle(MessageType messageType, DataObject message) { public void handle(MessageType messageType, DataObject message) {
log.info("监听处理绘图数据:{}", messageType);
String content = getMessageContent(message); String content = getMessageContent(message);
ContentParseData parseData = parse(content); ContentParseData parseData = parse(content);
if (parseData == null) { if (parseData == null) {
return; return;
} }
String realPrompt = this.discordHelper.getRealPrompt(parseData.getPrompt()); String realPrompt = this.discordHelper.getRealPrompt(parseData.getPrompt());
log.info("realPrompt=======>{}", realPrompt);
if (MessageType.CREATE == messageType) { if (MessageType.CREATE == messageType) {
if ("Waiting to start".equals(parseData.getStatus())) { if ("Waiting to start".equals(parseData.getStatus())) {
// 开始 // 开始
...@@ -73,6 +76,7 @@ public class ImagineMessageHandler extends MessageHandler { ...@@ -73,6 +76,7 @@ public class ImagineMessageHandler extends MessageHandler {
.setStatusSet(Set.of(TaskStatus.SUBMITTED, TaskStatus.IN_PROGRESS)); .setStatusSet(Set.of(TaskStatus.SUBMITTED, TaskStatus.IN_PROGRESS));
Task task = this.taskQueueHelper.findRunningTask(taskPredicate(condition, realPrompt)) Task task = this.taskQueueHelper.findRunningTask(taskPredicate(condition, realPrompt))
.findFirst().orElse(null); .findFirst().orElse(null);
log.info("进度更新:{}", JSONObject.valueToString(task));
if (task == null) { if (task == null) {
return; return;
} }
...@@ -142,7 +146,7 @@ public class ImagineMessageHandler extends MessageHandler { ...@@ -142,7 +146,7 @@ public class ImagineMessageHandler extends MessageHandler {
} }
private Predicate<Task> taskPredicate(TaskCondition condition, String prompt) { private Predicate<Task> taskPredicate(TaskCondition condition, String prompt) {
return condition.and(t -> prompt.startsWith(t.getPromptEn())); return condition.and(t -> prompt.replaceAll("\\s", "").equals(t.getPromptEn().replaceAll("\\s", "")));
} }
private ContentParseData parse(String content) { private ContentParseData parse(String content) {
......
package com.github.novicezk.midjourney.wss.user; package com.github.novicezk.midjourney.wss.user;
import cn.hutool.json.JSON;
import com.github.novicezk.midjourney.ProxyProperties; import com.github.novicezk.midjourney.ProxyProperties;
import com.github.novicezk.midjourney.enums.MessageType; import com.github.novicezk.midjourney.enums.MessageType;
import com.github.novicezk.midjourney.wss.handle.MessageHandler; import com.github.novicezk.midjourney.wss.handle.MessageHandler;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataObject; import net.dv8tion.jda.api.utils.data.DataObject;
import org.json.JSONObject;
import org.springframework.boot.context.event.ApplicationStartedEvent; import org.springframework.boot.context.event.ApplicationStartedEvent;
import org.springframework.context.ApplicationListener; import org.springframework.context.ApplicationListener;
......
...@@ -102,6 +102,7 @@ public class UserWebSocketStarter extends WebSocketAdapter implements WebSocketS ...@@ -102,6 +102,7 @@ public class UserWebSocketStarter extends WebSocketAdapter implements WebSocketS
String json = new String(decompressBinary, StandardCharsets.UTF_8); String json = new String(decompressBinary, StandardCharsets.UTF_8);
DataObject data = DataObject.fromJson(json); DataObject data = DataObject.fromJson(json);
int opCode = data.getInt("op"); int opCode = data.getInt("op");
log.info("监听二进制消息类型opCode:{}", opCode);
if (opCode != WebSocketCode.HEARTBEAT_ACK) { if (opCode != WebSocketCode.HEARTBEAT_ACK) {
this.sequence.incrementAndGet(); this.sequence.incrementAndGet();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment