Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,6 @@ credentials.json
bun.lock
coverage/

examples/src/createAgent/*.png
examples/src/createAgent/*.png
# LangGraph API
.langgraph_api
1 change: 1 addition & 0 deletions libs/langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,4 @@ schema/prompt_template.d.cts
node_modules
dist
.yarn
graph-matrix.mermaid.md
214 changes: 165 additions & 49 deletions libs/langchain/src/agents/middlewareAgent/ReactAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ import { AgentNode } from "./nodes/AgentNode.js";
import { ToolNode } from "../nodes/ToolNode.js";
import { BeforeModelNode } from "./nodes/BeforeModalNode.js";
import { AfterModelNode } from "./nodes/AfterModalNode.js";
import { initializeMiddlewareStates } from "./nodes/utils.js";
import {
initializeMiddlewareStates,
parseJumpToTarget,
} from "./nodes/utils.js";

import type { ClientTool, ServerTool, WithStateGraphNodes } from "../types.js";

import {
import type {
CreateAgentParams,
AgentMiddleware,
InferMiddlewareStates,
Expand All @@ -35,6 +38,7 @@ import {
InferContextInput,
InvokeConfiguration,
StreamConfiguration,
JumpToDestination,
} from "./types.js";

import {
Expand Down Expand Up @@ -131,8 +135,16 @@ export class ReactAgent<
>;

// Generate node names for middleware nodes that have hooks
const beforeModelNodes: { index: number; name: string }[] = [];
const afterModelNodes: { index: number; name: string }[] = [];
const beforeModelNodes: {
index: number;
name: string;
allowed?: string[];
}[] = [];
const afterModelNodes: {
index: number;
name: string;
allowed?: string[];
}[] = [];
const modifyModelRequestHookMiddleware: [
AgentMiddleware,
/**
Expand All @@ -141,17 +153,24 @@ export class ReactAgent<
() => any
][] = [];

const middlewareNames = new Set<string>();
const middleware = this.options.middleware ?? [];
for (let i = 0; i < middleware.length; i++) {
let beforeModelNode: BeforeModelNode | undefined;
let afterModelNode: AfterModelNode | undefined;
const m = middleware[i];
if (middlewareNames.has(m.name)) {
throw new Error(`Middleware ${m.name} is defined multiple times`);
}

middlewareNames.add(m.name);
if (m.beforeModel) {
beforeModelNode = new BeforeModelNode(m);
const name = `before_model_${m.name}_${i}`;
const name = `${m.name}.before_model`;
beforeModelNodes.push({
index: i,
name,
allowed: m.beforeModelJumpTo,
});
allNodeWorkflows.addNode(
name,
Expand All @@ -161,10 +180,11 @@ export class ReactAgent<
}
if (m.afterModel) {
afterModelNode = new AfterModelNode(m);
const name = `after_model_${m.name}_${i}`;
const name = `${m.name}.after_model`;
afterModelNodes.push({
index: i,
name,
allowed: m.afterModelJumpTo,
});
allNodeWorkflows.addNode(
name,
Expand Down Expand Up @@ -226,18 +246,35 @@ export class ReactAgent<
allNodeWorkflows.addEdge(START, "model_request");
}

// Connect beforeModel nodes in sequence
for (let i = 0; i < beforeModelNodes.length - 1; i++) {
allNodeWorkflows.addEdge(
beforeModelNodes[i].name,
beforeModelNodes[i + 1].name
);
}
// Connect beforeModel nodes; add conditional routing ONLY if allowed jumps are specified
for (let i = 0; i < beforeModelNodes.length; i++) {
const node = beforeModelNodes[i];
const current = node.name;
const isLast = i === beforeModelNodes.length - 1;
const nextDefault = isLast
? "model_request"
: beforeModelNodes[i + 1].name;

if (node.allowed && node.allowed.length > 0) {
const hasTools = toolClasses.filter(isClientTool).length > 0;
const allowedMapped = node.allowed
.map((t) => parseJumpToTarget(t))
.filter((dest) => dest !== "tools" || hasTools);
const destinations = Array.from(
new Set([nextDefault, ...allowedMapped])
) as ("tools" | "model_request" | typeof END)[];

// Connect last beforeModel node to agent
const lastBeforeModelNode = beforeModelNodes.at(-1);
if (beforeModelNodes.length > 0 && lastBeforeModelNode) {
allNodeWorkflows.addEdge(lastBeforeModelNode.name, "model_request");
allNodeWorkflows.addConditionalEdges(
current,
this.#createBeforeModelRouter(
toolClasses.filter(isClientTool),
nextDefault
),
destinations
);
} else {
allNodeWorkflows.addEdge(current, nextDefault);
}
}

// Connect agent to last afterModel node (for reverse order execution)
Expand All @@ -258,27 +295,59 @@ export class ReactAgent<
}
}

// Connect afterModel nodes in reverse sequence
// Connect afterModel nodes in reverse sequence; add conditional routing ONLY if allowed jumps are specified per node
for (let i = afterModelNodes.length - 1; i > 0; i--) {
allNodeWorkflows.addEdge(
afterModelNodes[i].name,
afterModelNodes[i - 1].name
);
const node = afterModelNodes[i];
const current = node.name;
const nextDefault = afterModelNodes[i - 1].name;

if (node.allowed && node.allowed.length > 0) {
const hasTools = toolClasses.filter(isClientTool).length > 0;
const allowedMapped = node.allowed
.map((t) => parseJumpToTarget(t))
.filter((dest) => dest !== "tools" || hasTools);
const destinations = Array.from(
new Set([nextDefault, ...allowedMapped])
) as ("tools" | "model_request" | typeof END)[];

allNodeWorkflows.addConditionalEdges(
current,
this.#createAfterModelSequenceRouter(
toolClasses.filter(isClientTool),
node.allowed,
nextDefault
),
destinations
);
} else {
allNodeWorkflows.addEdge(current, nextDefault);
}
}

// Connect first afterModel node (last to execute) to model paths with jumpTo support
if (afterModelNodes.length > 0) {
const firstAfterModelNode = afterModelNodes[0].name;
const firstAfterModel = afterModelNodes[0];
const firstAfterModelNode = firstAfterModel.name;
const modelPaths = this.#getModelPaths(
toolClasses.filter(isClientTool),
true
).filter(
(p) => p !== "tools" || toolClasses.filter(isClientTool).length > 0
);

const allowJump = Boolean(
firstAfterModel.allowed && firstAfterModel.allowed.length > 0
);

// Use afterModel router which includes jumpTo logic
const destinations = modelPaths;

allNodeWorkflows.addConditionalEdges(
firstAfterModelNode,
this.#createAfterModelRouter(toolClasses.filter(isClientTool)),
modelPaths
this.#createAfterModelRouter(
toolClasses.filter(isClientTool),
allowJump
),
destinations
);
}

Expand Down Expand Up @@ -444,10 +513,15 @@ export class ReactAgent<
* @param toolClasses - Available tool classes for validation
* @returns Router function that handles jumpTo logic and normal routing
*/
#createAfterModelRouter(toolClasses: (ClientTool | ServerTool)[]) {
#createAfterModelRouter(
toolClasses: (ClientTool | ServerTool)[],
allowJump: boolean
) {
const hasStructuredResponse = Boolean(this.options.responseFormat);

return (state: BuiltInState) => {
return (
state: Omit<BuiltInState, "jumpTo"> & { jumpTo?: JumpToDestination }
) => {
// First, check if we just processed a structured response
// If so, ignore any existing jumpTo and go to END
const messages = state.messages;
Expand All @@ -459,33 +533,20 @@ export class ReactAgent<
return END;
}

// Check if jumpTo is set in the state
if (state.jumpTo) {
const jumpTarget = state.jumpTo;

// If jumpTo is "model", go to model_request node
if (jumpTarget === "model") {
return "model_request";
// Check if jumpTo is set in the state and allowed
if (allowJump && state.jumpTo) {
if (state.jumpTo === END) {
return END;
}

// If jumpTo is "tools", go to tools node
if (jumpTarget === "tools") {
if (state.jumpTo === "tools") {
// If trying to jump to tools but no tools are available, go to END
if (toolClasses.length === 0) {
return END;
}

return "tools";
}

// If jumpTo is END, go to END
if (jumpTarget === END) {
return END;
return new Send("tools", { ...state, jumpTo: undefined });
}

throw new Error(
`Invalid jump target: ${jumpTarget}, must be "model" or "tools".`
);
// destination === "model_request"
return new Send("model_request", { ...state, jumpTo: undefined });
}

// check if there are pending tool calls
Expand Down Expand Up @@ -545,6 +606,61 @@ export class ReactAgent<
};
}

/**
* Router for afterModel sequence nodes (connecting later middlewares to earlier ones),
* honoring allowed jump targets and defaulting to the next node.
*/
#createAfterModelSequenceRouter(
toolClasses: (ClientTool | ServerTool)[],
allowed: string[],
nextDefault: string
) {
const allowedSet = new Set(allowed.map((t) => parseJumpToTarget(t)));
return (state: BuiltInState) => {
if (state.jumpTo) {
const dest = parseJumpToTarget(state.jumpTo);
if (dest === END && allowedSet.has(END)) {
return END;
}
if (dest === "tools" && allowedSet.has("tools")) {
if (toolClasses.length === 0) return END;
return new Send("tools", { ...state, jumpTo: undefined });
}
if (dest === "model_request" && allowedSet.has("model_request")) {
return new Send("model_request", { ...state, jumpTo: undefined });
}
}
return nextDefault as any;
};
}

/**
* Create routing function for jumpTo functionality after beforeModel hooks.
* Falls back to the default next node if no jumpTo is present.
*/
#createBeforeModelRouter(
toolClasses: (ClientTool | ServerTool)[],
nextDefault: string
) {
return (state: BuiltInState) => {
if (!state.jumpTo) {
return nextDefault;
}
const destination = parseJumpToTarget(state.jumpTo);
if (destination === END) {
return END;
}
if (destination === "tools") {
if (toolClasses.length === 0) {
return END;
}
return new Send("tools", { ...state, jumpTo: undefined });
}
// destination === "model_request"
return new Send("model_request", { ...state, jumpTo: undefined });
};
}

/**
* Initialize middleware states if not already present in the input state.
*/
Expand Down
1 change: 1 addition & 0 deletions libs/langchain/src/agents/middlewareAgent/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export const JUMP_TO_TARGETS = ["model", "tools", "end"] as const;
Loading
Loading