diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/pom.xml b/bigtop-manager-ai/bigtop-manager-ai-assistant/pom.xml index 6f1fea818..62ca8b3f7 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/pom.xml +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/pom.xml @@ -39,6 +39,10 @@ org.apache.bigtop bigtop-manager-ai-core + + org.apache.bigtop + bigtop-manager-dao + diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java index 279567837..dfae2f71b 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java @@ -21,12 +21,16 @@ import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider; import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory; import org.apache.bigtop.manager.ai.core.enums.PlatformType; +import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; +import org.apache.bigtop.manager.ai.core.exception.PlatformNotFoundException; import org.apache.bigtop.manager.ai.core.factory.AIAssistant; import org.apache.bigtop.manager.ai.core.factory.ToolBox; import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider; import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider; import org.apache.bigtop.manager.ai.openai.OpenAIAssistant; +import org.apache.commons.lang3.NotImplementedException; + import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.store.memory.chat.ChatMemoryStore; import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore; @@ -35,13 +39,19 @@ public class GeneralAssistantFactory extends AbstractAIAssistantFactory { - private SystemPromptProvider systemPromptProvider = new LocSystemPromptProvider(); - private ChatMemoryStore chatMemoryStore = new InMemoryChatMemoryStore(); + private final SystemPromptProvider systemPromptProvider; + private final ChatMemoryStore chatMemoryStore; - public GeneralAssistantFactory() {} + public GeneralAssistantFactory() { + this(new LocSystemPromptProvider(), new InMemoryChatMemoryStore()); + } public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider) { - this.systemPromptProvider = systemPromptProvider; + this(systemPromptProvider, new InMemoryChatMemoryStore()); + } + + public GeneralAssistantFactory(ChatMemoryStore chatMemoryStore) { + this(new LocSystemPromptProvider(), chatMemoryStore); } public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider, ChatMemoryStore chatMemoryStore) { @@ -51,29 +61,33 @@ public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider, ChatMe @Override public AIAssistant createWithPrompt( - PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id, Object promptId) { - AIAssistant aiAssistant = create(platformType, assistantConfig, id); - SystemMessage systemPrompt = systemPromptProvider.getSystemPrompt(promptId); - aiAssistant.setSystemPrompt(systemPrompt); - return aiAssistant; - } - - @Override - public AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id) { + PlatformType platformType, + AIAssistantConfigProvider assistantConfig, + Object id, + SystemPrompt systemPrompts) { + AIAssistant aiAssistant; if (Objects.requireNonNull(platformType) == PlatformType.OPENAI) { - AIAssistant aiAssistant = OpenAIAssistant.builder() + aiAssistant = OpenAIAssistant.builder() .id(id) .memoryStore(chatMemoryStore) .withConfigProvider(assistantConfig) .build(); - aiAssistant.setSystemPrompt(systemPromptProvider.getSystemPrompt()); - return aiAssistant; + } else { + throw new PlatformNotFoundException(platformType.getValue()); } - return null; + + SystemMessage systemPrompt = systemPromptProvider.getSystemPrompt(systemPrompts); + aiAssistant.setSystemPrompt(systemPrompt); + return aiAssistant; + } + + @Override + public AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id) { + return createWithPrompt(platformType, assistantConfig, id, SystemPrompt.DEFAULT_PROMPT); } @Override public ToolBox createToolBox(PlatformType platformType) { - return null; + throw new NotImplementedException("ToolBox is not implemented for GeneralAssistantFactory"); } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfig.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfig.java index f632f8392..5266eadb1 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfig.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/AIAssistantConfig.java @@ -24,42 +24,83 @@ import java.util.Map; public class AIAssistantConfig implements AIAssistantConfigProvider { - private final Map configMap; - private AIAssistantConfig(Map configMap) { - this.configMap = configMap; + /** + * Model name for platform that we want to use + */ + private final String model; + + /** + * Credentials for different platforms + */ + private final Map credentials; + + /** + * Platform extra configs are put here + */ + private final Map configs; + + private AIAssistantConfig(String model, Map credentials, Map configMap) { + this.model = model; + this.credentials = credentials; + this.configs = configMap; } public static Builder builder() { return new Builder(); } - public static Builder withDefault(String baseUrl, String apiKey) { - Builder builder = new Builder(); - return builder.set("baseUrl", baseUrl).set("apiKey", apiKey); + @Override + public String getModel() { + return model; } @Override - public Map configs() { + public Map getCredentials() { + return credentials; + } - return configMap; + @Override + public Map getConfigs() { + return configs; } public static class Builder { - private final Map configs; + private String model; + + private final Map credentials = new HashMap<>(); + + private final Map configs = new HashMap<>(); - public Builder() { - configs = new HashMap<>(); - configs.put("memoryLen", "30"); + public Builder() {} + + public Builder setModel(String model) { + this.model = model; + return this; } - public Builder set(String key, String value) { + public Builder addCredential(String key, String value) { + credentials.put(key, value); + return this; + } + + public Builder addCredentials(Map credentialMap) { + credentials.putAll(credentialMap); + return this; + } + + public Builder addConfig(String key, String value) { configs.put(key, value); return this; } + public Builder addConfigs(Map configMap) { + configs.putAll(configMap); + return this; + } + public AIAssistantConfig build() { - return new AIAssistantConfig(configs); + return new AIAssistantConfig(model, credentials, configs); } } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/LocSystemPromptProvider.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/LocSystemPromptProvider.java index 1603601a1..b61dd16cd 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/LocSystemPromptProvider.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/provider/LocSystemPromptProvider.java @@ -18,6 +18,7 @@ */ package org.apache.bigtop.manager.ai.assistant.provider; +import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider; import org.springframework.util.ResourceUtils; @@ -29,41 +30,37 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; -import java.util.Objects; @Slf4j public class LocSystemPromptProvider implements SystemPromptProvider { - public static final String DEFAULT = "default"; private static final String SYSTEM_PROMPT_PATH = "src/main/resources/"; - private static final String DEFAULT_NAME = "big-data-professor.st"; @Override - public SystemMessage getSystemPrompt(Object id) { - if (Objects.equals(id.toString(), DEFAULT)) { - return getSystemPrompt(); - } else { - return loadPromptFromFile(id.toString()); + public SystemMessage getSystemPrompt(SystemPrompt systemPrompt) { + if (systemPrompt == SystemPrompt.DEFAULT_PROMPT) { + systemPrompt = SystemPrompt.BIGDATA_PROFESSOR; } + + return loadPromptFromFile(systemPrompt.getValue()); } @Override public SystemMessage getSystemPrompt() { - return loadPromptFromFile(DEFAULT_NAME); + return getSystemPrompt(SystemPrompt.DEFAULT_PROMPT); } private SystemMessage loadPromptFromFile(String fileName) { - final String filePath = SYSTEM_PROMPT_PATH + fileName; + final String filePath = SYSTEM_PROMPT_PATH + fileName + ".st"; try { File file = ResourceUtils.getFile(filePath); String text = Files.readString(file.toPath(), StandardCharsets.UTF_8); return SystemMessage.from(text); } catch (IOException e) { - // log.error( "Exception occurred while loading SystemPrompt from local. Here is some information:{}", e.getMessage()); - return SystemMessage.from(""); + return SystemMessage.from("You are a helpful assistant."); } } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java new file mode 100644 index 000000000..ba4fd95ef --- /dev/null +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.ai.assistant.store; + +import org.apache.bigtop.manager.ai.core.enums.MessageSender; +import org.apache.bigtop.manager.dao.po.ChatMessagePO; +import org.apache.bigtop.manager.dao.po.ChatThreadPO; +import org.apache.bigtop.manager.dao.repository.ChatMessageDao; +import org.apache.bigtop.manager.dao.repository.ChatThreadDao; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.store.memory.chat.ChatMemoryStore; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +public class PersistentChatMemoryStore implements ChatMemoryStore { + + private final ChatThreadDao chatThreadDao; + private final ChatMessageDao chatMessageDao; + + public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) { + this.chatThreadDao = chatThreadDao; + this.chatMessageDao = chatMessageDao; + } + + private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) { + String sender = chatMessagePO.getSender().toLowerCase(); + if (sender.equals(MessageSender.AI.getValue())) { + return new AiMessage(chatMessagePO.getMessage()); + } else if (sender.equals(MessageSender.USER.getValue())) { + return new UserMessage(chatMessagePO.getMessage()); + } else if (sender.equals(MessageSender.SYSTEM.getValue())) { + return new SystemMessage(chatMessagePO.getMessage()); + } else { + return null; + } + } + + private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatThreadId) { + ChatMessagePO chatMessagePO = new ChatMessagePO(); + if (chatMessage.type().equals(ChatMessageType.AI)) { + chatMessagePO.setSender(MessageSender.AI.getValue()); + AiMessage aiMessage = (AiMessage) chatMessage; + chatMessagePO.setMessage(aiMessage.text()); + } else if (chatMessage.type().equals(ChatMessageType.USER)) { + chatMessagePO.setSender(MessageSender.USER.getValue()); + UserMessage userMessage = (UserMessage) chatMessage; + chatMessagePO.setMessage(userMessage.singleText()); + } else if (chatMessage.type().equals(ChatMessageType.SYSTEM)) { + chatMessagePO.setSender(MessageSender.SYSTEM.getValue()); + SystemMessage systemMessage = (SystemMessage) chatMessage; + chatMessagePO.setMessage(systemMessage.text()); + } else { + chatMessagePO.setSender(chatMessage.type().toString()); + } + ChatThreadPO chatThreadPO = chatThreadDao.findById(chatThreadId); + chatMessagePO.setUserId(chatThreadPO.getUserId()); + chatMessagePO.setThreadId(chatThreadId); + return chatMessagePO; + } + + @Override + public List getMessages(Object threadId) { + List chatMessages = chatMessageDao.findAllByThreadId((Long) threadId); + if (chatMessages.isEmpty()) { + return new ArrayList<>(); + } else { + return chatMessages.stream().map(this::convertToChatMessage).collect(Collectors.toList()); + } + } + + @Override + public void updateMessages(Object threadId, List messages) { + ChatMessagePO chatMessagePO = convertToChatMessagePO(messages.get(messages.size() - 1), (Long) threadId); + chatMessageDao.save(chatMessagePO); + } + + @Override + public void deleteMessages(Object threadId) { + chatMessageDao.deleteByThreadId((Long) threadId); + } +} diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/AIAssistantServiceTest.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/AIAssistantServiceTest.java index 40c21d58d..0e623ef00 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/AIAssistantServiceTest.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/AIAssistantServiceTest.java @@ -32,18 +32,18 @@ import dev.langchain4j.model.openai.OpenAiChatModelName; import reactor.core.publisher.Flux; +import java.util.UUID; + import static org.junit.jupiter.api.Assertions.assertFalse; import static org.mockito.Mockito.when; public class AIAssistantServiceTest { private AIAssistantConfigProvider configProvider = AIAssistantConfig.builder() - .set("apiKey", "sk-") + .addConfig("apiKey", "sk-") // The `baseUrl` has a default value that is automatically generated based on the `PlatformType`. - .set("baseUrl", "https://api.openai.com/v1") - // default 30 - .set("memoryLen", "10") - .set("modelName", OpenAiChatModelName.GPT_3_5_TURBO.toString()) + .addConfig("baseUrl", "https://api.openai.com/v1") + .addConfig("modelName", OpenAiChatModelName.GPT_3_5_TURBO.toString()) .build(); @Mock @@ -52,6 +52,8 @@ public class AIAssistantServiceTest { @Mock private AIAssistantFactory aiAssistantFactory; + private final String threadId = UUID.randomUUID().toString(); + @BeforeEach public void init() { MockitoAnnotations.openMocks(this); @@ -62,13 +64,14 @@ public void init() { emmit.next(text.charAt(i) + ""); } })); - when(aiAssistantFactory.create(PlatformType.OPENAI, configProvider)).thenReturn(this.aiAssistant); - when(aiAssistant.getPlatform()).thenReturn(PlatformType.OPENAI.getValue()); + when(aiAssistantFactory.create(PlatformType.OPENAI, configProvider, threadId)) + .thenReturn(this.aiAssistant); + when(aiAssistant.getPlatform()).thenReturn(PlatformType.OPENAI); } @Test public void createNew2SimpleChat() { - AIAssistant aiAssistant = aiAssistantFactory.create(PlatformType.OPENAI, configProvider); + AIAssistant aiAssistant = aiAssistantFactory.create(PlatformType.OPENAI, configProvider, threadId); String ask = aiAssistant.ask("1?"); assertFalse(ask.isEmpty()); System.out.println(ask); @@ -76,7 +79,7 @@ public void createNew2SimpleChat() { @Test public void createNew2StreamChat() throws InterruptedException { - AIAssistant aiAssistant = aiAssistantFactory.create(PlatformType.OPENAI, configProvider); + AIAssistant aiAssistant = aiAssistantFactory.create(PlatformType.OPENAI, configProvider, threadId); Flux stringFlux = aiAssistant.streamAsk("stream 1?"); stringFlux.subscribe( System.out::println, diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/SystemPromptProviderTests.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/SystemPromptProviderTests.java index db9be1ae0..1f768e17d 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/SystemPromptProviderTests.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/test/java/org/apache/bigtop/manager/ai/assistant/SystemPromptProviderTests.java @@ -19,6 +19,7 @@ package org.apache.bigtop.manager.ai.assistant; import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider; +import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider; import org.junit.jupiter.api.Test; @@ -30,22 +31,15 @@ public class SystemPromptProviderTests { - private SystemPromptProvider systemPromptProvider = new LocSystemPromptProvider(); - - @Test - public void loadSystemPromptTest() { - System.out.println(systemPromptProvider.getSystemPrompt()); - } + private final SystemPromptProvider systemPromptProvider = new LocSystemPromptProvider(); @Test public void loadSystemPromptByIdTest() { - SystemMessage systemPrompt1 = systemPromptProvider.getSystemPrompt("big-data-professor.st"); + SystemMessage systemPrompt1 = systemPromptProvider.getSystemPrompt(SystemPrompt.BIGDATA_PROFESSOR); assertFalse(systemPrompt1.text().isEmpty()); - System.out.println(systemPrompt1.text()); - SystemMessage systemPrompt2 = systemPromptProvider.getSystemPrompt(LocSystemPromptProvider.DEFAULT); + SystemMessage systemPrompt2 = systemPromptProvider.getSystemPrompt(); assertFalse(systemPrompt2.text().isEmpty()); - System.out.println(systemPrompt2.text()); assertEquals(systemPrompt1.text(), systemPrompt2.text()); } diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java index e5087e303..5b0f383a8 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/AbstractAIAssistant.java @@ -37,6 +37,8 @@ public abstract class AbstractAIAssistant implements AIAssistant { private final Object assistantId; private final ChatMemory chatMemory; + protected static final Integer MEMORY_LEN = 10; + public AbstractAIAssistant( ChatLanguageModel chatLanguageModel, StreamingChatLanguageModel streamingChatLanguageModel, @@ -50,29 +52,25 @@ public AbstractAIAssistant( @Override public Flux streamAsk(ChatMessage chatMessage) { chatMemory.add(chatMessage); - Flux streamAiMessage = Flux.create( - emitter -> { - streamingChatLanguageModel.generate(chatMemory.messages(), new StreamingResponseHandler<>() { - @Override - public void onNext(String token) { - emitter.next(token); - } + return Flux.create( + emitter -> streamingChatLanguageModel.generate(chatMemory.messages(), new StreamingResponseHandler<>() { + @Override + public void onNext(String token) { + emitter.next(token); + } - @Override - public void onError(Throwable error) { - emitter.error(error); - } + @Override + public void onError(Throwable error) { + emitter.error(error); + } - @Override - public void onComplete(Response response) { - StreamingResponseHandler.super.onComplete(response); - chatMemory.add(response.content()); - } - }); - }, + @Override + public void onComplete(Response response) { + StreamingResponseHandler.super.onComplete(response); + chatMemory.add(response.content()); + } + }), FluxSink.OverflowStrategy.BUFFER); - - return streamAiMessage; } @Override diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/MessageSender.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/MessageSender.java new file mode 100644 index 000000000..1c93085a8 --- /dev/null +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/MessageSender.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.ai.core.enums; + +import lombok.Getter; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +@Getter +public enum MessageSender { + USER("user"), + AI("ai"), + SYSTEM("system"); + + private final String value; + + MessageSender(String value) { + this.value = value; + } + + public static List getSenders() { + return Arrays.stream(values()).map(item -> item.value).collect(Collectors.toList()); + } + + public static MessageSender getMessageSender(String value) { + if (Objects.isNull(value) || value.isEmpty()) { + return null; + } + for (MessageSender messageSender : MessageSender.values()) { + if (messageSender.value.equals(value)) { + return messageSender; + } + } + + return null; + } +} diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/PlatformType.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/PlatformType.java index 0f5cb6082..cadbb2a66 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/PlatformType.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/PlatformType.java @@ -18,11 +18,14 @@ */ package org.apache.bigtop.manager.ai.core.enums; +import lombok.Getter; + import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; +@Getter public enum PlatformType { OPENAI("openai"); @@ -47,8 +50,4 @@ public static PlatformType getPlatformType(String value) { } return null; } - - public String getValue() { - return this.value; - } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/SystemPrompt.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/SystemPrompt.java new file mode 100644 index 000000000..be0eea340 --- /dev/null +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/enums/SystemPrompt.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.ai.core.enums; + +import lombok.Getter; + +@Getter +public enum SystemPrompt { + DEFAULT_PROMPT("default"), + BIGDATA_PROFESSOR("big-data-professor"), + LANGUAGE_PROMPT("language-prompt"); + ; + + private final String value; + + SystemPrompt(String value) { + this.value = value; + } +} diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java index 8f4ef8e2e..6834b0513 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistant.java @@ -18,6 +18,8 @@ */ package org.apache.bigtop.manager.ai.core.factory; +import org.apache.bigtop.manager.ai.core.enums.PlatformType; + import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; @@ -57,7 +59,7 @@ public interface AIAssistant { * This is used to get the AIAssistant's Platform * @return */ - String getPlatform(); + PlatformType getPlatform(); void setSystemPrompt(SystemMessage systemPrompt); diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java index 6610f9c4b..d6b24034b 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/factory/AIAssistantFactory.java @@ -19,51 +19,28 @@ package org.apache.bigtop.manager.ai.core.factory; import org.apache.bigtop.manager.ai.core.enums.PlatformType; -import org.apache.bigtop.manager.ai.core.exception.PlatformNotFoundException; +import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider; -import java.util.Objects; import java.util.UUID; public interface AIAssistantFactory { AIAssistant createWithPrompt( - PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id, Object promptId); + PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id, SystemPrompt systemPrompt); AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id); - ToolBox createToolBox(PlatformType platformType); - - default AIAssistant createWithPrompt( - PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object prompt) { - return createWithPrompt(platformType, assistantConfig, UUID.randomUUID().toString(), prompt); - } - - default AIAssistant create(String platform, AIAssistantConfigProvider assistantConfigProvider, Object id) { - PlatformType platformType = PlatformType.getPlatformType(platform); - if (Objects.isNull(platformType)) { - throw new PlatformNotFoundException(platform); - } - return create(platformType, assistantConfigProvider, id); - } - - default AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfigProvider) { - return create(platformType, assistantConfigProvider, UUID.randomUUID().toString()); - } - - default AIAssistant create(String platform, AIAssistantConfigProvider assistantConfig) { - PlatformType platformType = PlatformType.getPlatformType(platform); - if (Objects.isNull(platformType)) { - throw new PlatformNotFoundException(platform); - } - return create(platformType, assistantConfig); + /** + * TODO Create AIAssistant without memory, should delete UUID + * + * @param platformType platform type + * @param assistantConfig assistant config + * @return AIAssistant + */ + default AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig) { + return create(platformType, assistantConfig, UUID.randomUUID().toString()); } - default ToolBox createToolBox(String platform) { - PlatformType platformType = PlatformType.getPlatformType(platform); - if (Objects.isNull(platformType)) { - throw new PlatformNotFoundException(platform); - } - return createToolBox(platformType); - } + ToolBox createToolBox(PlatformType platformType); } diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/AIAssistantConfigProvider.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/AIAssistantConfigProvider.java index ea05b98ed..04f9dbd59 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/AIAssistantConfigProvider.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/AIAssistantConfigProvider.java @@ -21,5 +21,9 @@ import java.util.Map; public interface AIAssistantConfigProvider { - Map configs(); + String getModel(); + + Map getCredentials(); + + Map getConfigs(); } diff --git a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/SystemPromptProvider.java b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/SystemPromptProvider.java index 5d6865fbd..4340fc8d0 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/SystemPromptProvider.java +++ b/bigtop-manager-ai/bigtop-manager-ai-core/src/main/java/org/apache/bigtop/manager/ai/core/provider/SystemPromptProvider.java @@ -18,12 +18,13 @@ */ package org.apache.bigtop.manager.ai.core.provider; +import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; + import dev.langchain4j.data.message.SystemMessage; public interface SystemPromptProvider { - // Return the SystemPrompt for the specified ID. - SystemMessage getSystemPrompt(Object id); + SystemMessage getSystemPrompt(SystemPrompt systemPrompt); // return default system prompt SystemMessage getSystemPrompt(); diff --git a/bigtop-manager-ai/bigtop-manager-ai-openai/src/main/java/org/apache/bigtop/manager/ai/openai/OpenAIAssistant.java b/bigtop-manager-ai/bigtop-manager-ai-openai/src/main/java/org/apache/bigtop/manager/ai/openai/OpenAIAssistant.java index 1d76b6254..f5325802b 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-openai/src/main/java/org/apache/bigtop/manager/ai/openai/OpenAIAssistant.java +++ b/bigtop-manager-ai/bigtop-manager-ai-openai/src/main/java/org/apache/bigtop/manager/ai/openai/OpenAIAssistant.java @@ -19,11 +19,10 @@ package org.apache.bigtop.manager.ai.openai; import org.apache.bigtop.manager.ai.core.AbstractAIAssistant; +import org.apache.bigtop.manager.ai.core.enums.PlatformType; import org.apache.bigtop.manager.ai.core.factory.AIAssistant; import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider; -import org.springframework.util.NumberUtils; - import dev.langchain4j.internal.ValidationUtils; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -33,14 +32,9 @@ import dev.langchain4j.model.openai.OpenAiStreamingChatModel; import dev.langchain4j.store.memory.chat.ChatMemoryStore; -import java.util.HashMap; -import java.util.Map; - public class OpenAIAssistant extends AbstractAIAssistant { - private static final String PLATFORM_NAME = "openai"; private static final String BASE_URL = "https://api.openai.com/v1"; - private static final String MODEL_NAME = "gpt-3.5-turbo"; private OpenAIAssistant( ChatLanguageModel chatLanguageModel, @@ -50,8 +44,8 @@ private OpenAIAssistant( } @Override - public String getPlatform() { - return PLATFORM_NAME; + public PlatformType getPlatform() { + return PlatformType.OPENAI; } public static Builder builder() { @@ -61,16 +55,13 @@ public static Builder builder() { public static class Builder { private Object id; - private Map configs = new HashMap<>(); private ChatMemoryStore chatMemoryStore; + private AIAssistantConfigProvider configProvider; - public Builder() { - configs.put("baseUrl", BASE_URL); - configs.put("modelName", MODEL_NAME); - } + public Builder() {} public Builder withConfigProvider(AIAssistantConfigProvider configProvider) { - this.configs = configProvider.configs(); + this.configProvider = configProvider; return this; } @@ -86,25 +77,23 @@ public Builder memoryStore(ChatMemoryStore chatMemoryStore) { public AIAssistant build() { ValidationUtils.ensureNotNull(id, "id"); - String baseUrl = configs.get("baseUrl"); - String modelName = configs.get("modelName"); - String apiKey = ValidationUtils.ensureNotNull(configs.get("apiKey"), "apiKey"); - Integer memoryLen = ValidationUtils.ensureNotNull( - NumberUtils.parseNumber(configs.get("memoryLen"), Integer.class), "memoryLen not a number."); + String model = ValidationUtils.ensureNotNull(configProvider.getModel(), "model"); + String apiKey = ValidationUtils.ensureNotNull( + configProvider.getCredentials().get("apiKey"), "apiKey"); ChatLanguageModel openAiChatModel = OpenAiChatModel.builder() .apiKey(apiKey) - .baseUrl(baseUrl) - .modelName(modelName) + .baseUrl(BASE_URL) + .modelName(model) .build(); StreamingChatLanguageModel openaiStreamChatModel = OpenAiStreamingChatModel.builder() .apiKey(apiKey) - .baseUrl(baseUrl) - .modelName(modelName) + .baseUrl(BASE_URL) + .modelName(model) .build(); MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder() .id(id) .chatMemoryStore(chatMemoryStore) - .maxMessages(memoryLen) + .maxMessages(MEMORY_LEN) .build(); return new OpenAIAssistant(openAiChatModel, openaiStreamChatModel, chatMemory); } diff --git a/bigtop-manager-dao/pom.xml b/bigtop-manager-dao/pom.xml index 6fed3ad75..ff56c9712 100644 --- a/bigtop-manager-dao/pom.xml +++ b/bigtop-manager-dao/pom.xml @@ -74,6 +74,10 @@ org.apache.tomcat.embed tomcat-embed-core + + org.apache.bigtop + bigtop-manager-ai-core + diff --git a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/handler/JsonTypeHandler.java b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/handler/JsonTypeHandler.java new file mode 100644 index 000000000..4879753b5 --- /dev/null +++ b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/handler/JsonTypeHandler.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.dao.handler; + +import org.apache.ibatis.type.BaseTypeHandler; +import org.apache.ibatis.type.JdbcType; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Map; + +public class JsonTypeHandler extends BaseTypeHandler> { + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + @Override + public void setNonNullParameter(PreparedStatement ps, int i, Map parameter, JdbcType jdbcType) + throws SQLException { + ps.setString(i, convertMapToJson(parameter)); + } + + @Override + public Map getNullableResult(ResultSet rs, String columnName) throws SQLException { + return convertJsonToMap(rs.getString(columnName)); + } + + @Override + public Map getNullableResult(ResultSet rs, int columnIndex) throws SQLException { + return convertJsonToMap(rs.getString(columnIndex)); + } + + @Override + public Map getNullableResult(java.sql.CallableStatement cs, int columnIndex) throws SQLException { + return convertJsonToMap(cs.getString(columnIndex)); + } + + private String convertMapToJson(Map map) { + try { + return objectMapper.writeValueAsString(map); + } catch (Exception e) { + throw new RuntimeException("Error converting map to JSON string", e); + } + } + + private Map convertJsonToMap(String json) { + try { + return objectMapper.readValue(json, new TypeReference>() {}); + } catch (Exception e) { + throw new RuntimeException("Error converting JSON string to map", e); + } + } +} diff --git a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/ChatMessagePO.java b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/ChatMessagePO.java new file mode 100644 index 000000000..423681077 --- /dev/null +++ b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/ChatMessagePO.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.dao.po; + +import lombok.Data; +import lombok.EqualsAndHashCode; + +import jakarta.persistence.Column; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import java.io.Serializable; + +@Data +@EqualsAndHashCode(callSuper = true) +@Table(name = "llm_chat_message") +public class ChatMessagePO extends BasePO implements Serializable { + @Id + @Column(name = "id") + private Long id; + + @Column(name = "message", nullable = false, length = 255) + private String message; + + @Column(name = "sender") + private String sender; + + @Column(name = "user_id") + private Long userId; + + @Column(name = "thread_id") + private Long threadId; +} diff --git a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/ChatThreadPO.java b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/ChatThreadPO.java new file mode 100644 index 000000000..e4c8628e9 --- /dev/null +++ b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/ChatThreadPO.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.dao.po; + +import lombok.Data; +import lombok.EqualsAndHashCode; + +import jakarta.persistence.Column; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import java.io.Serializable; + +@Data +@EqualsAndHashCode(callSuper = true) +@Table(name = "llm_chat_thread") +public class ChatThreadPO extends BasePO implements Serializable { + @Id + @Column(name = "id") + private Long id; + + @Column(name = "model", nullable = false, length = 255) + private String model; + + @Column(name = "user_id") + private Long userId; + + @Column(name = "platform_id") + private Long platformId; +} diff --git a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/PlatformAuthorizedPO.java b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/PlatformAuthorizedPO.java new file mode 100644 index 000000000..f27fcd437 --- /dev/null +++ b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/PlatformAuthorizedPO.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.dao.po; + +import lombok.Data; +import lombok.EqualsAndHashCode; + +import jakarta.persistence.Column; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import java.io.Serializable; +import java.util.Map; + +@Data +@EqualsAndHashCode(callSuper = true) +@Table(name = "llm_platform_authorized") +public class PlatformAuthorizedPO extends BasePO implements Serializable { + + @Id + @Column(name = "id") + private Long id; + + @Column(name = "credentials", columnDefinition = "json", nullable = false) + private Map credentials; + + @Column(name = "platform_id") + private Long platformId; +} diff --git a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/PlatformPO.java b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/PlatformPO.java new file mode 100644 index 000000000..0180c516b --- /dev/null +++ b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/po/PlatformPO.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.dao.po; + +import lombok.Data; +import lombok.EqualsAndHashCode; + +import jakarta.persistence.Column; +import jakarta.persistence.Id; +import jakarta.persistence.Table; +import java.io.Serializable; +import java.util.Map; + +@Data +@EqualsAndHashCode(callSuper = true) +@Table(name = "llm_platform") +public class PlatformPO extends BasePO implements Serializable { + + @Id + @Column(name = "id") + private Long id; + + @Column(name = "name", nullable = false, length = 255) + private String name; + + @Column(name = "credential", columnDefinition = "json", nullable = false) + private Map credential; + + @Column(name = "support_models", length = 255) + private String supportModels; +} diff --git a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/ChatMessageDao.java b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/ChatMessageDao.java new file mode 100644 index 000000000..59721afdf --- /dev/null +++ b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/ChatMessageDao.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.dao.repository; + +import org.apache.bigtop.manager.dao.po.ChatMessagePO; + +import org.apache.ibatis.annotations.Param; + +import java.util.List; + +public interface ChatMessageDao extends BaseDao { + List findAllByThreadId(@Param("threadId") Long threadId); + + void deleteByThreadId(@Param("threadId") Long threadId); +} diff --git a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/ChatThreadDao.java b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/ChatThreadDao.java new file mode 100644 index 000000000..a691b7b3c --- /dev/null +++ b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/ChatThreadDao.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.dao.repository; + +import org.apache.bigtop.manager.dao.po.ChatThreadPO; + +import org.apache.ibatis.annotations.Param; + +import java.util.List; + +public interface ChatThreadDao extends BaseDao { + List findAllByUserId(@Param("userId") Long userId); + + ChatThreadPO findById(Long id); + + List findAllByPlatformAuthorizedIdAndUserId( + @Param("platformId") Long platformAuthorizedId, @Param("userId") Long userId); +} diff --git a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/PlatformAuthorizedDao.java b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/PlatformAuthorizedDao.java new file mode 100644 index 000000000..fd801d76c --- /dev/null +++ b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/PlatformAuthorizedDao.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.dao.repository; + +import org.apache.bigtop.manager.dao.po.PlatformAuthorizedPO; + +import org.apache.ibatis.annotations.Param; + +public interface PlatformAuthorizedDao extends BaseDao { + PlatformAuthorizedPO findByPlatformId(@Param("id") Long platformId); + + void saveWithCredentials(PlatformAuthorizedPO platformAuthorizedPO); +} diff --git a/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/PlatformDao.java b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/PlatformDao.java new file mode 100644 index 000000000..e2d3c9785 --- /dev/null +++ b/bigtop-manager-dao/src/main/java/org/apache/bigtop/manager/dao/repository/PlatformDao.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.dao.repository; + +import org.apache.bigtop.manager.dao.po.PlatformPO; + +public interface PlatformDao extends BaseDao { + PlatformPO findByPlatformId(Long id); +} diff --git a/bigtop-manager-dao/src/main/resources/mapper/mysql/ChatMessageMapper.xml b/bigtop-manager-dao/src/main/resources/mapper/mysql/ChatMessageMapper.xml new file mode 100644 index 000000000..108410b1c --- /dev/null +++ b/bigtop-manager-dao/src/main/resources/mapper/mysql/ChatMessageMapper.xml @@ -0,0 +1,36 @@ + + + + + + + + + DELETE FROM llm_chat_message + WHERE thread_id = #{threadId} + + + \ No newline at end of file diff --git a/bigtop-manager-dao/src/main/resources/mapper/mysql/ChatThreadMapper.xml b/bigtop-manager-dao/src/main/resources/mapper/mysql/ChatThreadMapper.xml new file mode 100644 index 000000000..be702079b --- /dev/null +++ b/bigtop-manager-dao/src/main/resources/mapper/mysql/ChatThreadMapper.xml @@ -0,0 +1,39 @@ + + + + + + id, user_id, platform_id, model + + + + + + \ No newline at end of file diff --git a/bigtop-manager-dao/src/main/resources/mapper/mysql/PlatformAuthorizedMapper.xml b/bigtop-manager-dao/src/main/resources/mapper/mysql/PlatformAuthorizedMapper.xml new file mode 100644 index 000000000..f065a4c7b --- /dev/null +++ b/bigtop-manager-dao/src/main/resources/mapper/mysql/PlatformAuthorizedMapper.xml @@ -0,0 +1,45 @@ + + + + + + + id, credentials, platfotrm_id + + + + + + + + + + INSERT INTO llm_platform_authorized (platform_id, credentials) + VALUES (#{platformId}, #{credentials, typeHandler=org.apache.bigtop.manager.dao.handler.JsonTypeHandler}) + ON DUPLICATE KEY UPDATE + platform_id = VALUES(platform_id), + credentials = VALUES(credentials) + + + \ No newline at end of file diff --git a/bigtop-manager-dao/src/main/resources/mapper/mysql/PlatformMapper.xml b/bigtop-manager-dao/src/main/resources/mapper/mysql/PlatformMapper.xml new file mode 100644 index 000000000..e4ca9039b --- /dev/null +++ b/bigtop-manager-dao/src/main/resources/mapper/mysql/PlatformMapper.xml @@ -0,0 +1,38 @@ + + + + + + + id, name, credential, support_models + + + + + + + + + + \ No newline at end of file diff --git a/bigtop-manager-server/pom.xml b/bigtop-manager-server/pom.xml index 408b7f94a..d6aecdefb 100644 --- a/bigtop-manager-server/pom.xml +++ b/bigtop-manager-server/pom.xml @@ -147,6 +147,14 @@ net.devh grpc-client-spring-boot-starter + + org.apache.bigtop + bigtop-manager-ai-core + + + org.apache.bigtop + bigtop-manager-ai-assistant + diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/AIChatController.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/AIChatController.java index 2e3b56100..8446ab97e 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/AIChatController.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/AIChatController.java @@ -18,7 +18,6 @@ */ package org.apache.bigtop.manager.server.controller; -import org.apache.bigtop.manager.server.enums.ResponseStatus; import org.apache.bigtop.manager.server.model.converter.PlatformConverter; import org.apache.bigtop.manager.server.model.dto.PlatformDTO; import org.apache.bigtop.manager.server.model.req.PlatformReq; @@ -75,7 +74,7 @@ public ResponseEntity> platformsAuthCredential(@P @Operation(summary = "platforms", description = "Add authorized platforms") @PutMapping("/platforms") - public ResponseEntity addAuthorizedPlatform(@RequestBody PlatformReq platformReq) { + public ResponseEntity addAuthorizedPlatform(@RequestBody PlatformReq platformReq) { PlatformDTO platformDTO = PlatformConverter.INSTANCE.fromReq2DTO(platformReq); return ResponseEntity.success(chatService.addAuthorizedPlatform(platformDTO)); } @@ -83,11 +82,7 @@ public ResponseEntity addAuthorizedPlatform(@RequestBody PlatformReq @Operation(summary = "platforms", description = "Delete authorized platforms") @DeleteMapping("/platforms/{platformId}") public ResponseEntity deleteAuthorizedPlatform(@PathVariable Long platformId) { - int code = chatService.deleteAuthorizedPlatform(platformId); - if (code != 0) { - return ResponseEntity.error(ResponseStatus.PARAMETER_ERROR, "Permission denied"); - } - return ResponseEntity.success(true); + return ResponseEntity.success(chatService.deleteAuthorizedPlatform(platformId)); } @Operation(summary = "new threads", description = "Create a chat threads") @@ -99,11 +94,7 @@ public ResponseEntity createChatThreads(@PathVariable Long platfor @Operation(summary = "delete threads", description = "Delete a chat threads") @DeleteMapping("platforms/{platformId}/threads/{threadId}") public ResponseEntity deleteChatThreads(@PathVariable Long platformId, @PathVariable Long threadId) { - int code = chatService.deleteChatThreads(platformId, threadId); - if (code != 0) { - return ResponseEntity.error(ResponseStatus.PARAMETER_ERROR, "No Content"); - } - return ResponseEntity.success(true); + return ResponseEntity.success(chatService.deleteChatThreads(platformId, threadId)); } @Operation(summary = "get", description = "Get all threads of a platform") diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/ApiExceptionEnum.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/ApiExceptionEnum.java index 45ea4235e..01ed94425 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/ApiExceptionEnum.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/ApiExceptionEnum.java @@ -60,6 +60,14 @@ public enum ApiExceptionEnum { // Command Exceptions -- 18000 ~ 18999 COMMAND_NOT_FOUND(18000, LocaleKeys.COMMAND_NOT_FOUND), COMMAND_NOT_SUPPORTED(18000, LocaleKeys.COMMAND_NOT_SUPPORTED), + + // AI Chat Exceptions -- 19000 ~ 19999 + PLATFORM_NOT_FOUND(19000, LocaleKeys.PLATFORM_NOT_FOUND), + PLATFORM_NOT_AUTHORIZED(19001, LocaleKeys.PLATFORM_NOT_AUTHORIZED), + PERMISSION_DENIED(19002, LocaleKeys.PERMISSION_DENIED), + CREDIT_INCORRECT(19003, LocaleKeys.CREDIT_INCORRECT), + MODEL_NOT_SUPPORTED(19004, LocaleKeys.MODEL_NOT_SUPPORTED), + CHAT_THREAD_NOT_FOUND(19005, LocaleKeys.CHAT_THREAD_NOT_FOUND), ; private final Integer code; diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/LocaleKeys.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/LocaleKeys.java index 22dc64ae6..bbb6af104 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/LocaleKeys.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/LocaleKeys.java @@ -56,6 +56,15 @@ public enum LocaleKeys { COMMAND_NOT_FOUND("command.not.found"), COMMAND_NOT_SUPPORTED("command.not.supported"), + + PLATFORM_NOT_FOUND("platform.not.found"), + PLATFORM_NOT_AUTHORIZED("platform.not.authorized"), + PERMISSION_DENIED("permission.denied"), + CREDIT_INCORRECT("credit.incorrect"), + MODEL_NOT_SUPPORTED("model.not.supported"), + CHAT_THREAD_NOT_FOUND("chat.thread.not.found"), + + CHAT_LANGUAGE_PROMPT("chat.language.prompt"), ; private final String key; diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/ChatMessageConverter.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/ChatMessageConverter.java new file mode 100644 index 000000000..b2c17238d --- /dev/null +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/ChatMessageConverter.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.server.model.converter; + +import org.apache.bigtop.manager.ai.core.enums.MessageSender; +import org.apache.bigtop.manager.dao.po.ChatMessagePO; +import org.apache.bigtop.manager.server.config.MapStructSharedConfig; +import org.apache.bigtop.manager.server.model.vo.ChatMessageVO; + +import org.mapstruct.Mapper; +import org.mapstruct.factory.Mappers; + +@Mapper(config = MapStructSharedConfig.class) +public interface ChatMessageConverter { + ChatMessageConverter INSTANCE = Mappers.getMapper(ChatMessageConverter.class); + + ChatMessageVO fromPO2VO(ChatMessagePO chatMessagePO); + + default MessageSender mapStringToMessageSender(String sender) { + if (sender == null) { + return null; + } + try { + return MessageSender.valueOf(sender.toUpperCase()); + } catch (IllegalArgumentException e) { + return null; + } + } +} diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/ChatThreadConverter.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/ChatThreadConverter.java new file mode 100644 index 000000000..6931da212 --- /dev/null +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/ChatThreadConverter.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.server.model.converter; + +import org.apache.bigtop.manager.dao.po.ChatThreadPO; +import org.apache.bigtop.manager.server.config.MapStructSharedConfig; +import org.apache.bigtop.manager.server.model.vo.ChatThreadVO; + +import org.mapstruct.Mapper; +import org.mapstruct.Mapping; +import org.mapstruct.factory.Mappers; + +@Mapper(config = MapStructSharedConfig.class) +public interface ChatThreadConverter { + ChatThreadConverter INSTANCE = Mappers.getMapper(ChatThreadConverter.class); + + @Mapping(source = "id", target = "threadId") + ChatThreadVO fromPO2VO(ChatThreadPO platformAuthorizedPO); +} diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/PlatformAuthorizedConverter.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/PlatformAuthorizedConverter.java new file mode 100644 index 000000000..2e3c87a5b --- /dev/null +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/PlatformAuthorizedConverter.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.server.model.converter; + +import org.apache.bigtop.manager.dao.po.PlatformAuthorizedPO; +import org.apache.bigtop.manager.dao.po.PlatformPO; +import org.apache.bigtop.manager.server.config.MapStructSharedConfig; +import org.apache.bigtop.manager.server.model.vo.PlatformAuthorizedVO; + +import org.mapstruct.Context; +import org.mapstruct.Mapper; +import org.mapstruct.Mapping; +import org.mapstruct.factory.Mappers; + +@Mapper(config = MapStructSharedConfig.class) +public interface PlatformAuthorizedConverter { + PlatformAuthorizedConverter INSTANCE = Mappers.getMapper(PlatformAuthorizedConverter.class); + + @Mapping(target = "platformId", source = "id") + @Mapping(target = "supportModels", expression = "java(platformPO.getSupportModels())") + @Mapping(target = "platformName", expression = "java(platformPO.getName())") + PlatformAuthorizedVO fromPO2VO(PlatformAuthorizedPO platformAuthorizedPO, @Context PlatformPO platformPO); +} diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/PlatformConverter.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/PlatformConverter.java index aa4c3432e..85b6e15c7 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/PlatformConverter.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/converter/PlatformConverter.java @@ -18,16 +18,40 @@ */ package org.apache.bigtop.manager.server.model.converter; +import org.apache.bigtop.manager.dao.po.PlatformPO; import org.apache.bigtop.manager.server.config.MapStructSharedConfig; import org.apache.bigtop.manager.server.model.dto.PlatformDTO; +import org.apache.bigtop.manager.server.model.req.AuthCredentialReq; import org.apache.bigtop.manager.server.model.req.PlatformReq; +import org.apache.bigtop.manager.server.model.vo.PlatformVO; +import org.mapstruct.AfterMapping; import org.mapstruct.Mapper; +import org.mapstruct.MappingTarget; import org.mapstruct.factory.Mappers; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + @Mapper(config = MapStructSharedConfig.class) public interface PlatformConverter { PlatformConverter INSTANCE = Mappers.getMapper(PlatformConverter.class); PlatformDTO fromReq2DTO(PlatformReq platformReq); + + PlatformVO fromPO2VO(PlatformPO platformPO); + + default Map mapAuthCredentials(List authCredentials) { + if (authCredentials == null) { + return null; + } + return authCredentials.stream() + .collect(Collectors.toMap(AuthCredentialReq::getKey, AuthCredentialReq::getValue)); + } + + @AfterMapping + default void afterMapping(@MappingTarget PlatformDTO platformDTO, PlatformReq platformReq) { + platformDTO.setAuthCredentials(mapAuthCredentials(platformReq.getAuthCredentials())); + } } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/dto/PlatformAuthorizedDTO.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/dto/PlatformAuthorizedDTO.java new file mode 100644 index 000000000..da6cb1f2d --- /dev/null +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/dto/PlatformAuthorizedDTO.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.bigtop.manager.server.model.dto; + +import lombok.Data; + +import java.util.Map; + +@Data +public class PlatformAuthorizedDTO { + private String platformName; + private String model; + private Map credentials; + + public PlatformAuthorizedDTO(String name, Map credentialSet, String model) { + this.platformName = name; + this.credentials = credentialSet; + this.model = model; + } +} diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/dto/PlatformDTO.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/dto/PlatformDTO.java index f26245f56..368bf2a95 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/dto/PlatformDTO.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/dto/PlatformDTO.java @@ -20,10 +20,10 @@ import lombok.Data; -import java.util.List; +import java.util.Map; @Data public class PlatformDTO { private Long platformId; - private List authCredentials; + private Map authCredentials; } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/ChatMessageVO.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/ChatMessageVO.java index 981529bde..da226da07 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/ChatMessageVO.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/ChatMessageVO.java @@ -18,17 +18,19 @@ */ package org.apache.bigtop.manager.server.model.vo; +import org.apache.bigtop.manager.ai.core.enums.MessageSender; + import lombok.Data; @Data public class ChatMessageVO { - private String sender; + private MessageSender sender; private String message; private String createTime; - public ChatMessageVO(String sender, String messageText, String createTime) { + public ChatMessageVO(MessageSender sender, String messageText, String createTime) { this.sender = sender; this.message = messageText; this.createTime = createTime; diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/ChatThreadVO.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/ChatThreadVO.java index 2349fdbd2..c5f84a95b 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/ChatThreadVO.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/ChatThreadVO.java @@ -39,4 +39,6 @@ public ChatThreadVO(Long threadId, Long platformId, String model, String createT this.createTime = createTime; this.updateTime = createTime; } + + public ChatThreadVO() {} } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/PlatformAuthCredentialVO.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/PlatformAuthCredentialVO.java index edfe0ff30..7107422e1 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/PlatformAuthCredentialVO.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/PlatformAuthCredentialVO.java @@ -28,6 +28,6 @@ public class PlatformAuthCredentialVO { public PlatformAuthCredentialVO(String name, String displayName) { this.name = name; - this.displayName = name; + this.displayName = displayName; } } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/PlatformAuthorizedVO.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/PlatformAuthorizedVO.java index e957f7aab..3532cf684 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/PlatformAuthorizedVO.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/model/vo/PlatformAuthorizedVO.java @@ -33,4 +33,6 @@ public PlatformAuthorizedVO(long platformId, String name, String models) { this.platformName = name; this.supportModels = models; } + + public PlatformAuthorizedVO() {} } diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/AIChatService.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/AIChatService.java index d3cf12556..06042ba62 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/AIChatService.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/AIChatService.java @@ -34,15 +34,15 @@ public interface AIChatService { List authorizedPlatforms(); - PlatformVO addAuthorizedPlatform(PlatformDTO platformDTO); + PlatformAuthorizedVO addAuthorizedPlatform(PlatformDTO platformDTO); List platformsAuthCredential(Long platformId); - int deleteAuthorizedPlatform(Long platformId); + boolean deleteAuthorizedPlatform(Long platformId); ChatThreadVO createChatThreads(Long platformId, String model); - int deleteChatThreads(Long platformId, Long threadId); + boolean deleteChatThreads(Long platformId, Long threadId); List getAllChatThreads(Long platformId, String model); diff --git a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/AIChatServiceImpl.java b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/AIChatServiceImpl.java index 6dca1bcaa..e04787102 100644 --- a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/AIChatServiceImpl.java +++ b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/AIChatServiceImpl.java @@ -18,7 +18,29 @@ */ package org.apache.bigtop.manager.server.service.impl; -import org.apache.bigtop.manager.common.utils.DateUtils; +import org.apache.bigtop.manager.ai.assistant.GeneralAssistantFactory; +import org.apache.bigtop.manager.ai.assistant.provider.AIAssistantConfig; +import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore; +import org.apache.bigtop.manager.ai.core.enums.MessageSender; +import org.apache.bigtop.manager.ai.core.enums.PlatformType; +import org.apache.bigtop.manager.ai.core.factory.AIAssistant; +import org.apache.bigtop.manager.ai.core.factory.AIAssistantFactory; +import org.apache.bigtop.manager.dao.po.ChatMessagePO; +import org.apache.bigtop.manager.dao.po.ChatThreadPO; +import org.apache.bigtop.manager.dao.po.PlatformAuthorizedPO; +import org.apache.bigtop.manager.dao.po.PlatformPO; +import org.apache.bigtop.manager.dao.repository.ChatMessageDao; +import org.apache.bigtop.manager.dao.repository.ChatThreadDao; +import org.apache.bigtop.manager.dao.repository.PlatformAuthorizedDao; +import org.apache.bigtop.manager.dao.repository.PlatformDao; +import org.apache.bigtop.manager.server.enums.ApiExceptionEnum; +import org.apache.bigtop.manager.server.exception.ApiException; +import org.apache.bigtop.manager.server.holder.SessionUserHolder; +import org.apache.bigtop.manager.server.model.converter.ChatMessageConverter; +import org.apache.bigtop.manager.server.model.converter.ChatThreadConverter; +import org.apache.bigtop.manager.server.model.converter.PlatformAuthorizedConverter; +import org.apache.bigtop.manager.server.model.converter.PlatformConverter; +import org.apache.bigtop.manager.server.model.dto.PlatformAuthorizedDTO; import org.apache.bigtop.manager.server.model.dto.PlatformDTO; import org.apache.bigtop.manager.server.model.vo.ChatMessageVO; import org.apache.bigtop.manager.server.model.vo.ChatThreadVO; @@ -27,153 +49,275 @@ import org.apache.bigtop.manager.server.model.vo.PlatformVO; import org.apache.bigtop.manager.server.service.AIChatService; +import org.jetbrains.annotations.NotNull; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -import lombok.extern.slf4j.Slf4j; +import reactor.core.publisher.Flux; -import java.io.IOException; +import jakarta.annotation.Resource; import java.util.ArrayList; -import java.util.Date; +import java.util.HashMap; import java.util.List; -import java.util.Random; +import java.util.Map; +import java.util.Objects; -@Slf4j @Service public class AIChatServiceImpl implements AIChatService { + @Resource + private PlatformDao platformDao; + + @Resource + private PlatformAuthorizedDao platformAuthorizedDao; + + @Resource + private ChatThreadDao chatThreadDao; + + @Resource + private ChatMessageDao chatMessageDao; + + private AIAssistantFactory aiAssistantFactory; + + private final AIAssistantFactory aiTestFactory = new GeneralAssistantFactory(); + + public AIAssistantFactory getAiAssistantFactory() { + if (aiAssistantFactory == null) { + aiAssistantFactory = + new GeneralAssistantFactory(new PersistentChatMemoryStore(chatThreadDao, chatMessageDao)); + } + return aiAssistantFactory; + } + + private AIAssistantConfig getAIAssistantConfig(PlatformAuthorizedDTO platformAuthorizedDTO) { + return AIAssistantConfig.builder() + .setModel(platformAuthorizedDTO.getModel()) + .addCredentials(platformAuthorizedDTO.getCredentials()) + .build(); + } + + private PlatformType getPlatformType(String platformName) { + return PlatformType.getPlatformType(platformName.toLowerCase()); + } + + private AIAssistant buildAIAssistant(PlatformAuthorizedDTO platformAuthorizedDTO, Long threadId) { + return getAiAssistantFactory() + .create( + getPlatformType(platformAuthorizedDTO.getPlatformName()), + getAIAssistantConfig(platformAuthorizedDTO), + threadId); + } + + private Boolean testAuthorization(PlatformAuthorizedDTO platformAuthorizedDTO) { + AIAssistant aiAssistant = aiTestFactory.create( + getPlatformType(platformAuthorizedDTO.getPlatformName()), getAIAssistantConfig(platformAuthorizedDTO)); + try { + aiAssistant.ask("1+1="); + } catch (Exception e) { + throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT, e.getMessage()); + } + + return true; + } + @Override public List platforms() { + List platformPOs = platformDao.findAll(); List platforms = new ArrayList<>(); - platforms.add(new PlatformVO(1L, "OpenAI", "GPT-3.5,GPT-4o")); - platforms.add(new PlatformVO(2L, "ChatGLM", "GPT-3.5,GPT-4o")); + for (PlatformPO platformPO : platformPOs) { + platforms.add(PlatformConverter.INSTANCE.fromPO2VO(platformPO)); + } return platforms; } @Override public List authorizedPlatforms() { List authorizedPlatforms = new ArrayList<>(); - authorizedPlatforms.add(new PlatformAuthorizedVO(1L, "OpenAI", "GPT-3.5,GPT-4o")); - authorizedPlatforms.add(new PlatformAuthorizedVO(2L, "ChatGLM", "GPT-4o")); + List authorizedPlatformPOs = platformAuthorizedDao.findAll(); + for (PlatformAuthorizedPO authorizedPlatformPO : authorizedPlatformPOs) { + PlatformPO platformPO = platformDao.findById(authorizedPlatformPO.getPlatformId()); + authorizedPlatforms.add(PlatformAuthorizedConverter.INSTANCE.fromPO2VO(authorizedPlatformPO, platformPO)); + } + return authorizedPlatforms; } @Override - public PlatformVO addAuthorizedPlatform(PlatformDTO platformDTO) { - log.info("Adding authorized platform: {}", platformDTO); - log.info(platformDTO.getAuthCredentials().toString()); - return new PlatformVO(1L, "OpenAI", "GPT-3.5,GPT-4o"); + public PlatformAuthorizedVO addAuthorizedPlatform(PlatformDTO platformDTO) { + PlatformPO platformPO = platformDao.findByPlatformId(platformDTO.getPlatformId()); + if (platformPO == null) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); + } + Map credentialSet = getStringMap(platformDTO, platformPO); + List models = List.of(platformPO.getSupportModels().split(",")); + if (models.isEmpty()) { + throw new ApiException(ApiExceptionEnum.MODEL_NOT_SUPPORTED); + } + PlatformAuthorizedDTO platformAuthorizedDTO = + new PlatformAuthorizedDTO(platformPO.getName(), credentialSet, models.get(0)); + + if (!testAuthorization(platformAuthorizedDTO)) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); + } + + PlatformAuthorizedPO platformAuthorizedPO = new PlatformAuthorizedPO(); + platformAuthorizedPO.setCredentials(credentialSet); + platformAuthorizedPO.setPlatformId(platformPO.getId()); + + platformAuthorizedDao.saveWithCredentials(platformAuthorizedPO); + PlatformAuthorizedVO platformAuthorizedVO = + PlatformAuthorizedConverter.INSTANCE.fromPO2VO(platformAuthorizedPO, platformPO); + platformAuthorizedVO.setSupportModels(platformPO.getSupportModels()); + platformAuthorizedVO.setPlatformName(platformPO.getName()); + return platformAuthorizedVO; + } + + private static @NotNull Map getStringMap(PlatformDTO platformDTO, PlatformPO platformPO) { + if (platformPO == null) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); + } + Map credentialNeed = platformPO.getCredential(); + Map credentialGet = platformDTO.getAuthCredentials(); + Map credentialSet = new HashMap<>(); + for (String key : credentialNeed.keySet()) { + if (!credentialGet.containsKey(key)) { + throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT); + } + credentialSet.put(key, credentialGet.get(key)); + } + return credentialSet; } @Override public List platformsAuthCredential(Long platformId) { - List platformAuthCredentials = new ArrayList<>(); - platformAuthCredentials.add(new PlatformAuthCredentialVO("api-key", "API Key")); - platformAuthCredentials.add(new PlatformAuthCredentialVO("api-secret", "API Secret")); - return platformAuthCredentials; + PlatformPO platformPO = platformDao.findByPlatformId(platformId); + if (platformPO == null) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); + } + List platformAuthCredentialVOs = new ArrayList<>(); + for (String key : platformPO.getCredential().keySet()) { + PlatformAuthCredentialVO platformAuthCredentialVO = + new PlatformAuthCredentialVO(key, platformPO.getCredential().get(key)); + platformAuthCredentialVOs.add(platformAuthCredentialVO); + } + return platformAuthCredentialVOs; } @Override - public int deleteAuthorizedPlatform(Long platformId) { - Random random = new Random(); - int randomInt = random.nextInt(); - return randomInt % 2; + public boolean deleteAuthorizedPlatform(Long platformId) { + List authorizedPlatformPOs = platformAuthorizedDao.findAll(); + for (PlatformAuthorizedPO authorizedPlatformPO : authorizedPlatformPOs) { + if (authorizedPlatformPO.getId().equals(platformId)) { + platformAuthorizedDao.deleteById(authorizedPlatformPO.getId()); + return true; + } + } + + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); } @Override public ChatThreadVO createChatThreads(Long platformId, String model) { - - return new ChatThreadVO(1L, platformId, model, DateUtils.format(new Date())); + PlatformAuthorizedPO platformAuthorizedPO = platformAuthorizedDao.findByPlatformId(platformId); + if (platformAuthorizedPO == null) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_FOUND); + } + Long userId = SessionUserHolder.getUserId(); + PlatformPO platformPO = platformDao.findByPlatformId(platformAuthorizedPO.getPlatformId()); + List supportModels = List.of(platformPO.getSupportModels().split(",")); + if (!supportModels.contains(model)) { + throw new ApiException(ApiExceptionEnum.MODEL_NOT_SUPPORTED); + } + ChatThreadPO chatThreadPO = new ChatThreadPO(); + chatThreadPO.setUserId(userId); + chatThreadPO.setModel(model); + chatThreadPO.setPlatformId(platformAuthorizedPO.getId()); + chatThreadDao.save(chatThreadPO); + return ChatThreadConverter.INSTANCE.fromPO2VO(chatThreadPO); } @Override - public int deleteChatThreads(Long platformId, Long threadId) { - Random random = new Random(); - int randomInt = random.nextInt(); - return randomInt % 2; + public boolean deleteChatThreads(Long platformId, Long threadId) { + Long userId = SessionUserHolder.getUserId(); + List chatThreadPOS = chatThreadDao.findAllByUserId(userId); + for (ChatThreadPO chatThreadPO : chatThreadPOS) { + if (chatThreadPO.getId().equals(threadId) + && chatThreadPO.getPlatformId().equals(platformId)) { + chatThreadDao.deleteById(threadId); + return true; + } + } + throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND); } @Override public List getAllChatThreads(Long platformId, String model) { + Long userId = SessionUserHolder.getUserId(); + List chatThreadPOS = chatThreadDao.findAllByPlatformAuthorizedIdAndUserId(platformId, userId); List chatThreads = new ArrayList<>(); - if (model.equals("GPT-3.5")) { - ChatThreadVO chatThreadVO = new ChatThreadVO(1L, platformId, "GPT-3.5", DateUtils.format(new Date())); - chatThreads.add(chatThreadVO); - ChatThreadVO chatThreadVO2 = new ChatThreadVO(3L, platformId, "GPT-3.5", DateUtils.format(new Date())); - chatThreads.add(chatThreadVO2); - } - if (model.equals("GPT-4o")) { - ChatThreadVO chatThreadVO = new ChatThreadVO(2L, platformId, "GPT-4o", DateUtils.format(new Date())); - chatThreads.add(chatThreadVO); + for (ChatThreadPO chatThreadPO : chatThreadPOS) { + ChatThreadVO chatThreadVO = ChatThreadConverter.INSTANCE.fromPO2VO(chatThreadPO); + if (chatThreadVO.getModel().equals(model)) { + chatThreads.add(chatThreadVO); + } } return chatThreads; } @Override public SseEmitter talk(Long platformId, Long threadId, String message) { - String fullMessage = "Don't ask me" + message; - fullMessage += - """ - I won't tell you Bigtop Manager provides a modern, low-threshold web application to simplify \ - the deployment and management of components for Bigtop, similar to Apache Ambari and Cloudera \ - Manager. - And Bigtop Manager provides a modern, low-threshold web application to simplify \ - the deployment and management of components for Bigtop, similar to Apache Ambari and Cloudera \ - Manager. - """; - - SseEmitter emitter = new SseEmitter(); - Random random = new Random(); + ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId); + Long userId = SessionUserHolder.getUserId(); + if (chatThreadPO == null || !Objects.equals(userId, chatThreadPO.getUserId())) { + throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND); + } + PlatformAuthorizedPO platformAuthorizedPO = platformAuthorizedDao.findByPlatformId(platformId); + if (platformAuthorizedPO == null) { + throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_AUTHORIZED); + } - try { - StringBuilder remainingMessage = new StringBuilder(fullMessage); + PlatformPO platformPO = platformDao.findById(platformAuthorizedPO.getPlatformId()); + PlatformAuthorizedDTO platformAuthorizedDTO = new PlatformAuthorizedDTO( + platformPO.getName(), platformAuthorizedPO.getCredentials(), chatThreadPO.getModel()); + AIAssistant aiAssistant = buildAIAssistant(platformAuthorizedDTO, chatThreadPO.getId()); + Flux stringFlux = aiAssistant.streamAsk(message); - while (!remainingMessage.isEmpty()) { - int charsToSend = random.nextInt(21); - // 2% probability of simulated transmission failure - if (random.nextInt(50) == 2) { + SseEmitter emitter = new SseEmitter(); + stringFlux.subscribe( + s -> { try { - emitter.send(SseEmitter.event().name("error").data("broken pipe")); - } catch (IOException e) { + emitter.send(s); + } catch (Exception e) { emitter.completeWithError(e); } - emitter.complete(); - return emitter; - } - charsToSend = Math.min(charsToSend, remainingMessage.length()); - - String part = remainingMessage.substring(0, charsToSend); - remainingMessage.delete(0, charsToSend); - - emitter.send(SseEmitter.event().name("message").data(part)); - - int delay = random.nextInt(101); - Thread.sleep(delay); - } - } catch (IOException | InterruptedException e) { - emitter.completeWithError(e); - throw new RuntimeException(e); - } + }, + Throwable::printStackTrace, + emitter::complete); - emitter.complete(); + emitter.onTimeout(emitter::complete); return emitter; } @Override public List history(Long platformId, Long threadId) { List chatMessages = new ArrayList<>(); - Random random = new Random(); - int numberOfMessages = random.nextInt(11); - boolean isUser = true; - - for (int i = 0; i < numberOfMessages; i++) { - String sender = isUser ? "user" : "AI"; - String messageText = isUser ? "hello" : "hello, I'm GPT"; - messageText += i; - - ChatMessageVO message = new ChatMessageVO(sender, messageText, DateUtils.format(new Date())); - chatMessages.add(message); - - isUser = !isUser; + ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId); + if (chatThreadPO == null) { + throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND); + } + Long userId = SessionUserHolder.getUserId(); + if (!chatThreadPO.getUserId().equals(userId)) { + throw new ApiException(ApiExceptionEnum.PERMISSION_DENIED); + } + List chatMessagePOs = chatMessageDao.findAllByThreadId(threadId); + for (ChatMessagePO chatMessagePO : chatMessagePOs) { + ChatMessageVO chatMessageVO = ChatMessageConverter.INSTANCE.fromPO2VO(chatMessagePO); + MessageSender sender = chatMessageVO.getSender(); + if (sender == null) { + continue; + } + if (sender.equals(MessageSender.USER) || sender.equals(MessageSender.AI)) { + chatMessages.add(chatMessageVO); + } } return chatMessages; } diff --git a/bigtop-manager-server/src/main/resources/ddl/MySQL-DDL-CREATE.sql b/bigtop-manager-server/src/main/resources/ddl/MySQL-DDL-CREATE.sql index 94a4ab598..23d384650 100644 --- a/bigtop-manager-server/src/main/resources/ddl/MySQL-DDL-CREATE.sql +++ b/bigtop-manager-server/src/main/resources/ddl/MySQL-DDL-CREATE.sql @@ -310,6 +310,66 @@ CREATE TABLE `stage` KEY `idx_job_id` (`job_id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; +CREATE TABLE `llm_platform` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `name` VARCHAR(255) NOT NULL, + `credential` JSON DEFAULT NULL, + `support_models` VARCHAR(255) DEFAULT NULL, + `create_time` DATETIME DEFAULT CURRENT_TIMESTAMP, + `update_time` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + `create_by` BIGINT DEFAULT NULL, + `update_by` BIGINT DEFAULT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE `llm_platform_authorized` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `platform_id` BIGINT(20) UNSIGNED NOT NULL, + `credentials` JSON NOT NULL, + `create_time` DATETIME DEFAULT CURRENT_TIMESTAMP, + `update_time` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + `create_by` BIGINT DEFAULT NULL, + `update_by` BIGINT DEFAULT NULL, + PRIMARY KEY (`id`), + KEY `idx_platform_id` (`platform_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE `llm_chat_thread` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `platform_id` BIGINT(20) UNSIGNED NOT NULL, + `user_id` BIGINT(20) UNSIGNED NOT NULL, + `model` VARCHAR(255) NOT NULL, + `create_time` DATETIME DEFAULT CURRENT_TIMESTAMP, + `update_time` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + `create_by` BIGINT DEFAULT NULL, + `update_by` BIGINT DEFAULT NULL, + PRIMARY KEY (`id`), + KEY `idx_platform_id` (`platform_id`), + KEY `idx_user_id` (`user_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE `llm_chat_message` +( + `id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + `thread_id` BIGINT(20) UNSIGNED NOT NULL, + `user_id` BIGINT(20) UNSIGNED NOT NULL, + `message` TEXT NOT NULL, + `sender` VARCHAR(50) NOT NULL, + `create_time` DATETIME DEFAULT CURRENT_TIMESTAMP, + `update_time` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + KEY `idx_thread_id` (`thread_id`), + KEY `idx_user_id` (`user_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + -- Adding default admin user INSERT INTO bigtop_manager.user (id, create_time, update_time, nickname, password, status, username) -VALUES (1, now(), now(), 'Administrator', '21232f297a57a5a743894a0e4a801fc3', true, 'admin'); \ No newline at end of file +VALUES (1, now(), now(), 'Administrator', '21232f297a57a5a743894a0e4a801fc3', true, 'admin'); + +-- Adding default ai chat platform +INSERT INTO bigtop_manager.llm_platform (id,credential,NAME,support_models) +VALUES +(1,'{"apiKey": "API Key"}','OpenAI','gpt-3.5-turbo,gpt-4,gpt-4o,gpt-3.5-turbo-16k,gpt-4-turbo-preview,gpt-4-32k,gpt-4o-mini'); \ No newline at end of file diff --git a/bigtop-manager-server/src/main/resources/i18n/messages_en_US.properties b/bigtop-manager-server/src/main/resources/i18n/messages_en_US.properties index 5ae40e775..051df6d75 100644 --- a/bigtop-manager-server/src/main/resources/i18n/messages_en_US.properties +++ b/bigtop-manager-server/src/main/resources/i18n/messages_en_US.properties @@ -50,3 +50,10 @@ config.not.found=Config not exist command.not.found=Command not found for level [{0}] command.not.supported=Command [{0}] not supported for level [{1}] + +platform.not.found=platform not found +platform.not.authorized=platform not authorized +permission.denied=permission denied +credit.incorrect=credit incorrect +model.not.supported=model not supported +chat.thread.not.found=chat thread not found diff --git a/bigtop-manager-server/src/main/resources/i18n/messages_zh_CN.properties b/bigtop-manager-server/src/main/resources/i18n/messages_zh_CN.properties index a45ff7162..e4c647f6d 100644 --- a/bigtop-manager-server/src/main/resources/i18n/messages_zh_CN.properties +++ b/bigtop-manager-server/src/main/resources/i18n/messages_zh_CN.properties @@ -50,3 +50,10 @@ config.not.found=配置不存在 command.not.found=不存在 [{0}] 级别的命令 command.not.supported=[{0}] 命令在 [{1}] 级别下不支持 + +platform.not.found=平台不存在 +platform.not.authorized=平台未配置 +permission.denied=权限被拒绝 +credit.incorrect=凭证不正确 +model.not.supported=模型不支持 +chat.thread.not.found=线程不存在 diff --git a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/controller/AIChatControllerTest.java b/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/controller/AIChatControllerTest.java index e271b6dac..5c434c444 100644 --- a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/controller/AIChatControllerTest.java +++ b/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/controller/AIChatControllerTest.java @@ -18,6 +18,12 @@ */ package org.apache.bigtop.manager.server.controller; +import org.apache.bigtop.manager.server.model.dto.PlatformDTO; +import org.apache.bigtop.manager.server.model.req.PlatformReq; +import org.apache.bigtop.manager.server.model.vo.ChatMessageVO; +import org.apache.bigtop.manager.server.model.vo.ChatThreadVO; +import org.apache.bigtop.manager.server.model.vo.PlatformAuthCredentialVO; +import org.apache.bigtop.manager.server.model.vo.PlatformAuthorizedVO; import org.apache.bigtop.manager.server.model.vo.PlatformVO; import org.apache.bigtop.manager.server.service.AIChatService; import org.apache.bigtop.manager.server.utils.MessageSourceUtils; @@ -32,12 +38,15 @@ import org.mockito.MockedStatic; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.util.ArrayList; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -64,14 +73,128 @@ void tearDown() { @Test void getAllPlatforms() { - List platforms = new ArrayList<>(); when(chatService.platforms()).thenReturn(platforms); ResponseEntity> response = chatController.platforms(); + assertTrue(response.isSuccess()); assertEquals(platforms, response.getData()); } - // TODO + @Test + void getAuthorizedPlatforms() { + List authorizedPlatforms = new ArrayList<>(); + when(chatService.authorizedPlatforms()).thenReturn(authorizedPlatforms); + + ResponseEntity> response = chatController.authorizedPlatforms(); + + assertTrue(response.isSuccess()); + assertEquals(authorizedPlatforms, response.getData()); + } + + @Test + void platformsAuthCredential() { + Long platformId = 1L; + List credentials = new ArrayList<>(); + when(chatService.platformsAuthCredential(platformId)).thenReturn(credentials); + + ResponseEntity> response = chatController.platformsAuthCredential(platformId); + + assertTrue(response.isSuccess()); + assertEquals(credentials, response.getData()); + } + + @Test + void addAuthorizedPlatform() { + PlatformReq platformReq = new PlatformReq(); + PlatformAuthorizedVO authorizedVO = new PlatformAuthorizedVO(); + + when(chatService.addAuthorizedPlatform(any(PlatformDTO.class))).thenReturn(authorizedVO); + + ResponseEntity response = chatController.addAuthorizedPlatform(platformReq); + + assertTrue(response.isSuccess()); + assertEquals(authorizedVO, response.getData()); + } + + @Test + void deleteAuthorizedPlatform() { + Long platformId = 1L; + when(chatService.deleteAuthorizedPlatform(platformId)).thenReturn(true); + + ResponseEntity response = chatController.deleteAuthorizedPlatform(platformId); + + assertTrue(response.isSuccess()); + assertEquals(true, response.getData()); + } + + @Test + void createChatThreads() { + Long platformId = 1L; + String model = "model1"; + ChatThreadVO chatThread = new ChatThreadVO(); + + when(chatService.createChatThreads(eq(platformId), eq(model))).thenReturn(chatThread); + + ResponseEntity response = chatController.createChatThreads(platformId, model); + + assertTrue(response.isSuccess()); + assertEquals(chatThread, response.getData()); + } + + @Test + void deleteChatThreads() { + Long platformId = 1L; + Long threadId = 1L; + + when(chatService.deleteChatThreads(platformId, threadId)).thenReturn(true); + + ResponseEntity response = chatController.deleteChatThreads(platformId, threadId); + + assertTrue(response.isSuccess()); + assertEquals(true, response.getData()); + } + + @Test + void getAllChatThreads() { + Long platformId = 1L; + String model = "model1"; + List chatThreads = new ArrayList<>(); + + when(chatService.getAllChatThreads(eq(platformId), eq(model))).thenReturn(chatThreads); + + ResponseEntity> response = chatController.getAllChatThreads(platformId, model); + + assertTrue(response.isSuccess()); + assertEquals(chatThreads, response.getData()); + } + + @Test + void talk() { + Long platformId = 1L; + Long threadId = 1L; + String message = "Hello"; + + SseEmitter emitter = new SseEmitter(); + when(chatService.talk(eq(platformId), eq(threadId), eq(message))).thenReturn(emitter); + + SseEmitter result = chatController.talk(platformId, threadId, message); + + assertEquals(emitter, result); + } + + @Test + void history() { + Long platformId = 1L; + Long threadId = 1L; + List history = new ArrayList<>(); + + when(chatService.history(platformId, threadId)).thenReturn(history); + + ResponseEntity> response = chatController.history(platformId, threadId); + + assertTrue(response.isSuccess()); + assertEquals(history, response.getData()); + } }