diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 51c4659a6..cb5c377b7 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -1,21 +1,24 @@ /* eslint-disable no-param-reassign */ import { - Runnable, - RunnableConfig, - RunnableFunc, - RunnableSequence, + _coerceToRunnable, getCallbackManagerForConfig, mergeConfigs, patchConfig, - _coerceToRunnable, + Runnable, + RunnableConfig, + RunnableFunc, RunnableLike, + RunnableSequence, } from "@langchain/core/runnables"; +import type { StreamEvent } from "@langchain/core/tracers/log_stream"; import { IterableReadableStream } from "@langchain/core/utils/stream"; import { All, + BaseCache, BaseCheckpointSaver, BaseStore, CheckpointListOptions, + CheckpointMetadata, CheckpointTuple, compareChannelVersions, copyCheckpoint, @@ -23,81 +26,38 @@ import { PendingWrite, SCHEDULED, uuid5, - CheckpointMetadata, - BaseCache, } from "@langchain/langgraph-checkpoint"; -import type { StreamEvent } from "@langchain/core/tracers/log_stream"; import { BaseChannel, createCheckpoint, emptyChannels, isBaseChannel, } from "../channels/base.js"; -import { PregelNode } from "./read.js"; -import { validateGraph, validateKeys } from "./validate.js"; -import { mapInput, readChannels } from "./io.js"; -import { - printStepCheckpoint, - printStepTasks, - printStepWrites, - tasksWithWrites, -} from "./debug.js"; -import { ChannelWrite, ChannelWriteEntry, PASSTHROUGH } from "./write.js"; import { + CHECKPOINT_NAMESPACE_END, + CHECKPOINT_NAMESPACE_SEPARATOR, + Command, CONFIG_KEY_CHECKPOINTER, + CONFIG_KEY_NODE_FINISHED, CONFIG_KEY_READ, CONFIG_KEY_SEND, + CONFIG_KEY_STREAM, CONFIG_KEY_TASK_ID, + COPY, + END, ERROR, INPUT, INTERRUPT, - PUSH, - CHECKPOINT_NAMESPACE_SEPARATOR, - CHECKPOINT_NAMESPACE_END, - CONFIG_KEY_STREAM, - Command, - NULL_TASK_ID, - COPY, - END, - CONFIG_KEY_NODE_FINISHED, Interrupt, isInterrupted, + NULL_TASK_ID, + PUSH, } from "../constants.js"; -import { - PregelExecutableTask, - PregelInterface, - PregelParams, - StateSnapshot, - StreamMode, - PregelInputType, - PregelOutputType, - PregelOptions, - SingleChannelSubscriptionOptions, - MultipleChannelSubscriptionOptions, - GetStateOptions, - type StreamOutputMap, -} from "./types.js"; import { GraphRecursionError, GraphValueError, InvalidUpdateError, } from "../errors.js"; -import { - _prepareNextTasks, - _localRead, - _applyWrites, - StrRecord, - WritesProtocol, -} from "./algo.js"; -import { - _coerceToDict, - combineAbortSignals, - getNewChannelVersions, - patchCheckpointMap, - RetryPolicy, -} from "./utils/index.js"; -import { findSubgraphPregel } from "./utils/subgraph.js"; -import { PregelLoop } from "./loop.js"; import { ChannelKeyPlaceholder, isConfiguredManagedValue, @@ -108,16 +68,57 @@ import { } from "../managed/base.js"; import { gatherIterator, patchConfigurable } from "../utils.js"; import { - ensureLangGraphConfig, - recastCheckpointNamespace, -} from "./utils/config.js"; -import { LangGraphRunnableConfig } from "./runnable_types.js"; + _applyWrites, + _localRead, + _prepareNextTasks, + StrRecord, + WritesProtocol, +} from "./algo.js"; +import { + printStepCheckpoint, + printStepTasks, + printStepWrites, + tasksWithWrites, +} from "./debug.js"; +import { mapInput, readChannels } from "./io.js"; +import { PregelLoop } from "./loop.js"; import { StreamMessagesHandler } from "./messages.js"; +import { PregelNode } from "./read.js"; +import { LangGraphRunnableConfig } from "./runnable_types.js"; import { PregelRunner } from "./runner.js"; import { IterableReadableStreamWithAbortSignal, IterableReadableWritableStream, } from "./stream.js"; +import { + GetStateOptions, + MultipleChannelSubscriptionOptions, + PregelExecutableTask, + PregelInputType, + PregelInterface, + PregelOptions, + PregelOutputType, + PregelParams, + SingleChannelSubscriptionOptions, + StateSnapshot, + StreamMode, + type StreamOutputMap, +} from "./types.js"; +import { + ensureLangGraphConfig, + recastCheckpointNamespace, +} from "./utils/config.js"; +import { + _coerceToDict, + combineAbortSignals, + combineCallbacks, + getNewChannelVersions, + patchCheckpointMap, + RetryPolicy, +} from "./utils/index.js"; +import { findSubgraphPregel } from "./utils/subgraph.js"; +import { validateGraph, validateKeys } from "./validate.js"; +import { ChannelWrite, ChannelWriteEntry, PASSTHROUGH } from "./write.js"; type WriteValue = Runnable | RunnableFunc | unknown; type StreamEventsOptions = Parameters[2]; @@ -282,7 +283,7 @@ export class Channel { } } -export type { PregelInputType, PregelOutputType, PregelOptions }; +export type { PregelInputType, PregelOptions, PregelOutputType }; // This is a workaround to allow Pregel to override `invoke` / `stream` and `withConfig` // without having to adhere to the types in the `Runnable` class (thanks to `any`). @@ -1806,10 +1807,12 @@ export class Pregel< const config = { recursionLimit: this.config?.recursionLimit, + ...options, // Similar to `stream`, we need to pass the `config.callbacks` here, // otherwise the user-provided callback will get lost in `ensureLangGraphConfig`. - callbacks: this.config?.callbacks, - ...options, + + // extend the callbacks with the ones from the config + callbacks: combineCallbacks(this.config?.callbacks, options?.callbacks), signal: options?.signal ? combineAbortSignals(options.signal, abortController.signal) : abortController.signal, diff --git a/libs/langgraph/src/pregel/utils/index.ts b/libs/langgraph/src/pregel/utils/index.ts index de151b500..3685fe753 100644 --- a/libs/langgraph/src/pregel/utils/index.ts +++ b/libs/langgraph/src/pregel/utils/index.ts @@ -1,3 +1,4 @@ +import { Callbacks } from "@langchain/core/callbacks/manager"; import { RunnableConfig } from "@langchain/core/runnables"; import type { ChannelVersions, @@ -164,3 +165,36 @@ export function combineAbortSignals(...signals: AbortSignal[]): AbortSignal { return combinedController.signal; } + +/** + * Combine multiple callbacks into a single callback. + * @param callback1 - The first callback to combine. + * @param callback2 - The second callback to combine. + * @returns A single callback that is a combination of the input callbacks. + */ +export const combineCallbacks = ( + callback1?: Callbacks, + callback2?: Callbacks +): Callbacks | undefined => { + if (!callback1 && !callback2) { + return undefined; + } + + if (!callback1) { + return callback2; + } + + if (!callback2) { + return callback1; + } + if (Array.isArray(callback1) && Array.isArray(callback2)) { + return [...callback1, ...callback2]; + } + if (Array.isArray(callback1)) { + return [...callback1, callback2] as Callbacks; + } + if (Array.isArray(callback2)) { + return [callback1, ...callback2]; + } + return [callback1, callback2] as Callbacks; +};