@@ -14,11 +14,26 @@ export type ToolSelectionStrategy =
1414/**
1515 * Custom tool selector function type
1616 */
17- export type CustomToolSelector < Context = Record < string , unknown > > = (
18- tools : ( ClientTool | ServerTool ) [ ] ,
19- query : string ,
20- context : Context
21- ) => Promise < ( ClientTool | ServerTool ) [ ] > | ( ClientTool | ServerTool ) [ ] ;
17+ const customToolSelector = z
18+ . function ( )
19+ . args (
20+ z . object ( {
21+ // all tools
22+ tools : z . array ( z . custom < ClientTool | ServerTool > ( ) ) . describe ( "Alltools" ) ,
23+ // user query
24+ query : z . string ( ) ,
25+ // agent context
26+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
27+ context : z . custom < Record < string , any > > ( ) ,
28+ } )
29+ )
30+ . returns (
31+ z . union ( [
32+ z . array ( z . custom < ClientTool | ServerTool > ( ) ) ,
33+ z . promise ( z . array ( z . custom < ClientTool | ServerTool > ( ) ) ) ,
34+ ] )
35+ ) ;
36+ export type CustomToolSelector = z . infer < typeof customToolSelector > ;
2237
2338/**
2439 * Keyword matching configuration
@@ -98,13 +113,18 @@ const contextSchema = z.object({
98113 . object ( {
99114 threshold : z . number ( ) . min ( 0 ) . max ( 1 ) . default ( 0.3 ) ,
100115 maxTools : z . number ( ) . positive ( ) . default ( 10 ) ,
116+ embedFunction : z
117+ . function ( )
118+ . args ( z . string ( ) )
119+ . returns ( z . union ( [ z . array ( z . number ( ) ) , z . promise ( z . array ( z . number ( ) ) ) ] ) )
120+ . optional ( ) ,
101121 } )
102122 . optional ( ) ,
103123
104124 /**
105125 * Custom tool selector function
106126 */
107- customSelector : z . custom < CustomToolSelector > ( ) . optional ( ) ,
127+ customSelector : customToolSelector . optional ( ) ,
108128} ) ;
109129
110130/**
@@ -208,20 +228,46 @@ function calculateSimpleSimilarity(text1: string, text2: string): number {
208228 return union . size > 0 ? intersection . size / union . size : 0 ;
209229}
210230
231+ /**
232+ * Calculate cosine similarity between two vectors
233+ */
234+ function cosine ( a : number [ ] , b : number [ ] ) : number {
235+ const num = a . reduce ( ( s , v , i ) => s + v * ( b [ i ] ?? 0 ) , 0 ) ;
236+ const da = Math . hypot ( ...a ) ;
237+ const db = Math . hypot ( ...b ) ;
238+ return da && db ? num / ( da * db ) : 0 ;
239+ }
240+
211241/**
212242 * Semantic similarity-based tool selection
213243 */
214- function selectToolsBySemantic (
244+ async function selectToolsBySemantic (
215245 tools : ( ClientTool | ServerTool ) [ ] ,
216246 config : SemanticMatchConfig ,
217247 query : string
218- ) : ( ClientTool | ServerTool ) [ ] {
219- const { threshold = 0.3 , maxTools = 10 } = config ;
248+ ) : Promise < ( ClientTool | ServerTool ) [ ] > {
249+ const { threshold = 0.3 , maxTools = 10 , embedFunction } = config ;
220250
221251 if ( ! query . trim ( ) ) {
222252 return tools . slice ( 0 , maxTools ) ;
223253 }
224254
255+ if ( embedFunction ) {
256+ const q = await Promise . resolve ( embedFunction ( query ) ) ;
257+ const scored = await Promise . all (
258+ tools . map ( async ( tool ) => {
259+ const t = getToolText ( tool ) ;
260+ const e = await Promise . resolve ( embedFunction ( t ) ) ;
261+ return { tool, score : cosine ( q , e ) } ;
262+ } )
263+ ) ;
264+ return scored
265+ . filter ( ( { score } ) => score >= threshold )
266+ . sort ( ( a , b ) => b . score - a . score )
267+ . slice ( 0 , maxTools )
268+ . map ( ( { tool } ) => tool ) ;
269+ }
270+
225271 const toolsWithScores = tools . map ( ( tool ) => ( {
226272 tool,
227273 score : calculateSimpleSimilarity ( query , getToolText ( tool ) ) ,
@@ -275,14 +321,14 @@ function selectToolsBySemantic(
275321 * @example
276322 * Basic usage with keyword matching
277323 * ```typescript
278- * import { bigToolMiddleware } from "langchain/middleware";
324+ * import { bigTool } from "langchain/middleware";
279325 * import { createAgent } from "langchain";
280326 *
281327 * const agent = createAgent({
282328 * model: "openai:gpt-4",
283329 * tools: [...manyTools], // 1000+ tools
284330 * middleware: [
285- * bigToolMiddleware ({
331+ * bigTool ({
286332 * strategy: "keyword",
287333 * maxTools: 20,
288334 * keywordConfig: {
@@ -298,19 +344,54 @@ function selectToolsBySemantic(
298344 * @example
299345 * Semantic similarity strategy
300346 * ```typescript
301- * const semanticMiddleware = bigToolMiddleware({
347+ * // Basic semantic similarity using simple text-overlap
348+ * const semanticMiddleware = bigTool({
302349 * strategy: "semantic",
303350 * semanticConfig: {
304351 * threshold: 0.4,
305- * maxTools: 15
306- * }
352+ * maxTools: 15,
353+ * },
354+ * });
355+ *
356+ * // Semantic similarity with a custom embedding function (e.g.cosine ranking)
357+ * // Note: embedFunction can be async. It should return a numeric vector.
358+ * import { embedText } from "./myEmbeddings";
359+ * const semanticWithEmbeddings = bigTool({
360+ * strategy: "semantic",
361+ * semanticConfig: {
362+ * threshold: 0.35,
363+ * maxTools: 10,
364+ * embedFunction: (text: string) => embedText(text),
365+ * },
366+ * });
367+ *
368+ * // Semantic similarity with a vendor embedding function
369+ * import OpenAI from "openai";
370+ * import { bigTool } from "langchain/middleware";
371+ *
372+ * const openai = new OpenAI({ apiKey: process.env.OPENAI_API_KEY! });
373+ * const embedFunction = async (text: string): Promise<number[]> => {
374+ * const res = await openai.embeddings.create({
375+ * model: "text-embedding-3-small", // or "text-embedding-3-large"
376+ * input: text,
377+ * });
378+ * return res.data[0].embedding;
379+ * };
380+ *
381+ * const middleware = bigTool({
382+ * strategy: "semantic",
383+ * semanticConfig: {
384+ * threshold: 0.35,
385+ * maxTools: 10,
386+ * embedFunction,
387+ * },
307388 * });
308389 * ```
309390 *
310391 * @example
311392 * Custom selection logic
312393 * ```typescript
313- * const customMiddleware = bigToolMiddleware ({
394+ * const customMiddleware = bigTool ({
314395 * strategy: "custom",
315396 * customSelector: async (tools, query, context) => {
316397 * // Your custom logic here
@@ -326,45 +407,74 @@ function selectToolsBySemantic(
326407 * @example
327408 * Runtime configuration override
328409 * ```typescript
410+ * // Override keyword strategy at runtime
329411 * await agent.invoke(
330412 * { messages: [new HumanMessage("Find files related to user data")] },
331413 * {
332414 * configurable: {
333415 * middleware_context: {
334- * strategy: "keyword",
335- * maxTools: 5,
336- * keywordConfig: {
337- * keywords: ["user", "data", "file"],
338- * minMatches: 2
339- * }
340- * }
341- * }
416+ * bigToolOptions: {
417+ * strategy: "keyword",
418+ * maxTools: 5,
419+ * keywordConfig: {
420+ * keywords: ["user", "data", "file"],
421+ * minMatches: 2,
422+ * },
423+ * },
424+ * },
425+ * },
426+ * }
427+ * );
428+ *
429+ * // Override semantic strategy (including embeddings) at runtime
430+ * const embed = (text: string) => embedText(text);
431+ * await agent.invoke(
432+ * { messages: [new HumanMessage("Query customer database records")] },
433+ * {
434+ * configurable: {
435+ * middleware_context: {
436+ * bigToolOptions: {
437+ * strategy: "semantic",
438+ * semanticConfig: {
439+ * threshold: 0.4,
440+ * maxTools: 8,
441+ * embedFunction: embed,
442+ * },
443+ * },
444+ * },
445+ * },
342446 * }
343447 * );
344448 * ```
345449 */
346- export function bigToolMiddleware (
450+ export function bigTool (
347451 middlewareOptions ?: Partial < z . infer < typeof contextSchema > > & {
348452 customSelector ?: CustomToolSelector ;
349453 }
350454) {
351455 return createMiddleware ( {
352456 name : "BigToolMiddleware" ,
353- contextSchema,
457+ contextSchema : z . object ( {
458+ bigToolOptions : contextSchema ,
459+ } ) ,
354460
355461 prepareModelRequest : async ( request , state , runtime ) => {
462+ const contextConfiguration = runtime . context . bigToolOptions ;
356463 // Get configuration with fallbacks
357464 const strategy =
358- DEFAULT_STRATEGY === runtime . context ?. strategy
465+ DEFAULT_STRATEGY === contextConfiguration ?. strategy
359466 ? middlewareOptions ?. strategy ?? DEFAULT_STRATEGY
360467 : DEFAULT_STRATEGY ;
361- const maxTools = runtime . context ?. maxTools ?? middlewareOptions ?. maxTools ;
468+ const maxTools =
469+ contextConfiguration ?. maxTools ?? middlewareOptions ?. maxTools ;
362470 const keywordConfig =
363- runtime . context ?. keywordConfig ?? middlewareOptions ?. keywordConfig ;
471+ contextConfiguration ?. keywordConfig ?? middlewareOptions ?. keywordConfig ;
364472 const semanticConfig =
365- runtime . context ?. semanticConfig ?? middlewareOptions ?. semanticConfig ;
473+ contextConfiguration ?. semanticConfig ??
474+ middlewareOptions ?. semanticConfig ;
366475 const customSelector =
367- runtime . context ?. customSelector ?? middlewareOptions ?. customSelector ;
476+ contextConfiguration ?. customSelector ??
477+ middlewareOptions ?. customSelector ;
368478
369479 const originalTools = request . tools ;
370480 let selectedTools = originalTools ;
@@ -390,7 +500,7 @@ export function bigToolMiddleware(
390500
391501 case "semantic" :
392502 if ( semanticConfig ) {
393- selectedTools = selectToolsBySemantic (
503+ selectedTools = await selectToolsBySemantic (
394504 originalTools ,
395505 semanticConfig ,
396506 query
@@ -401,7 +511,11 @@ export function bigToolMiddleware(
401511 case "custom" :
402512 if ( customSelector ) {
403513 selectedTools = await Promise . resolve (
404- customSelector ( originalTools , query , runtime . context )
514+ customSelector ( {
515+ tools : originalTools ,
516+ query,
517+ context : runtime . context ,
518+ } )
405519 ) ;
406520 }
407521 break ;
0 commit comments