diff --git a/libs/langchain/src/agents/middlewareAgent/ReactAgent.ts b/libs/langchain/src/agents/middlewareAgent/ReactAgent.ts index 2192e658b620..59512fe5bf8d 100644 --- a/libs/langchain/src/agents/middlewareAgent/ReactAgent.ts +++ b/libs/langchain/src/agents/middlewareAgent/ReactAgent.ts @@ -173,6 +173,13 @@ export class ReactAgent< */ () => any ][] = []; + const retryModelRequestHookMiddleware: [ + AgentMiddleware, + /** + * ToDo: better type to get the state of middleware + */ + () => any + ][] = []; this.#agentNode = new AgentNode({ model: this.options.model, @@ -185,6 +192,7 @@ export class ReactAgent< shouldReturnDirect, signal: this.options.signal, modifyModelRequestHookMiddleware, + retryModelRequestHookMiddleware, }); const middlewareNames = new Set(); @@ -240,6 +248,16 @@ export class ReactAgent< }), ]); } + + if (m.retryModelRequest) { + retryModelRequestHookMiddleware.push([ + m, + () => ({ + ...beforeModelNode?.getState(), + ...afterModelNode?.getState(), + }), + ]); + } } /** diff --git a/libs/langchain/src/agents/middlewareAgent/middleware.ts b/libs/langchain/src/agents/middlewareAgent/middleware.ts index 474a1912be57..bda5a1b6ad83 100644 --- a/libs/langchain/src/agents/middlewareAgent/middleware.ts +++ b/libs/langchain/src/agents/middlewareAgent/middleware.ts @@ -116,6 +116,34 @@ export function createMiddleware< : never > ) => Promise | ModelRequest | void; + /** + * The function to handle model invocation errors and optionally retry. + * + * @param error - The exception that occurred during model invocation + * @param request - The original model request that failed + * @param state - The current agent state + * @param runtime - The runtime context + * @param attempt - The current attempt number (1-indexed) + * @returns Modified request to retry with, or undefined/null to propagate the error (re-raise) + */ + retryModelRequest?: ( + error: Error, + request: ModelRequest, + state: (TSchema extends InteropZodObject + ? InferInteropZodInput + : {}) & + AgentBuiltInState, + runtime: Runtime< + TContextSchema extends InteropZodObject + ? InferInteropZodOutput + : TContextSchema extends InteropZodDefault + ? InferInteropZodOutput + : TContextSchema extends InteropZodOptional + ? Partial> + : never + >, + attempt: number + ) => Promise | ModelRequest | void; /** * The function to run before the model call. This function is called before the model is invoked and before the `modifyModelRequest` hook. * It allows to modify the state of the agent. @@ -219,6 +247,33 @@ export function createMiddleware< ); } + if (config.retryModelRequest) { + middleware.retryModelRequest = async ( + error, + request, + state, + runtime, + attempt + ) => + Promise.resolve( + config.retryModelRequest!( + error, + request, + state, + runtime as Runtime< + TContextSchema extends InteropZodObject + ? InferInteropZodOutput + : TContextSchema extends InteropZodDefault + ? InferInteropZodOutput + : TContextSchema extends InteropZodOptional + ? Partial> + : never + >, + attempt + ) + ); + } + if (config.beforeModel) { middleware.beforeModel = async (state, runtime) => Promise.resolve( diff --git a/libs/langchain/src/agents/middlewareAgent/middleware/index.ts b/libs/langchain/src/agents/middlewareAgent/middleware/index.ts index ddba86773359..64943a46f17b 100644 --- a/libs/langchain/src/agents/middlewareAgent/middleware/index.ts +++ b/libs/langchain/src/agents/middlewareAgent/middleware/index.ts @@ -30,4 +30,5 @@ export { modelCallLimitMiddleware, type ModelCallLimitMiddlewareConfig, } from "./callLimit.js"; +export { modelFallbackMiddleware } from "./modelFallback.js"; export { type AgentMiddleware } from "../types.js"; diff --git a/libs/langchain/src/agents/middlewareAgent/middleware/modelFallback.ts b/libs/langchain/src/agents/middlewareAgent/middleware/modelFallback.ts new file mode 100644 index 000000000000..c9d15b92bbbb --- /dev/null +++ b/libs/langchain/src/agents/middlewareAgent/middleware/modelFallback.ts @@ -0,0 +1,83 @@ +import type { LanguageModelLike } from "@langchain/core/language_models/base"; +import { initChatModel } from "../../../chat_models/universal.js"; +import type { ModelRequest, AgentMiddleware } from "../types.js"; +import { createMiddleware } from "../middleware.js"; + +/** + * Middleware that provides automatic model fallback on errors. + * + * This middleware attempts to retry failed model calls with alternative models + * in sequence. When a model call fails, it tries the next model in the fallback + * list until either a call succeeds or all models have been exhausted. + * + * @example + * ```ts + * import { createAgent, modelFallbackMiddleware } from "langchain"; + * + * // Create middleware with fallback models (not including primary) + * const fallback = modelFallbackMiddleware({ + * "openai:gpt-4o-mini", // First fallback + * "anthropic:claude-3-5-sonnet-20241022", // Second fallback + * }); + * + * const agent = createAgent({ + * model: "openai:gpt-4o", // Primary model + * middleware: [fallback], + * tools: [], + * }); + * + * // If gpt-4o fails, automatically tries gpt-4o-mini, then claude + * const result = await agent.invoke({ + * messages: [{ role: "user", content: "Hello" }] + * }); + * ``` + * + * @param fallbackModels - The fallback models to try, in order. + * @returns A middleware instance that handles model failures with fallbacks + */ +export function modelFallbackMiddleware( + /** + * The fallback models to try, in order. + */ + ...fallbackModels: (string | LanguageModelLike)[] +): AgentMiddleware { + return createMiddleware({ + name: "modelFallbackMiddleware", + retryModelRequest: async ( + _error, + request, + _state, + _runtime, + attempt + ): Promise => { + /** + * attempt 1 = primary model failed, try models[0] (first fallback) + */ + const fallbackIndex = attempt - 1; + + /** + * All fallback models exhausted + */ + if (fallbackIndex >= fallbackModels.length) { + return undefined; + } + + /** + * Get or initialize the fallback model + */ + const fallbackModel = fallbackModels[fallbackIndex]; + const model = + typeof fallbackModel === "string" + ? await initChatModel(fallbackModel) + : fallbackModel; + + /** + * Try next fallback model + */ + return { + ...request, + model, + }; + }, + }); +} diff --git a/libs/langchain/src/agents/middlewareAgent/middleware/tests/modelFallback.test.ts b/libs/langchain/src/agents/middlewareAgent/middleware/tests/modelFallback.test.ts new file mode 100644 index 000000000000..daffaf251f24 --- /dev/null +++ b/libs/langchain/src/agents/middlewareAgent/middleware/tests/modelFallback.test.ts @@ -0,0 +1,97 @@ +import { expect, describe, it, vi } from "vitest"; +import { HumanMessage, AIMessage } from "@langchain/core/messages"; +import { LanguageModelLike } from "@langchain/core/language_models/base"; + +import { createAgent } from "../../index.js"; +import { modelFallbackMiddleware } from "../modelFallback.js"; + +function createMockModel(name = "ChatAnthropic", model = "anthropic") { + // Mock Anthropic model + const invokeCallback = vi + .fn() + .mockResolvedValue(new AIMessage("Response from model")); + return { + getName: () => name, + bindTools: vi.fn().mockReturnThis(), + _streamResponseChunks: vi.fn().mockReturnThis(), + bind: vi.fn().mockReturnThis(), + invoke: invokeCallback, + lc_runnable: true, + _modelType: model, + _generate: vi.fn(), + _llmType: () => model, + } as unknown as LanguageModelLike; +} + +describe("modelFallbackMiddleware", () => { + it("should retry the model request with the new model", async () => { + const model = createMockModel(); + model.invoke = vi.fn().mockRejectedValue(new Error("Model error")); + const retryModel = createMockModel("ChatAnthropic", "anthropic"); + const agent = createAgent({ + model, + tools: [], + middleware: [modelFallbackMiddleware(retryModel)] as const, + }); + await agent.invoke({ messages: [new HumanMessage("Hello, world!")] }); + expect(model.invoke).toHaveBeenCalledTimes(1); + expect(retryModel.invoke).toHaveBeenCalledTimes(1); + }); + + it("should allow to configure additional models", async () => { + const model = createMockModel(); + model.invoke = vi + .fn() + .mockRejectedValueOnce(new Error("Model error")) + .mockResolvedValueOnce(new AIMessage("Response from model")); + const anotherFailingModel = createMockModel(); + anotherFailingModel.invoke = vi + .fn() + .mockRejectedValue(new Error("Model error")); + const retryModel = createMockModel("ChatAnthropic", "anthropic"); + const agent = createAgent({ + model, + tools: [], + middleware: [ + modelFallbackMiddleware( + anotherFailingModel, + anotherFailingModel, + anotherFailingModel, + retryModel + ), + ] as const, + }); + + await agent.invoke({ messages: [new HumanMessage("Hello, world!")] }); + expect(model.invoke).toHaveBeenCalledTimes(1); + expect(anotherFailingModel.invoke).toHaveBeenCalledTimes(3); + expect(retryModel.invoke).toHaveBeenCalledTimes(1); + }); + + it("should throw if list is exhausted", async () => { + const model = createMockModel(); + model.invoke = vi + .fn() + .mockRejectedValueOnce(new Error("Model error")) + .mockResolvedValueOnce(new AIMessage("Response from model")); + const anotherFailingModel = createMockModel(); + anotherFailingModel.invoke = vi + .fn() + .mockRejectedValue(new Error("Model error")); + const agent = createAgent({ + model, + tools: [], + middleware: [ + modelFallbackMiddleware( + anotherFailingModel, + anotherFailingModel, + anotherFailingModel + ), + ] as const, + }); + + await expect( + agent.invoke({ messages: [new HumanMessage("Hello, world!")] }) + ).rejects.toThrow("Model error"); + }); +}); diff --git a/libs/langchain/src/agents/middlewareAgent/nodes/AgentNode.ts b/libs/langchain/src/agents/middlewareAgent/nodes/AgentNode.ts index 2effc9d4b971..6e75afd37dfb 100644 --- a/libs/langchain/src/agents/middlewareAgent/nodes/AgentNode.ts +++ b/libs/langchain/src/agents/middlewareAgent/nodes/AgentNode.ts @@ -72,6 +72,10 @@ export interface AgentNodeOptions< AgentMiddleware, () => any ][]; + retryModelRequestHookMiddleware?: [ + AgentMiddleware, + () => any + ][]; } interface NativeResponseFormat { @@ -273,94 +277,174 @@ export class AgentNode< /** * Execute modifyModelRequest hooks from beforeModelNodes */ - const preparedOptions = await this.#executePrepareModelRequestHooks( + let preparedOptions = await this.#executePrepareModelRequestHooks( model, state, config ); /** - * If user provides a model in the preparedOptions, use it, - * otherwise use the model from the options + * Retry loop for model invocation with error handling + * Hard limit of 100 attempts to prevent infinite loops from buggy middleware */ - const finalModel = preparedOptions?.model ?? model; + const maxAttempts = 100; + for (let attempt = 1; attempt <= maxAttempts; attempt++) { + try { + /** + * If user provides a model in the preparedOptions, use it, + * otherwise use the model from the options + */ + const finalModel = preparedOptions?.model ?? model; - /** - * Check if the LLM already has bound tools and throw if it does. - */ - validateLLMHasNoBoundTools(finalModel); + /** + * Check if the LLM already has bound tools and throw if it does. + */ + validateLLMHasNoBoundTools(finalModel); - const structuredResponseFormat = this.#getResponseFormat(finalModel); - const modelWithTools = await this.#bindTools( - finalModel, - preparedOptions, - structuredResponseFormat - ); - let modelInput = this.#getModelInputState(state); + const structuredResponseFormat = this.#getResponseFormat(finalModel); + const modelWithTools = await this.#bindTools( + finalModel, + preparedOptions, + structuredResponseFormat + ); + let modelInput = this.#getModelInputState(state); - /** - * Use messages from preparedOptions if provided - */ - if (preparedOptions?.messages) { - modelInput = { ...modelInput, messages: preparedOptions.messages }; - } + /** + * Use messages from preparedOptions if provided + */ + if (preparedOptions?.messages) { + modelInput = { ...modelInput, messages: preparedOptions.messages }; + } - const signal = mergeAbortSignals(this.#options.signal, config.signal); - const invokeConfig = { ...config, signal }; - const response = (await modelWithTools.invoke( - modelInput, - invokeConfig - )) as AIMessage; + const signal = mergeAbortSignals(this.#options.signal, config.signal); + const invokeConfig = { ...config, signal }; + const response = (await modelWithTools.invoke( + modelInput, + invokeConfig + )) as AIMessage; - /** - * if the user requests a native schema output, try to parse the response - * and return the structured response if it is valid - */ - if (structuredResponseFormat?.type === "native") { - const structuredResponse = - structuredResponseFormat.strategy.parse(response); - if (structuredResponse) { - return { structuredResponse, messages: [response] }; - } + /** + * if the user requests a native schema output, try to parse the response + * and return the structured response if it is valid + */ + if (structuredResponseFormat?.type === "native") { + const structuredResponse = + structuredResponseFormat.strategy.parse(response); + if (structuredResponse) { + return { structuredResponse, messages: [response] }; + } + + return response; + } - return response; - } + if (!structuredResponseFormat || !response.tool_calls) { + return response; + } - if (!structuredResponseFormat || !response.tool_calls) { - return response; - } + const toolCalls = response.tool_calls.filter( + (call) => call.name in structuredResponseFormat.tools + ); - const toolCalls = response.tool_calls.filter( - (call) => call.name in structuredResponseFormat.tools - ); + /** + * if there were not structured tool calls, we can return the response + */ + if (toolCalls.length === 0) { + return response; + } - /** - * if there were not structured tool calls, we can return the response - */ - if (toolCalls.length === 0) { - return response; - } + /** + * if there were multiple structured tool calls, we should throw an error as this + * scenario is not defined/supported. + */ + if (toolCalls.length > 1) { + return this.#handleMultipleStructuredOutputs( + response, + toolCalls, + structuredResponseFormat + ); + } - /** - * if there were multiple structured tool calls, we should throw an error as this - * scenario is not defined/supported. - */ - if (toolCalls.length > 1) { - return this.#handleMultipleStructuredOutputs( - response, - toolCalls, - structuredResponseFormat - ); + const toolStrategy = structuredResponseFormat.tools[toolCalls[0].name]; + const toolMessageContent = toolStrategy?.options?.toolMessageContent; + return this.#handleSingleStructuredOutput( + response, + toolCalls[0], + structuredResponseFormat, + toolMessageContent ?? options.lastMessage + ); + } catch (error) { + // Try retry_model_request on each middleware + const retryMiddleware = + this.#options.retryModelRequestHookMiddleware ?? []; + let shouldRetry = false; + + for (const [middleware, getMiddlewareState] of retryMiddleware) { + if (middleware.retryModelRequest) { + /** + * Cast config to LangGraphRunnableConfig to access LangGraph-specific properties + */ + const lgConfig = config as LangGraphRunnableConfig; + + /** + * Merge context with default context of middleware + */ + const context = middleware.contextSchema + ? interopParse(middleware.contextSchema, lgConfig?.context || {}) + : lgConfig?.context; + + /** + * Create runtime + */ + const privateState = this.getState()._privateState; + const runtime: Runtime = { + ...privateState, + context, + writer: lgConfig.writer, + interrupt: lgConfig.interrupt, + signal: lgConfig.signal, + }; + + const retryRequest = await middleware.retryModelRequest( + error as Error, + { + model: preparedOptions?.model ?? model, + systemPrompt: preparedOptions?.systemPrompt, + messages: preparedOptions?.messages ?? state.messages, + tools: this.#options.toolClasses, + }, + { + ...getMiddlewareState(), + messages: state.messages, + }, + /** + * ensure runtime is frozen to prevent modifications + */ + Object.freeze({ + ...runtime, + context, + }), + attempt + ); + + if (retryRequest) { + // Update preparedOptions with the modified request + preparedOptions = retryRequest; + shouldRetry = true; + // Break on first middleware that wants to retry + break; + } + } + } + + // If no middleware wants to retry, re-raise the error + if (!shouldRetry) { + throw error; + } + } } - const toolStrategy = structuredResponseFormat.tools[toolCalls[0].name]; - const toolMessageContent = toolStrategy?.options?.toolMessageContent; - return this.#handleSingleStructuredOutput( - response, - toolCalls[0], - structuredResponseFormat, - toolMessageContent ?? options.lastMessage - ); + // If we exit the loop, max attempts exceeded + throw new Error(`Maximum retry attempts (${maxAttempts}) exceeded`); } /** diff --git a/libs/langchain/src/agents/middlewareAgent/tests/middleware.test.ts b/libs/langchain/src/agents/middlewareAgent/tests/middleware.test.ts index d56066d655fd..9eb37348e725 100644 --- a/libs/langchain/src/agents/middlewareAgent/tests/middleware.test.ts +++ b/libs/langchain/src/agents/middlewareAgent/tests/middleware.test.ts @@ -433,4 +433,82 @@ describe("middleware", () => { ); }); }); + + describe("retryModelRequest", () => { + it("should retry the model request with the new model", async () => { + const model = createMockModel(); + model.invoke = vi.fn().mockRejectedValue(new Error("Model error")); + const retryModel = createMockModel("ChatAnthropic", "anthropic"); + const middleware = createMiddleware({ + name: "middleware", + retryModelRequest: async (_, request) => { + return { + ...request, + model: retryModel, + }; + }, + }); + const agent = createAgent({ + model, + tools: [], + middleware: [middleware] as const, + }); + await agent.invoke({ messages: [new HumanMessage("Hello, world!")] }); + expect(model.invoke).toHaveBeenCalledTimes(1); + expect(retryModel.invoke).toHaveBeenCalledTimes(1); + }); + + it("should not retry the model request if the middleware returns undefined", async () => { + const model = createMockModel(); + model.invoke = vi.fn().mockRejectedValue(new Error("Model error")); + const retryModel = createMockModel("ChatAnthropic", "anthropic"); + const middleware = createMiddleware({ + name: "middleware", + retryModelRequest: async () => { + return; + }, + }); + const agent = createAgent({ + model, + tools: [], + middleware: [middleware] as const, + }); + await expect( + agent.invoke({ messages: [new HumanMessage("Hello, world!")] }) + ).rejects.toThrow("Model error"); + expect(model.invoke).toHaveBeenCalledTimes(1); + expect(retryModel.invoke).toHaveBeenCalledTimes(0); + }); + + it("should break after the first middleware that returns a request", async () => { + const model = createMockModel(); + model.invoke = vi + .fn() + .mockRejectedValueOnce(new Error("Model error")) + .mockResolvedValueOnce(new AIMessage("Response from model")); + const retryModel = createMockModel("ChatAnthropic", "anthropic"); + const middleware1 = createMiddleware({ + name: "middleware1", + retryModelRequest: async (_, request) => request, + }); + const middleware2 = createMiddleware({ + name: "middleware2", + retryModelRequest: async (_, request) => { + return { + ...request, + model: retryModel, + }; + }, + }); + const agent = createAgent({ + model, + tools: [], + middleware: [middleware1, middleware2] as const, + }); + + await agent.invoke({ messages: [new HumanMessage("Hello, world!")] }); + expect(model.invoke).toHaveBeenCalledTimes(2); + expect(retryModel.invoke).toHaveBeenCalledTimes(0); + }); + }); }); diff --git a/libs/langchain/src/agents/middlewareAgent/types.ts b/libs/langchain/src/agents/middlewareAgent/types.ts index d13cda48874c..ee1c8e524ffc 100644 --- a/libs/langchain/src/agents/middlewareAgent/types.ts +++ b/libs/langchain/src/agents/middlewareAgent/types.ts @@ -378,6 +378,26 @@ export interface AgentMiddleware< AgentBuiltInState, runtime: Runtime ): Promise | void> | Partial | void; + /** + * Logic to handle model invocation errors and optionally retry. + * + * @param error - The exception that occurred during model invocation. + * @param request - The original model request that failed. + * @param state - The current agent state. + * @param runtime - The runtime context. + * @param attempt - The current attempt number (1-indexed). + * @returns Modified request to retry with, or undefined/null to propagate the error (re-raise). + */ + retryModelRequest?( + error: Error, + request: ModelRequest, + state: (TSchema extends InteropZodObject + ? InferInteropZodInput + : {}) & + AgentBuiltInState, + runtime: Runtime, + attempt: number + ): Promise | ModelRequest | void; beforeModel?( state: (TSchema extends InteropZodObject ? InferInteropZodInput