Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5e965f7
tmp commit
lhpqaq Aug 24, 2024
05f43dc
finish platform api
lhpqaq Aug 26, 2024
547bc57
tmp talk
lhpqaq Aug 26, 2024
fd9b8ae
remain talk and history
lhpqaq Aug 26, 2024
4c65a1d
mvn spotless:apply
lhpqaq Aug 26, 2024
fc4ffdc
Merge branch 'ai/chat' into dev
lhpqaq Aug 26, 2024
5686fbb
some details left
lhpqaq Aug 26, 2024
cab275d
remain test
lhpqaq Aug 26, 2024
7ebdff9
spotless:check
lhpqaq Aug 26, 2024
afb0d0d
add language
lhpqaq Aug 26, 2024
64bec89
spotless:apply
lhpqaq Aug 26, 2024
24fd4dd
finish
lhpqaq Aug 27, 2024
2308681
last spotless:apply
lhpqaq Aug 27, 2024
81dc33d
add license
lhpqaq Aug 27, 2024
95ad7cf
fix some chore and add ddl
lhpqaq Aug 27, 2024
37fce01
remove a comment and log
lhpqaq Aug 27, 2024
6115c1b
modify openai url
lhpqaq Aug 27, 2024
45eb6cc
add ai assistant func
lhpqaq Aug 28, 2024
8d5caea
add bigmodel moudle
lhpqaq Aug 28, 2024
1486aab
mvn spotless:apply
lhpqaq Aug 28, 2024
c21e81d
add bigmodel
lhpqaq Aug 28, 2024
e8bfd1e
remove a variable
lhpqaq Aug 28, 2024
b4da0dd
add a constructor function
lhpqaq Aug 28, 2024
f7f06ed
add enum MessageSender
lhpqaq Aug 28, 2024
2e158f9
fix var name
lhpqaq Aug 28, 2024
e703386
add qianfan
lhpqaq Aug 29, 2024
4696f7a
remove url
lhpqaq Aug 29, 2024
f2ac4d5
revert app.yml
lhpqaq Aug 29, 2024
0ecabe4
remove qianfan and zhipu
lhpqaq Aug 29, 2024
49b1466
apply diff from @kevinw66
lhpqaq Aug 29, 2024
b3ce436
fix some bug
lhpqaq Aug 29, 2024
531634d
rename dao
lhpqaq Aug 29, 2024
aa74b09
tmp commit
lhpqaq Aug 29, 2024
b55b95b
finish switch to mybatis
lhpqaq Aug 29, 2024
1aac351
remove JsonToMapConverter
lhpqaq Aug 29, 2024
71cd3de
revert pom.xml
lhpqaq Aug 29, 2024
cba1b1f
fix some chores
lhpqaq Aug 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bigtop-manager-ai/bigtop-manager-ai-assistant/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-ai-core</artifactId>
</dependency>
<dependency>
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-dao</artifactId>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@
import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider;
import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.exception.PlatformNotFoundException;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.factory.ToolBox;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;

import org.apache.commons.lang3.NotImplementedException;

import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
Expand All @@ -35,13 +39,19 @@

public class GeneralAssistantFactory extends AbstractAIAssistantFactory {

private SystemPromptProvider systemPromptProvider = new LocSystemPromptProvider();
private ChatMemoryStore chatMemoryStore = new InMemoryChatMemoryStore();
private final SystemPromptProvider systemPromptProvider;
private final ChatMemoryStore chatMemoryStore;

public GeneralAssistantFactory() {}
public GeneralAssistantFactory() {
this(new LocSystemPromptProvider(), new InMemoryChatMemoryStore());
}

public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider) {
this.systemPromptProvider = systemPromptProvider;
this(systemPromptProvider, new InMemoryChatMemoryStore());
}

public GeneralAssistantFactory(ChatMemoryStore chatMemoryStore) {
this(new LocSystemPromptProvider(), chatMemoryStore);
}

public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider, ChatMemoryStore chatMemoryStore) {
Expand All @@ -51,29 +61,33 @@ public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider, ChatMe

@Override
public AIAssistant createWithPrompt(
PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id, Object promptId) {
AIAssistant aiAssistant = create(platformType, assistantConfig, id);
SystemMessage systemPrompt = systemPromptProvider.getSystemPrompt(promptId);
aiAssistant.setSystemPrompt(systemPrompt);
return aiAssistant;
}

@Override
public AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id) {
PlatformType platformType,
AIAssistantConfigProvider assistantConfig,
Object id,
SystemPrompt systemPrompts) {
AIAssistant aiAssistant;
if (Objects.requireNonNull(platformType) == PlatformType.OPENAI) {
AIAssistant aiAssistant = OpenAIAssistant.builder()
aiAssistant = OpenAIAssistant.builder()
.id(id)
.memoryStore(chatMemoryStore)
.withConfigProvider(assistantConfig)
.build();
aiAssistant.setSystemPrompt(systemPromptProvider.getSystemPrompt());
return aiAssistant;
} else {
throw new PlatformNotFoundException(platformType.getValue());
}
return null;

SystemMessage systemPrompt = systemPromptProvider.getSystemPrompt(systemPrompts);
aiAssistant.setSystemPrompt(systemPrompt);
return aiAssistant;
}

@Override
public AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id) {
return createWithPrompt(platformType, assistantConfig, id, SystemPrompt.DEFAULT_PROMPT);
}

@Override
public ToolBox createToolBox(PlatformType platformType) {
return null;
throw new NotImplementedException("ToolBox is not implemented for GeneralAssistantFactory");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,83 @@
import java.util.Map;

public class AIAssistantConfig implements AIAssistantConfigProvider {
private final Map<String, String> configMap;

private AIAssistantConfig(Map<String, String> configMap) {
this.configMap = configMap;
/**
* Model name for platform that we want to use
*/
private final String model;

/**
* Credentials for different platforms
*/
private final Map<String, String> credentials;

/**
* Platform extra configs are put here
*/
private final Map<String, String> configs;

private AIAssistantConfig(String model, Map<String, String> credentials, Map<String, String> configMap) {
this.model = model;
this.credentials = credentials;
this.configs = configMap;
}

public static Builder builder() {
return new Builder();
}

public static Builder withDefault(String baseUrl, String apiKey) {
Builder builder = new Builder();
return builder.set("baseUrl", baseUrl).set("apiKey", apiKey);
@Override
public String getModel() {
return model;
}

@Override
public Map<String, String> configs() {
public Map<String, String> getCredentials() {
return credentials;
}

return configMap;
@Override
public Map<String, String> getConfigs() {
return configs;
}

public static class Builder {
private final Map<String, String> configs;
private String model;

private final Map<String, String> credentials = new HashMap<>();

private final Map<String, String> configs = new HashMap<>();

public Builder() {
configs = new HashMap<>();
configs.put("memoryLen", "30");
public Builder() {}

public Builder setModel(String model) {
this.model = model;
return this;
}

public Builder set(String key, String value) {
public Builder addCredential(String key, String value) {
credentials.put(key, value);
return this;
}

public Builder addCredentials(Map<String, String> credentialMap) {
credentials.putAll(credentialMap);
return this;
}

public Builder addConfig(String key, String value) {
configs.put(key, value);
return this;
}

public Builder addConfigs(Map<String, String> configMap) {
configs.putAll(configMap);
return this;
}

public AIAssistantConfig build() {
return new AIAssistantConfig(configs);
return new AIAssistantConfig(model, credentials, configs);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.bigtop.manager.ai.assistant.provider;

import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;

import org.springframework.util.ResourceUtils;
Expand All @@ -29,41 +30,37 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.Objects;

@Slf4j
public class LocSystemPromptProvider implements SystemPromptProvider {

public static final String DEFAULT = "default";
private static final String SYSTEM_PROMPT_PATH = "src/main/resources/";
private static final String DEFAULT_NAME = "big-data-professor.st";

@Override
public SystemMessage getSystemPrompt(Object id) {
if (Objects.equals(id.toString(), DEFAULT)) {
return getSystemPrompt();
} else {
return loadPromptFromFile(id.toString());
public SystemMessage getSystemPrompt(SystemPrompt systemPrompt) {
if (systemPrompt == SystemPrompt.DEFAULT_PROMPT) {
systemPrompt = SystemPrompt.BIGDATA_PROFESSOR;
}

return loadPromptFromFile(systemPrompt.getValue());
}

@Override
public SystemMessage getSystemPrompt() {
return loadPromptFromFile(DEFAULT_NAME);
return getSystemPrompt(SystemPrompt.DEFAULT_PROMPT);
}

private SystemMessage loadPromptFromFile(String fileName) {
final String filePath = SYSTEM_PROMPT_PATH + fileName;
final String filePath = SYSTEM_PROMPT_PATH + fileName + ".st";
try {
File file = ResourceUtils.getFile(filePath);
String text = Files.readString(file.toPath(), StandardCharsets.UTF_8);
return SystemMessage.from(text);
} catch (IOException e) {
//
log.error(
"Exception occurred while loading SystemPrompt from local. Here is some information:{}",
e.getMessage());
return SystemMessage.from("");
return SystemMessage.from("You are a helpful assistant.");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.store;

import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class PersistentChatMemoryStore implements ChatMemoryStore {

private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) {
String sender = chatMessagePO.getSender().toLowerCase();
if (sender.equals(MessageSender.AI.getValue())) {
return new AiMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.USER.getValue())) {
return new UserMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.SYSTEM.getValue())) {
return new SystemMessage(chatMessagePO.getMessage());
} else {
return null;
}
}

private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatThreadId) {
ChatMessagePO chatMessagePO = new ChatMessagePO();
if (chatMessage.type().equals(ChatMessageType.AI)) {
chatMessagePO.setSender(MessageSender.AI.getValue());
AiMessage aiMessage = (AiMessage) chatMessage;
chatMessagePO.setMessage(aiMessage.text());
} else if (chatMessage.type().equals(ChatMessageType.USER)) {
chatMessagePO.setSender(MessageSender.USER.getValue());
UserMessage userMessage = (UserMessage) chatMessage;
chatMessagePO.setMessage(userMessage.singleText());
} else if (chatMessage.type().equals(ChatMessageType.SYSTEM)) {
chatMessagePO.setSender(MessageSender.SYSTEM.getValue());
SystemMessage systemMessage = (SystemMessage) chatMessage;
chatMessagePO.setMessage(systemMessage.text());
} else {
chatMessagePO.setSender(chatMessage.type().toString());
}
ChatThreadPO chatThreadPO = chatThreadDao.findById(chatThreadId);
chatMessagePO.setUserId(chatThreadPO.getUserId());
chatMessagePO.setThreadId(chatThreadId);
return chatMessagePO;
}

@Override
public List<ChatMessage> getMessages(Object threadId) {
List<ChatMessagePO> chatMessages = chatMessageDao.findAllByThreadId((Long) threadId);
if (chatMessages.isEmpty()) {
return new ArrayList<>();
} else {
return chatMessages.stream().map(this::convertToChatMessage).collect(Collectors.toList());
}
}

@Override
public void updateMessages(Object threadId, List<ChatMessage> messages) {
ChatMessagePO chatMessagePO = convertToChatMessagePO(messages.get(messages.size() - 1), (Long) threadId);
chatMessageDao.save(chatMessagePO);
}

@Override
public void deleteMessages(Object threadId) {
chatMessageDao.deleteByThreadId((Long) threadId);
}
}
Loading