Skip to content

Commit 4066305

Browse files
fix(langchain/createAgent): tools in ModelRequest back to tool instance (#9119)
1 parent 2cb50a6 commit 4066305

File tree

14 files changed

+122
-185
lines changed

14 files changed

+122
-185
lines changed

examples/src/createAgent/dynamicTools/advanced.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ const selectToolsMiddleware = createMiddleware({
6969
name: "SelectToolsMiddleware",
7070
modifyModelRequest: async (request, state) => {
7171
const last = state.messages.at(-1);
72-
const active = last?.content
72+
const tools = last?.content
7373
? // only give me the most relevant tool
7474
await selectTopKBySimilarity(last.content as string, 1)
7575
: fullCatalog.slice(0, 5);
76-
return { ...request, tools: active.map((t) => t.name) };
76+
return { ...request, tools };
7777
},
7878
});
7979

examples/src/createAgent/dynamicTools/simple.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ const vcsToolGate = createMiddleware({
3535
contextSchema: z.object({ vcsProvider: z.string() }),
3636
modifyModelRequest: (request, _state, runtime) => {
3737
const provider = runtime.context.vcsProvider.toLowerCase();
38-
const active =
38+
const tools =
3939
provider === "gitlab" ? [gitlabCreateIssue] : [githubCreateIssue];
40-
return { ...request, tools: active.map((t) => t.name) };
40+
return { ...request, tools };
4141
},
4242
});
4343

libs/langchain/src/agents/middlewareAgent/middleware.ts

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ export function createMiddleware<
107107
: {}) &
108108
AgentBuiltInState,
109109
runtime: Runtime<
110-
(TSchema extends InteropZodObject ? InferInteropZodInput<TSchema> : {}) &
111-
AgentBuiltInState,
112110
TContextSchema extends InteropZodObject
113111
? InferInteropZodOutput<TContextSchema>
114112
: TContextSchema extends InteropZodDefault<any>
@@ -132,8 +130,6 @@ export function createMiddleware<
132130
: {}) &
133131
AgentBuiltInState,
134132
runtime: Runtime<
135-
(TSchema extends InteropZodObject ? InferInteropZodInput<TSchema> : {}) &
136-
AgentBuiltInState,
137133
TContextSchema extends InteropZodObject
138134
? InferInteropZodOutput<TContextSchema>
139135
: TContextSchema extends InteropZodDefault<any>
@@ -171,8 +167,6 @@ export function createMiddleware<
171167
: {}) &
172168
AgentBuiltInState,
173169
runtime: Runtime<
174-
(TSchema extends InteropZodObject ? InferInteropZodInput<TSchema> : {}) &
175-
AgentBuiltInState,
176170
TContextSchema extends InteropZodObject
177171
? InferInteropZodOutput<TContextSchema>
178172
: TContextSchema extends InteropZodDefault<any>
@@ -213,10 +207,6 @@ export function createMiddleware<
213207
options,
214208
state,
215209
runtime as Runtime<
216-
(TSchema extends InteropZodObject
217-
? InferInteropZodInput<TSchema>
218-
: {}) &
219-
AgentBuiltInState,
220210
TContextSchema extends InteropZodObject
221211
? InferInteropZodOutput<TContextSchema>
222212
: TContextSchema extends InteropZodDefault<any>
@@ -235,10 +225,6 @@ export function createMiddleware<
235225
config.beforeModel!(
236226
state,
237227
runtime as Runtime<
238-
(TSchema extends InteropZodObject
239-
? InferInteropZodInput<TSchema>
240-
: {}) &
241-
AgentBuiltInState,
242228
TContextSchema extends InteropZodObject
243229
? InferInteropZodOutput<TContextSchema>
244230
: TContextSchema extends InteropZodDefault<any>
@@ -257,10 +243,6 @@ export function createMiddleware<
257243
config.afterModel!(
258244
state,
259245
runtime as Runtime<
260-
(TSchema extends InteropZodObject
261-
? InferInteropZodInput<TSchema>
262-
: {}) &
263-
AgentBuiltInState,
264246
TContextSchema extends InteropZodObject
265247
? InferInteropZodOutput<TContextSchema>
266248
: TContextSchema extends InteropZodDefault<any>

libs/langchain/src/agents/middlewareAgent/middleware/dynamicSystemPrompt.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import type { Runtime, AgentBuiltInState } from "../types.js";
33

44
export type DynamicSystemPromptMiddlewareConfig<TContextSchema> = (
55
state: AgentBuiltInState,
6-
runtime: Runtime<AgentBuiltInState, TContextSchema>
6+
runtime: Runtime<TContextSchema>
77
) => string | Promise<string>;
88

99
/**
@@ -52,7 +52,7 @@ export function dynamicSystemPromptMiddleware<TContextSchema = unknown>(
5252
modifyModelRequest: async (options, state, runtime) => {
5353
const systemPrompt = await fn(
5454
state as AgentBuiltInState,
55-
runtime as Runtime<AgentBuiltInState, TContextSchema>
55+
runtime as Runtime<TContextSchema>
5656
);
5757

5858
if (typeof systemPrompt !== "string") {

libs/langchain/src/agents/middlewareAgent/middleware/hitl.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ export type DescriptionFactory<
6767
> = (
6868
toolCall: ToolCall,
6969
state: State,
70-
runtime: Runtime<State>
70+
runtime: Runtime<unknown>
7171
) => string | Promise<string>;
7272

7373
/**

libs/langchain/src/agents/middlewareAgent/middleware/llmToolSelector.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ export function llmToolSelectorMiddleware(
110110
/**
111111
* Extract tool information
112112
*/
113-
const toolInfo = runtime.tools.map((tool) => ({
113+
const toolInfo = request.tools.map((tool) => ({
114114
name: tool.name as string,
115115
description: tool.description,
116116
tool,
@@ -225,12 +225,19 @@ export function llmToolSelectorMiddleware(
225225
}
226226
}
227227

228+
/**
229+
* If no tools were selected after all retries, fall back to all tools
230+
*/
231+
if (selectedToolNames.length === 0) {
232+
return request;
233+
}
234+
228235
/**
229236
* Filter tools based on selection
230237
*/
231-
const selectedTools = toolInfo
232-
.filter(({ name }) => selectedToolNames.includes(name))
233-
.map(({ name }) => name);
238+
const selectedTools = toolInfo.filter(({ name }) =>
239+
selectedToolNames.includes(name)
240+
);
234241

235242
return {
236243
...request,

libs/langchain/src/agents/middlewareAgent/nodes/AfterModalNode.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ export class AfterModelNode<
2828
this.name = `AfterModelNode_${middleware.name}`;
2929
}
3030

31-
runHook(state: TStateSchema, runtime: Runtime<TStateSchema, TContextSchema>) {
31+
runHook(state: TStateSchema, runtime: Runtime<TContextSchema>) {
3232
return this.middleware.afterModel!(
3333
state as Record<string, any> & AgentBuiltInState,
34-
runtime as Runtime<TStateSchema, unknown>
34+
runtime as Runtime<unknown>
3535
) as Promise<MiddlewareResult<TStateSchema>>;
3636
}
3737
}

libs/langchain/src/agents/middlewareAgent/nodes/AgentNode.ts

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
getPromptRunnable,
2323
validateLLMHasNoBoundTools,
2424
hasToolCalls,
25+
isClientTool,
2526
} from "../../utils.js";
2627
import { mergeAbortSignals } from "../../nodes/utils.js";
2728
import {
@@ -40,7 +41,6 @@ import {
4041
ToolStrategyError,
4142
hasSupportForJsonSchemaOutput,
4243
} from "../../responses.js";
43-
import { parseToolCalls } from "./utils.js";
4444

4545
type ResponseHandlerResult<StructuredResponseFormat> =
4646
| {
@@ -570,7 +570,7 @@ export class AgentNode<
570570
model,
571571
systemPrompt,
572572
messages: state.messages,
573-
tools: this.#options.toolClasses.map((tool) => tool.name as string),
573+
tools: this.#options.toolClasses,
574574
};
575575

576576
/**
@@ -588,14 +588,11 @@ export class AgentNode<
588588
/**
589589
* Create runtime
590590
*/
591-
const runtime: Runtime<unknown, unknown> = {
592-
toolCalls: parseToolCalls(state.messages),
593-
tools: this.#options.toolClasses,
591+
const runtime: Runtime<unknown> = {
594592
context,
595593
writer: config.writer,
596594
interrupt: config.interrupt,
597595
signal: config.signal,
598-
terminate: (result) => ({ type: "terminate", result }),
599596
};
600597

601598
const result = await middleware.modifyModelRequest!(
@@ -613,26 +610,47 @@ export class AgentNode<
613610
})
614611
);
615612

616-
/**
617-
* raise meaningful error if unknown tools were selected
618-
*/
619-
const unknownTools =
620-
result?.tools?.filter(
621-
(tool) => !this.#options.toolClasses.some((t) => t.name === tool)
622-
) ?? [];
623-
if (unknownTools.length > 0) {
624-
throw new Error(
625-
`Unknown tools selected in middleware "${
626-
middleware.name
627-
}": ${unknownTools.join(
628-
", "
629-
)}, available tools: ${this.#options.toolClasses
630-
.map((t) => t.name)
631-
.join(", ")}!`
613+
if (result) {
614+
const modifiedTools = result.tools ?? [];
615+
616+
/**
617+
* Verify that the user didn't add any new tools.
618+
* We can't allow this as the ToolNode is already initiated with given tools.
619+
*/
620+
const newTools = modifiedTools.filter(
621+
(tool) =>
622+
isClientTool(tool) &&
623+
!this.#options.toolClasses.some((t) => t.name === tool.name)
632624
);
633-
}
625+
if (newTools.length > 0) {
626+
throw new Error(
627+
`You have added a new tool in "modifyModelRequest" hook of middleware "${
628+
middleware.name
629+
}": ${newTools
630+
.map((tool) => tool.name)
631+
.join(", ")}. This is not supported.`
632+
);
633+
}
634+
635+
/**
636+
* Verify that user has not added or modified a tool with the same name.
637+
* We can't allow this as the ToolNode is already initiated with given tools.
638+
*/
639+
const invalidTools = modifiedTools.filter(
640+
(tool) =>
641+
isClientTool(tool) &&
642+
this.#options.toolClasses.every((t) => t !== tool)
643+
);
644+
if (invalidTools.length > 0) {
645+
throw new Error(
646+
`You have modified a tool in "modifyModelRequest" hook of middleware "${
647+
middleware.name
648+
}": ${invalidTools
649+
.map((tool) => tool.name)
650+
.join(", ")}. This is not supported.`
651+
);
652+
}
634653

635-
if (result) {
636654
currentOptions = { ...currentOptions, ...result };
637655
}
638656
}
@@ -653,14 +671,8 @@ export class AgentNode<
653671
);
654672

655673
// Use tools from preparedOptions if provided, otherwise use default tools
656-
const preparedTools = preparedOptions?.tools ?? [];
657674
const allTools = [
658-
...(preparedTools.length > 0
659-
? this.#options.toolClasses.filter(
660-
(tool) =>
661-
typeof tool.name === "string" && preparedTools.includes(tool.name)
662-
)
663-
: this.#options.toolClasses),
675+
...(preparedOptions?.tools ?? []),
664676
...structuredTools.map((toolStrategy) => toolStrategy.tool),
665677
];
666678

libs/langchain/src/agents/middlewareAgent/nodes/BeforeModalNode.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ export class BeforeModelNode<
2727
});
2828
}
2929

30-
runHook(state: TStateSchema, runtime: Runtime<TStateSchema, TContextSchema>) {
30+
runHook(state: TStateSchema, runtime: Runtime<TContextSchema>) {
3131
return this.middleware.beforeModel!(
3232
state as Record<string, any> & AgentBuiltInState,
33-
runtime as Runtime<TStateSchema, unknown>
33+
runtime as Runtime<unknown>
3434
) as Promise<MiddlewareResult<TStateSchema>>;
3535
}
3636
}

libs/langchain/src/agents/middlewareAgent/nodes/middleware.ts

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
2-
/* eslint-disable no-instanceof/no-instanceof */
32
import { z } from "zod/v3";
43
import { LangGraphRunnableConfig, Command } from "@langchain/langgraph";
54
import { interopParse } from "@langchain/core/utils/types";
65

76
import { RunnableCallable } from "../../RunnableCallable.js";
87
import type {
98
Runtime,
10-
ControlAction,
119
AgentMiddleware,
1210
MiddlewareResult,
1311
JumpToTarget,
1412
} from "../types.js";
15-
import {
16-
derivePrivateState,
17-
parseToolCalls,
18-
parseJumpToTarget,
19-
} from "./utils.js";
13+
import { derivePrivateState, parseJumpToTarget } from "./utils.js";
2014

2115
type NodeOutput<TStateSchema extends Record<string, any>> =
2216
| TStateSchema
@@ -33,7 +27,7 @@ export abstract class MiddlewareNode<
3327

3428
abstract runHook(
3529
state: TStateSchema,
36-
config?: Runtime<TStateSchema, TContextSchema>
30+
config?: Runtime<TContextSchema>
3731
): Promise<MiddlewareResult<TStateSchema>>;
3832

3933
async invokeMiddleware(
@@ -74,21 +68,11 @@ export abstract class MiddlewareNode<
7468
/**
7569
* ToDo: implement later
7670
*/
77-
const runtime: Runtime<TStateSchema, TContextSchema> = {
78-
toolCalls: parseToolCalls(state.messages),
71+
const runtime: Runtime<TContextSchema> = {
7972
context: filteredContext,
8073
writer: config?.writer,
8174
interrupt: config?.interrupt,
8275
signal: config?.signal,
83-
tools: this.middleware.tools ?? [],
84-
terminate: (
85-
result?: Partial<TStateSchema> | Error
86-
): ControlAction<TStateSchema> => {
87-
if (result instanceof Error) {
88-
throw result;
89-
}
90-
return { type: "terminate", result };
91-
},
9276
};
9377

9478
const result = await this.runHook(state, runtime);

0 commit comments

Comments
 (0)