Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bigtop-manager-ai/bigtop-manager-ai-assistant/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-ai-core</artifactId>
</dependency>
<dependency>
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-ai-dashscope</artifactId>
</dependency>
<dependency>
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-dao</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,44 @@
package org.apache.bigtop.manager.ai.assistant;

import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider;
import org.apache.bigtop.manager.ai.assistant.provider.PersistentStoreProvider;
import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.exception.PlatformNotFoundException;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.factory.ToolBox;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.MessageStoreProvider;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.dashscope.DashScopeAssistant;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;

import org.apache.commons.lang3.NotImplementedException;

import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;

import java.util.Objects;

public class GeneralAssistantFactory extends AbstractAIAssistantFactory {

private final SystemPromptProvider systemPromptProvider;
private final ChatMemoryStore chatMemoryStore;
private final MessageStoreProvider messageStoreProvider;

public GeneralAssistantFactory() {
this(new LocSystemPromptProvider(), new InMemoryChatMemoryStore());
this(new LocSystemPromptProvider(), new PersistentStoreProvider());
}

public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider) {
this(systemPromptProvider, new InMemoryChatMemoryStore());
this(systemPromptProvider, new PersistentStoreProvider());
}

public GeneralAssistantFactory(ChatMemoryStore chatMemoryStore) {
this(new LocSystemPromptProvider(), chatMemoryStore);
public GeneralAssistantFactory(MessageStoreProvider messageStoreProvider) {
this(new LocSystemPromptProvider(), messageStoreProvider);
}

public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider, ChatMemoryStore chatMemoryStore) {
public GeneralAssistantFactory(
SystemPromptProvider systemPromptProvider, MessageStoreProvider messageStoreProvider) {
this.systemPromptProvider = systemPromptProvider;
this.chatMemoryStore = chatMemoryStore;
this.messageStoreProvider = messageStoreProvider;
}

@Override
Expand All @@ -69,14 +69,19 @@ public AIAssistant createWithPrompt(
if (Objects.requireNonNull(platformType) == PlatformType.OPENAI) {
aiAssistant = OpenAIAssistant.builder()
.id(id)
.memoryStore(chatMemoryStore)
.memoryStore(messageStoreProvider.getChatMemoryStore())
.withConfigProvider(assistantConfig)
.build();
} else if (Objects.requireNonNull(platformType) == PlatformType.DASH_SCOPE) {
aiAssistant = DashScopeAssistant.builder()
.id(id)
.withConfigProvider(assistantConfig)
.messageRepository(messageStoreProvider.getMessageRepository())
.build();
} else {
throw new PlatformNotFoundException(platformType.getValue());
}

SystemMessage systemPrompt = systemPromptProvider.getSystemPrompt(systemPrompts);
String systemPrompt = systemPromptProvider.getSystemMessage(systemPrompts);
aiAssistant.setSystemPrompt(systemPrompt);
String locale = assistantConfig.getLanguage();
if (locale != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ public Builder addConfig(String key, String value) {
}

public Builder addConfigs(Map<String, String> configMap) {
configs.putAll(configMap);
if (configMap != null) {
configs.putAll(configMap);
}
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import org.springframework.util.ResourceUtils;

import dev.langchain4j.data.message.SystemMessage;
import lombok.extern.slf4j.Slf4j;

import java.io.File;
Expand All @@ -36,7 +35,7 @@
public class LocSystemPromptProvider implements SystemPromptProvider {

@Override
public SystemMessage getSystemPrompt(SystemPrompt systemPrompt) {
public String getSystemMessage(SystemPrompt systemPrompt) {
if (systemPrompt == SystemPrompt.DEFAULT_PROMPT) {
systemPrompt = SystemPrompt.BIGDATA_PROFESSOR;
}
Expand All @@ -45,8 +44,8 @@ public SystemMessage getSystemPrompt(SystemPrompt systemPrompt) {
}

@Override
public SystemMessage getSystemPrompt() {
return getSystemPrompt(SystemPrompt.DEFAULT_PROMPT);
public String getSystemMessage() {
return getSystemMessage(SystemPrompt.DEFAULT_PROMPT);
}

private String loadTextFromFile(String fileName) {
Expand All @@ -64,23 +63,23 @@ private String loadTextFromFile(String fileName) {
}
}

private SystemMessage loadPromptFromFile(String fileName) {
private String loadPromptFromFile(String fileName) {
final String filePath = fileName + ".st";
String text = loadTextFromFile(filePath);
if (text == null) {
return SystemMessage.from("You are a helpful assistant.");
return "You are a helpful assistant.";
} else {
return SystemMessage.from(text);
return text;
}
}

public SystemMessage getLanguagePrompt(String locale) {
public String getLanguagePrompt(String locale) {
final String filePath = SystemPrompt.LANGUAGE_PROMPT.getValue() + '-' + locale + ".st";
String text = loadTextFromFile(filePath);
if (text == null) {
return SystemMessage.from("Answer in " + locale);
return "Answer in " + locale;
} else {
return SystemMessage.from(text);
return text;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.provider;

import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore;
import org.apache.bigtop.manager.ai.assistant.store.PersistentMessageRepository;
import org.apache.bigtop.manager.ai.core.provider.MessageStoreProvider;
import org.apache.bigtop.manager.ai.core.repository.MessageRepository;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
public class PersistentStoreProvider implements MessageStoreProvider {
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

public PersistentStoreProvider(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

public PersistentStoreProvider() {
chatMessageDao = null;
chatThreadDao = null;
}

@Override
public MessageRepository getMessageRepository() {
return new PersistentMessageRepository(chatThreadDao, chatMessageDao);
}

@Override
public ChatMemoryStore getChatMemoryStore() {
if (chatThreadDao == null) {
return new InMemoryChatMemoryStore();
}
return new PersistentChatMemoryStore(chatThreadDao, chatMessageDao);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.store;

import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.repository.MessageRepository;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

public class PersistentMessageRepository implements MessageRepository {
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

private boolean noPersistent() {
return chatThreadDao == null || chatMessageDao == null;
}

public PersistentMessageRepository(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

private ChatMessagePO getChatMessagePO(String message, Long threadId, MessageSender sender) {
if (noPersistent()) {
return null;
}
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
ChatMessagePO chatMessagePO = new ChatMessagePO();
chatMessagePO.setUserId(chatThreadPO.getUserId());
chatMessagePO.setThreadId(threadId);
chatMessagePO.setSender(sender.getValue());
chatMessagePO.setMessage(message);
return chatMessagePO;
}

@Override
public void saveUserMessage(String message, Long threadId) {
if (noPersistent()) {
return;
}
ChatMessagePO chatMessagePO = getChatMessagePO(message, threadId, MessageSender.USER);
chatMessageDao.save(chatMessagePO);
}

@Override
public void saveAiMessage(String message, Long threadId) {
if (noPersistent()) {
return;
}
ChatMessagePO chatMessagePO = getChatMessagePO(message, threadId, MessageSender.AI);
chatMessageDao.save(chatMessagePO);
}

@Override
public void saveSystemMessage(String message, Long threadId) {
if (noPersistent()) {
return;
}
ChatMessagePO chatMessagePO = getChatMessagePO(message, threadId, MessageSender.SYSTEM);
chatMessageDao.save(chatMessagePO);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

import org.junit.jupiter.api.Test;

import dev.langchain4j.data.message.SystemMessage;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

Expand All @@ -35,12 +33,12 @@ public class SystemPromptProviderTests {

@Test
public void loadSystemPromptByIdTest() {
SystemMessage systemPrompt1 = systemPromptProvider.getSystemPrompt(SystemPrompt.BIGDATA_PROFESSOR);
assertFalse(systemPrompt1.text().isEmpty());
String systemPrompt1 = systemPromptProvider.getSystemMessage(SystemPrompt.BIGDATA_PROFESSOR);
assertFalse(systemPrompt1.isEmpty());

SystemMessage systemPrompt2 = systemPromptProvider.getSystemPrompt();
assertFalse(systemPrompt2.text().isEmpty());
String systemPrompt2 = systemPromptProvider.getSystemMessage();
assertFalse(systemPrompt2.isEmpty());

assertEquals(systemPrompt1.text(), systemPrompt2.text());
assertEquals(systemPrompt1, systemPrompt2);
}
}
Loading