Skip to content

Commit 0f3cf6a

Browse files
improvements
1 parent d37d752 commit 0f3cf6a

File tree

2 files changed

+173
-59
lines changed

2 files changed

+173
-59
lines changed

libs/langchain/src/agents/middlewareAgent/middleware/bigTool.ts

Lines changed: 146 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,26 @@ export type ToolSelectionStrategy =
1414
/**
1515
* Custom tool selector function type
1616
*/
17-
export type CustomToolSelector<Context = Record<string, unknown>> = (
18-
tools: (ClientTool | ServerTool)[],
19-
query: string,
20-
context: Context
21-
) => Promise<(ClientTool | ServerTool)[]> | (ClientTool | ServerTool)[];
17+
const customToolSelector = z
18+
.function()
19+
.args(
20+
z.object({
21+
// all tools
22+
tools: z.array(z.custom<ClientTool | ServerTool>()).describe("Alltools"),
23+
// user query
24+
query: z.string(),
25+
// agent context
26+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
27+
context: z.custom<Record<string, any>>(),
28+
})
29+
)
30+
.returns(
31+
z.union([
32+
z.array(z.custom<ClientTool | ServerTool>()),
33+
z.promise(z.array(z.custom<ClientTool | ServerTool>())),
34+
])
35+
);
36+
export type CustomToolSelector = z.infer<typeof customToolSelector>;
2237

2338
/**
2439
* Keyword matching configuration
@@ -98,13 +113,18 @@ const contextSchema = z.object({
98113
.object({
99114
threshold: z.number().min(0).max(1).default(0.3),
100115
maxTools: z.number().positive().default(10),
116+
embedFunction: z
117+
.function()
118+
.args(z.string())
119+
.returns(z.union([z.array(z.number()), z.promise(z.array(z.number()))]))
120+
.optional(),
101121
})
102122
.optional(),
103123

104124
/**
105125
* Custom tool selector function
106126
*/
107-
customSelector: z.custom<CustomToolSelector>().optional(),
127+
customSelector: customToolSelector.optional(),
108128
});
109129

110130
/**
@@ -208,20 +228,46 @@ function calculateSimpleSimilarity(text1: string, text2: string): number {
208228
return union.size > 0 ? intersection.size / union.size : 0;
209229
}
210230

231+
/**
232+
* Calculate cosine similarity between two vectors
233+
*/
234+
function cosine(a: number[], b: number[]): number {
235+
const num = a.reduce((s, v, i) => s + v * (b[i] ?? 0), 0);
236+
const da = Math.hypot(...a);
237+
const db = Math.hypot(...b);
238+
return da && db ? num / (da * db) : 0;
239+
}
240+
211241
/**
212242
* Semantic similarity-based tool selection
213243
*/
214-
function selectToolsBySemantic(
244+
async function selectToolsBySemantic(
215245
tools: (ClientTool | ServerTool)[],
216246
config: SemanticMatchConfig,
217247
query: string
218-
): (ClientTool | ServerTool)[] {
219-
const { threshold = 0.3, maxTools = 10 } = config;
248+
): Promise<(ClientTool | ServerTool)[]> {
249+
const { threshold = 0.3, maxTools = 10, embedFunction } = config;
220250

221251
if (!query.trim()) {
222252
return tools.slice(0, maxTools);
223253
}
224254

255+
if (embedFunction) {
256+
const q = await Promise.resolve(embedFunction(query));
257+
const scored = await Promise.all(
258+
tools.map(async (tool) => {
259+
const t = getToolText(tool);
260+
const e = await Promise.resolve(embedFunction(t));
261+
return { tool, score: cosine(q, e) };
262+
})
263+
);
264+
return scored
265+
.filter(({ score }) => score >= threshold)
266+
.sort((a, b) => b.score - a.score)
267+
.slice(0, maxTools)
268+
.map(({ tool }) => tool);
269+
}
270+
225271
const toolsWithScores = tools.map((tool) => ({
226272
tool,
227273
score: calculateSimpleSimilarity(query, getToolText(tool)),
@@ -275,14 +321,14 @@ function selectToolsBySemantic(
275321
* @example
276322
* Basic usage with keyword matching
277323
* ```typescript
278-
* import { bigToolMiddleware } from "langchain/middleware";
324+
* import { bigTool } from "langchain/middleware";
279325
* import { createAgent } from "langchain";
280326
*
281327
* const agent = createAgent({
282328
* model: "openai:gpt-4",
283329
* tools: [...manyTools], // 1000+ tools
284330
* middleware: [
285-
* bigToolMiddleware({
331+
* bigTool({
286332
* strategy: "keyword",
287333
* maxTools: 20,
288334
* keywordConfig: {
@@ -298,19 +344,54 @@ function selectToolsBySemantic(
298344
* @example
299345
* Semantic similarity strategy
300346
* ```typescript
301-
* const semanticMiddleware = bigToolMiddleware({
347+
* // Basic semantic similarity using simple text-overlap
348+
* const semanticMiddleware = bigTool({
302349
* strategy: "semantic",
303350
* semanticConfig: {
304351
* threshold: 0.4,
305-
* maxTools: 15
306-
* }
352+
* maxTools: 15,
353+
* },
354+
* });
355+
*
356+
* // Semantic similarity with a custom embedding function (e.g.cosine ranking)
357+
* // Note: embedFunction can be async. It should return a numeric vector.
358+
* import { embedText } from "./myEmbeddings";
359+
* const semanticWithEmbeddings = bigTool({
360+
* strategy: "semantic",
361+
* semanticConfig: {
362+
* threshold: 0.35,
363+
* maxTools: 10,
364+
* embedFunction: (text: string) => embedText(text),
365+
* },
366+
* });
367+
*
368+
* // Semantic similarity with a vendor embedding function
369+
* import OpenAI from "openai";
370+
* import { bigTool } from "langchain/middleware";
371+
*
372+
* const openai = new OpenAI({ apiKey: process.env.OPENAI_API_KEY! });
373+
* const embedFunction = async (text: string): Promise<number[]> => {
374+
* const res = await openai.embeddings.create({
375+
* model: "text-embedding-3-small", // or "text-embedding-3-large"
376+
* input: text,
377+
* });
378+
* return res.data[0].embedding;
379+
* };
380+
*
381+
* const middleware = bigTool({
382+
* strategy: "semantic",
383+
* semanticConfig: {
384+
* threshold: 0.35,
385+
* maxTools: 10,
386+
* embedFunction,
387+
* },
307388
* });
308389
* ```
309390
*
310391
* @example
311392
* Custom selection logic
312393
* ```typescript
313-
* const customMiddleware = bigToolMiddleware({
394+
* const customMiddleware = bigTool({
314395
* strategy: "custom",
315396
* customSelector: async (tools, query, context) => {
316397
* // Your custom logic here
@@ -326,45 +407,74 @@ function selectToolsBySemantic(
326407
* @example
327408
* Runtime configuration override
328409
* ```typescript
410+
* // Override keyword strategy at runtime
329411
* await agent.invoke(
330412
* { messages: [new HumanMessage("Find files related to user data")] },
331413
* {
332414
* configurable: {
333415
* middleware_context: {
334-
* strategy: "keyword",
335-
* maxTools: 5,
336-
* keywordConfig: {
337-
* keywords: ["user", "data", "file"],
338-
* minMatches: 2
339-
* }
340-
* }
341-
* }
416+
* bigToolOptions: {
417+
* strategy: "keyword",
418+
* maxTools: 5,
419+
* keywordConfig: {
420+
* keywords: ["user", "data", "file"],
421+
* minMatches: 2,
422+
* },
423+
* },
424+
* },
425+
* },
426+
* }
427+
* );
428+
*
429+
* // Override semantic strategy (including embeddings) at runtime
430+
* const embed = (text: string) => embedText(text);
431+
* await agent.invoke(
432+
* { messages: [new HumanMessage("Query customer database records")] },
433+
* {
434+
* configurable: {
435+
* middleware_context: {
436+
* bigToolOptions: {
437+
* strategy: "semantic",
438+
* semanticConfig: {
439+
* threshold: 0.4,
440+
* maxTools: 8,
441+
* embedFunction: embed,
442+
* },
443+
* },
444+
* },
445+
* },
342446
* }
343447
* );
344448
* ```
345449
*/
346-
export function bigToolMiddleware(
450+
export function bigTool(
347451
middlewareOptions?: Partial<z.infer<typeof contextSchema>> & {
348452
customSelector?: CustomToolSelector;
349453
}
350454
) {
351455
return createMiddleware({
352456
name: "BigToolMiddleware",
353-
contextSchema,
457+
contextSchema: z.object({
458+
bigToolOptions: contextSchema,
459+
}),
354460

355461
prepareModelRequest: async (request, state, runtime) => {
462+
const contextConfiguration = runtime.context.bigToolOptions;
356463
// Get configuration with fallbacks
357464
const strategy =
358-
DEFAULT_STRATEGY === runtime.context?.strategy
465+
DEFAULT_STRATEGY === contextConfiguration?.strategy
359466
? middlewareOptions?.strategy ?? DEFAULT_STRATEGY
360467
: DEFAULT_STRATEGY;
361-
const maxTools = runtime.context?.maxTools ?? middlewareOptions?.maxTools;
468+
const maxTools =
469+
contextConfiguration?.maxTools ?? middlewareOptions?.maxTools;
362470
const keywordConfig =
363-
runtime.context?.keywordConfig ?? middlewareOptions?.keywordConfig;
471+
contextConfiguration?.keywordConfig ?? middlewareOptions?.keywordConfig;
364472
const semanticConfig =
365-
runtime.context?.semanticConfig ?? middlewareOptions?.semanticConfig;
473+
contextConfiguration?.semanticConfig ??
474+
middlewareOptions?.semanticConfig;
366475
const customSelector =
367-
runtime.context?.customSelector ?? middlewareOptions?.customSelector;
476+
contextConfiguration?.customSelector ??
477+
middlewareOptions?.customSelector;
368478

369479
const originalTools = request.tools;
370480
let selectedTools = originalTools;
@@ -390,7 +500,7 @@ export function bigToolMiddleware(
390500

391501
case "semantic":
392502
if (semanticConfig) {
393-
selectedTools = selectToolsBySemantic(
503+
selectedTools = await selectToolsBySemantic(
394504
originalTools,
395505
semanticConfig,
396506
query
@@ -401,7 +511,11 @@ export function bigToolMiddleware(
401511
case "custom":
402512
if (customSelector) {
403513
selectedTools = await Promise.resolve(
404-
customSelector(originalTools, query, runtime.context)
514+
customSelector({
515+
tools: originalTools,
516+
query,
517+
context: runtime.context,
518+
})
405519
);
406520
}
407521
break;

0 commit comments

Comments
 (0)