Skip to content

Commit a8bdc00

Browse files
Feat/cohere agent implementation (#4703)
* implement cohere agent support * run yarn lint * moderize Cohere add supported langchain method redo streaming since it was not working looping of agent calls was not functioning * change default model to real model tag add case statement for model tag * remove debug * update default * only whitelist known labels --------- Co-authored-by: Timothy Carambat <rambat1010@gmail.com>
1 parent 62b45a7 commit a8bdc00

10 files changed

Lines changed: 340 additions & 6 deletions

File tree

frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ const ENABLED_PROVIDERS = [
3636
"foundry",
3737
"zai",
3838
"giteeai",
39+
"cohere",
3940
// TODO: More agent support.
40-
// "cohere", // Has tool calling and will need to build explicit support
4141
// "huggingface" // Can be done but already has issues with no-chat templated. Needs to be tested.
4242
];
4343
const WARN_PERFORMANCE = [

server/endpoints/utils.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ function getModelTag() {
151151
case "giteeai":
152152
model = process.env.GITEE_AI_MODEL_PREF;
153153
break;
154+
case "cohere":
155+
model = process.env.COHERE_MODEL_PREF;
156+
break;
154157
default:
155158
model = "--";
156159
break;

server/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"@lancedb/lancedb": "0.15.0",
2727
"@langchain/anthropic": "0.1.16",
2828
"@langchain/aws": "^0.0.5",
29+
"@langchain/cohere": "0.0.11",
2930
"@langchain/community": "0.0.53",
3031
"@langchain/core": "0.1.61",
3132
"@langchain/openai": "0.0.28",

server/utils/agents/aibitat/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,8 @@ ${this.getHistory({ to: route.to })
990990
return new Providers.FoundryProvider({ model: config.model });
991991
case "giteeai":
992992
return new Providers.GiteeAIProvider({ model: config.model });
993+
case "cohere":
994+
return new Providers.CohereProvider({ model: config.model });
993995
default:
994996
throw new Error(
995997
`Unknown provider: ${config.provider}. Please use a valid provider.`

server/utils/agents/aibitat/providers/ai-provider.js

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
const { v4 } = require("uuid");
1414
const { ChatOpenAI } = require("@langchain/openai");
1515
const { ChatAnthropic } = require("@langchain/anthropic");
16+
const { ChatCohere } = require("@langchain/cohere");
1617
const { ChatOllama } = require("@langchain/community/chat_models/ollama");
1718
const { toValidNumber, safeJsonParse } = require("../../../http");
1819
const { getLLMProviderClass } = require("../../../helpers");
@@ -239,6 +240,11 @@ class Provider {
239240
apiKey: process.env.GITEE_AI_API_KEY ?? null,
240241
...config,
241242
});
243+
case "cohere":
244+
return new ChatCohere({
245+
apiKey: process.env.COHERE_API_KEY ?? null,
246+
...config,
247+
});
242248
// OSS Model Runners
243249
// case "anythingllm_ollama":
244250
// return new ChatOllama({
@@ -307,7 +313,6 @@ class Provider {
307313
...config,
308314
});
309315
}
310-
311316
default:
312317
throw new Error(`Unsupported provider ${provider} for this task.`);
313318
}
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
const { CohereClient } = require("cohere-ai");
2+
const Provider = require("./ai-provider");
3+
const InheritMultiple = require("./helpers/classes");
4+
const UnTooled = require("./helpers/untooled");
5+
const { v4 } = require("uuid");
6+
const { safeJsonParse } = require("../../../http");
7+
8+
class CohereProvider extends InheritMultiple([Provider, UnTooled]) {
9+
model;
10+
11+
constructor(config = {}) {
12+
const { model = process.env.COHERE_MODEL_PREF || "command-r-08-2024" } =
13+
config;
14+
super();
15+
const client = new CohereClient({
16+
token: process.env.COHERE_API_KEY,
17+
});
18+
this._client = client;
19+
this.model = model;
20+
this.verbose = true;
21+
}
22+
23+
get client() {
24+
return this._client;
25+
}
26+
27+
get supportsAgentStreaming() {
28+
return true;
29+
}
30+
31+
#convertChatHistoryCohere(chatHistory = []) {
32+
let cohereHistory = [];
33+
chatHistory.forEach((message) => {
34+
switch (message.role) {
35+
case "SYSTEM":
36+
case "system":
37+
cohereHistory.push({ role: "SYSTEM", message: message.content });
38+
break;
39+
case "USER":
40+
case "user":
41+
cohereHistory.push({ role: "USER", message: message.content });
42+
break;
43+
case "CHATBOT":
44+
case "assistant":
45+
cohereHistory.push({ role: "CHATBOT", message: message.content });
46+
break;
47+
}
48+
});
49+
50+
return cohereHistory;
51+
}
52+
53+
async #handleFunctionCallStream({ messages = [] }) {
54+
const userPrompt = messages[messages.length - 1]?.content || "";
55+
const history = messages.slice(0, -1);
56+
return await this.client.chatStream({
57+
model: this.model,
58+
chatHistory: this.#convertChatHistoryCohere(history),
59+
message: userPrompt,
60+
});
61+
}
62+
63+
async stream(messages, functions = [], eventHandler = null) {
64+
return await UnTooled.prototype.stream.call(
65+
this,
66+
messages,
67+
functions,
68+
this.#handleFunctionCallStream.bind(this),
69+
eventHandler
70+
);
71+
}
72+
73+
async streamingFunctionCall(
74+
messages,
75+
functions,
76+
chatCb = null,
77+
eventHandler = null
78+
) {
79+
const history = [...messages].filter((msg) =>
80+
["user", "assistant"].includes(msg.role)
81+
);
82+
if (history[history.length - 1]?.role !== "user") return null;
83+
84+
const msgUUID = v4();
85+
let textResponse = "";
86+
const historyMessages = this.buildToolCallMessages(history, functions);
87+
const stream = await chatCb({ messages: historyMessages });
88+
89+
eventHandler?.("reportStreamEvent", {
90+
type: "statusResponse",
91+
uuid: v4(),
92+
content: "Agent is thinking...",
93+
});
94+
95+
for await (const event of stream) {
96+
if (event.eventType !== "text-generation") continue;
97+
textResponse += event.text;
98+
eventHandler?.("reportStreamEvent", {
99+
type: "statusResponse",
100+
uuid: msgUUID,
101+
content: event.text,
102+
});
103+
}
104+
105+
const call = safeJsonParse(textResponse, null);
106+
if (call === null)
107+
return { toolCall: null, text: textResponse, uuid: msgUUID };
108+
109+
const { valid, reason } = this.validFuncCall(call, functions);
110+
if (!valid) {
111+
this.providerLog(`Invalid function tool call: ${reason}.`);
112+
eventHandler?.("reportStreamEvent", {
113+
type: "removeStatusResponse",
114+
uuid: msgUUID,
115+
content:
116+
"The model attempted to make an invalid function call - it was ignored.",
117+
});
118+
return { toolCall: null, text: null, uuid: msgUUID };
119+
}
120+
121+
const { isDuplicate, reason: duplicateReason } =
122+
this.deduplicator.isDuplicate(call.name, call.arguments);
123+
if (isDuplicate) {
124+
this.providerLog(
125+
`Cannot call ${call.name} again because ${duplicateReason}.`
126+
);
127+
eventHandler?.("reportStreamEvent", {
128+
type: "removeStatusResponse",
129+
uuid: msgUUID,
130+
content:
131+
"The model tried to call a function with the same arguments as a previous call - it was ignored.",
132+
});
133+
return { toolCall: null, text: null, uuid: msgUUID };
134+
}
135+
136+
eventHandler?.("reportStreamEvent", {
137+
uuid: `${msgUUID}:tool_call_invocation`,
138+
type: "toolCallInvocation",
139+
content: `Parsed Tool Call: ${call.name}(${JSON.stringify(call.arguments)})`,
140+
});
141+
return { toolCall: call, text: null, uuid: msgUUID };
142+
}
143+
144+
/**
145+
* Stream a chat completion from the LLM with tool calling
146+
* Override the inherited `stream` method since Cohere uses a different API format.
147+
*
148+
* @param {any[]} messages - The messages to send to the LLM.
149+
* @param {any[]} functions - The functions to use in the LLM.
150+
* @param {function} eventHandler - The event handler to use to report stream events.
151+
* @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion.
152+
*/
153+
async stream(messages, functions = [], eventHandler = null) {
154+
this.providerLog(
155+
"CohereProvider.stream - will process this chat completion."
156+
);
157+
try {
158+
let completion = { content: "" };
159+
if (functions.length > 0) {
160+
const {
161+
toolCall,
162+
text,
163+
uuid: msgUUID,
164+
} = await this.streamingFunctionCall(
165+
messages,
166+
functions,
167+
this.#handleFunctionCallStream.bind(this),
168+
eventHandler
169+
);
170+
171+
if (toolCall !== null) {
172+
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
173+
this.deduplicator.trackRun(toolCall.name, toolCall.arguments, {
174+
cooldown: this.isMCPTool(toolCall, functions),
175+
});
176+
return {
177+
result: null,
178+
functionCall: {
179+
name: toolCall.name,
180+
arguments: toolCall.arguments,
181+
},
182+
cost: 0,
183+
};
184+
}
185+
186+
if (text) {
187+
this.providerLog(
188+
`No tool call found in the response - will send as a full text response.`
189+
);
190+
completion.content = text;
191+
eventHandler?.("reportStreamEvent", {
192+
type: "removeStatusResponse",
193+
uuid: msgUUID,
194+
content: "No tool call found in the response",
195+
});
196+
eventHandler?.("reportStreamEvent", {
197+
type: "statusResponse",
198+
uuid: v4(),
199+
content: "Done thinking.",
200+
});
201+
eventHandler?.("reportStreamEvent", {
202+
type: "fullTextResponse",
203+
uuid: v4(),
204+
content: text,
205+
});
206+
}
207+
}
208+
209+
if (!completion?.content) {
210+
eventHandler?.("reportStreamEvent", {
211+
type: "statusResponse",
212+
uuid: v4(),
213+
content: "Done thinking.",
214+
});
215+
216+
this.providerLog(
217+
"Will assume chat completion without tool call inputs."
218+
);
219+
const msgUUID = v4();
220+
completion = { content: "" };
221+
const stream = await this.#handleFunctionCallStream({
222+
messages: this.cleanMsgs(messages),
223+
});
224+
225+
for await (const chunk of stream) {
226+
if (chunk.eventType !== "text-generation") continue;
227+
completion.content += chunk.text;
228+
eventHandler?.("reportStreamEvent", {
229+
type: "textResponseChunk",
230+
uuid: msgUUID,
231+
content: chunk.text,
232+
});
233+
}
234+
}
235+
236+
this.deduplicator.reset("runs");
237+
return {
238+
textResponse: completion.content,
239+
cost: 0,
240+
};
241+
} catch (error) {
242+
throw error;
243+
}
244+
}
245+
246+
getCost(_usage) {
247+
return 0;
248+
}
249+
}
250+
251+
module.exports = CohereProvider;

server/utils/agents/aibitat/providers/index.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const MoonshotAiProvider = require("./moonshotAi.js");
2828
const CometApiProvider = require("./cometapi.js");
2929
const FoundryProvider = require("./foundry.js");
3030
const GiteeAIProvider = require("./giteeai.js");
31+
const CohereProvider = require("./cohere.js");
3132

3233
module.exports = {
3334
OpenAIProvider,
@@ -60,4 +61,5 @@ module.exports = {
6061
MoonshotAiProvider,
6162
FoundryProvider,
6263
GiteeAIProvider,
64+
CohereProvider,
6365
};

server/utils/agents/index.js

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ class AgentHandler {
219219
case "giteeai":
220220
if (!process.env.GITEE_AI_API_KEY)
221221
throw new Error("GiteeAI API Key must be provided to use agents.");
222+
break;
223+
case "cohere":
224+
if (!process.env.COHERE_API_KEY)
225+
throw new Error("Cohere API key must be provided to use agents.");
226+
break;
222227
default:
223228
throw new Error(
224229
"No workspace agent provider set. Please set your agent provider in the workspace's settings"
@@ -297,6 +302,8 @@ class AgentHandler {
297302
return process.env.FOUNDRY_MODEL_PREF ?? null;
298303
case "giteeai":
299304
return process.env.GITEE_AI_MODEL_PREF ?? null;
305+
case "cohere":
306+
return process.env.COHERE_MODEL_PREF ?? "command-r-08-2024";
300307
default:
301308
return null;
302309
}

server/utils/helpers/customModels.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ async function getCohereModels(_apiKey = null, type = "chat") {
814814
.then((results) => results.models)
815815
.then((models) =>
816816
models.map((model) => ({
817-
id: model.id,
817+
id: model.name,
818818
name: model.name,
819819
}))
820820
)

0 commit comments

Comments
 (0)