Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 87 additions & 20 deletions libs/langgraph/src/pregel/debug.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { wrap, tasksWithWrites, _readChannels } from "./debug.js";
import { BaseChannel } from "../channels/base.js";
import { LastValue } from "../channels/last_value.js";
import { EmptyChannelError } from "../errors.js";
import { ERROR, INTERRUPT, PULL } from "../constants.js";

describe("wrap", () => {
it("should wrap text with color codes", () => {
Expand Down Expand Up @@ -107,24 +108,27 @@ describe("tasksWithWrites", () => {
{
id: "task1",
name: "Task 1",
path: ["PULL", "Task 1"] as ["PULL", string],
path: [PULL, "Task 1"] as [typeof PULL, string],
interrupts: [],
},
{
id: "task2",
name: "Task 2",
path: ["PULL", "Task 2"] as ["PULL", string],
path: [PULL, "Task 2"] as [typeof PULL, string],
interrupts: [],
},
];

const pendingWrites: Array<[string, string, unknown]> = [];

const result = tasksWithWrites(tasks, pendingWrites);
const result = tasksWithWrites(tasks, pendingWrites, undefined, [
"Task 1",
"Task 2",
]);

expect(result).toEqual([
{ id: "task1", name: "Task 1", path: ["PULL", "Task 1"], interrupts: [] },
{ id: "task2", name: "Task 2", path: ["PULL", "Task 2"], interrupts: [] },
{ id: "task1", name: "Task 1", path: [PULL, "Task 1"], interrupts: [] },
{ id: "task2", name: "Task 2", path: [PULL, "Task 2"], interrupts: [] },
]);
});

Expand All @@ -133,32 +137,35 @@ describe("tasksWithWrites", () => {
{
id: "task1",
name: "Task 1",
path: ["PULL", "Task 1"] as ["PULL", string],
path: [PULL, "Task 1"] as [typeof PULL, string],
interrupts: [],
},
{
id: "task2",
name: "Task 2",
path: ["PULL", "Task 2"] as ["PULL", string],
path: [PULL, "Task 2"] as [typeof PULL, string],
interrupts: [],
},
];

const pendingWrites: Array<[string, string, unknown]> = [
["task1", "__error__", { message: "Test error" }],
["task1", ERROR, { message: "Test error" }],
];

const result = tasksWithWrites(tasks, pendingWrites);
const result = tasksWithWrites(tasks, pendingWrites, undefined, [
"Task 1",
"Task 2",
]);

expect(result).toEqual([
{
id: "task1",
name: "Task 1",
path: ["PULL", "Task 1"],
path: [PULL, "Task 1"],
error: { message: "Test error" },
interrupts: [],
},
{ id: "task2", name: "Task 2", path: ["PULL", "Task 2"], interrupts: [] },
{ id: "task2", name: "Task 2", path: [PULL, "Task 2"], interrupts: [] },
]);
});

Expand All @@ -167,13 +174,13 @@ describe("tasksWithWrites", () => {
{
id: "task1",
name: "Task 1",
path: ["PULL", "Task 1"] as ["PULL", string],
path: [PULL, "Task 1"] as [typeof PULL, string],
interrupts: [],
},
{
id: "task2",
name: "Task 2",
path: ["PULL", "Task 2"] as ["PULL", string],
path: [PULL, "Task 2"] as [typeof PULL, string],
interrupts: [],
},
];
Expand All @@ -184,17 +191,20 @@ describe("tasksWithWrites", () => {
task1: { configurable: { key: "value" } },
};

const result = tasksWithWrites(tasks, pendingWrites, states);
const result = tasksWithWrites(tasks, pendingWrites, states, [
"Task 1",
"Task 2",
]);

expect(result).toEqual([
{
id: "task1",
name: "Task 1",
path: ["PULL", "Task 1"],
path: [PULL, "Task 1"],
interrupts: [],
state: { configurable: { key: "value" } },
},
{ id: "task2", name: "Task 2", path: ["PULL", "Task 2"], interrupts: [] },
{ id: "task2", name: "Task 2", path: [PULL, "Task 2"], interrupts: [] },
]);
});

Expand All @@ -203,24 +213,81 @@ describe("tasksWithWrites", () => {
{
id: "task1",
name: "Task 1",
path: ["PULL", "Task 1"] as ["PULL", string],
path: [PULL, "Task 1"] as [typeof PULL, string],
interrupts: [],
},
];

const pendingWrites: Array<[string, string, unknown]> = [
["task1", "__interrupt__", { value: "Interrupted", when: "during" }],
["task1", INTERRUPT, { value: "Interrupted", when: "during" }],
];

const result = tasksWithWrites(tasks, pendingWrites);
const result = tasksWithWrites(tasks, pendingWrites, undefined, ["task1"]);

expect(result).toEqual([
{
id: "task1",
name: "Task 1",
path: ["PULL", "Task 1"],
path: [PULL, "Task 1"],
interrupts: [{ value: "Interrupted", when: "during" }],
},
]);
});

it("should include results", () => {
const tasks = [
{
id: "task1",
name: "Task 1",
path: [PULL, "Task 1"] as [typeof PULL, string],
interrupts: [],
},
{
id: "task2",
name: "Task 2",
path: [PULL, "Task 2"] as [typeof PULL, string],
interrupts: [],
},
{
id: "task3",
name: "Task 3",
path: [PULL, "Task 3"] as [typeof PULL, string],
interrupts: [],
},
];

const pendingWrites: Array<[string, string, unknown]> = [
["task1", "Task 1", "Result"],
["task2", "Task 2", "Result 2"],
];

const result = tasksWithWrites(tasks, pendingWrites, undefined, [
"Task 1",
"Task 2",
]);

expect(result).toEqual([
{
id: "task1",
name: "Task 1",
path: [PULL, "Task 1"],
interrupts: [],
result: { "Task 1": "Result" },
},
{
id: "task2",
name: "Task 2",
path: [PULL, "Task 2"],
interrupts: [],
result: { "Task 2": "Result 2" },
},
{
id: "task3",
name: "Task 3",
path: [PULL, "Task 3"],
interrupts: [],
result: undefined,
},
]);
});
});
57 changes: 47 additions & 10 deletions libs/langgraph/src/pregel/debug.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ import {
PendingWrite,
} from "@langchain/langgraph-checkpoint";
import { BaseChannel } from "../channels/base.js";
import { ERROR, Interrupt, INTERRUPT, TAG_HIDDEN } from "../constants.js";
import {
ERROR,
Interrupt,
INTERRUPT,
RETURN,
TAG_HIDDEN,
} from "../constants.js";
import { EmptyChannelError } from "../errors.js";
import {
PregelExecutableTask,
Expand Down Expand Up @@ -140,6 +146,8 @@ export function* mapDebugTaskResults<
}
}

type ChannelKey = string | number | symbol;

export function* mapDebugCheckpoint<
N extends PropertyKey,
C extends PropertyKey
Expand All @@ -151,7 +159,8 @@ export function* mapDebugCheckpoint<
metadata: CheckpointMetadata,
tasks: readonly PregelExecutableTask<N, C>[],
pendingWrites: CheckpointPendingWrite[],
parentConfig: RunnableConfig | undefined
parentConfig: RunnableConfig | undefined,
outputKeys: ChannelKey | ChannelKey[]
) {
function formatConfig(config: RunnableConfig) {
// https://stackoverflow.com/a/78298178
Expand Down Expand Up @@ -214,7 +223,7 @@ export function* mapDebugCheckpoint<
values: readChannels(channels, streamChannels),
metadata,
next: tasks.map((task) => task.name),
tasks: tasksWithWrites(tasks, pendingWrites, taskStates),
tasks: tasksWithWrites(tasks, pendingWrites, taskStates, outputKeys),
parentConfig: parentConfig ? formatConfig(parentConfig) : undefined,
},
};
Expand All @@ -223,36 +232,64 @@ export function* mapDebugCheckpoint<
export function tasksWithWrites<N extends PropertyKey, C extends PropertyKey>(
tasks: PregelTaskDescription[] | readonly PregelExecutableTask<N, C>[],
pendingWrites: CheckpointPendingWrite[],
states?: Record<string, RunnableConfig | StateSnapshot>
states: Record<string, RunnableConfig | StateSnapshot> | undefined,
outputKeys: ChannelKey[] | ChannelKey
): PregelTaskDescription[] {
return tasks.map((task): PregelTaskDescription => {
const error = pendingWrites.find(
([id, n]) => id === task.id && n === ERROR
)?.[2];

const interrupts = pendingWrites
.filter(([id, n]) => {
return id === task.id && n === INTERRUPT;
})
.map(([, , v]) => {
return v;
}) as Interrupt[];
.filter(([id, n]) => id === task.id && n === INTERRUPT)
.map(([, , v]) => v) as Interrupt[];

const result = (() => {
if (error || interrupts.length || !pendingWrites.length) return undefined;

const idx = pendingWrites.findIndex(
([tid, n]) => tid === task.id && n === RETURN
);

if (idx >= 0) return pendingWrites[idx][2];

if (typeof outputKeys === "string") {
return pendingWrites.find(
([tid, n]) => tid === task.id && n === outputKeys
)?.[2];
}

if (Array.isArray(outputKeys)) {
const results = pendingWrites
.filter(([tid, n]) => tid === task.id && outputKeys.includes(n))
.map(([, n, v]) => [n, v]);

if (!results.length) return undefined;
return Object.fromEntries(results);
}

return undefined;
})();

if (error) {
return {
id: task.id,
name: task.name as string,
path: task.path,
error,
interrupts,
result,
};
}

const taskState = states?.[task.id];
return {
id: task.id,
name: task.name as string,
path: task.path,
interrupts,
...(taskState !== undefined ? { state: taskState } : {}),
result,
};
});
}
Expand Down
7 changes: 6 additions & 1 deletion libs/langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,12 @@ export class Pregel<
this.streamChannelsAsIs as string | string[]
),
next: nextList,
tasks: tasksWithWrites(nextTasks, saved?.pendingWrites ?? [], taskStates),
tasks: tasksWithWrites(
nextTasks,
saved?.pendingWrites ?? [],
taskStates,
this.streamChannelsAsIs
),
metadata,
config: patchCheckpointMap(saved.config, saved.metadata),
createdAt: saved.checkpoint.ts,
Expand Down
3 changes: 2 additions & 1 deletion libs/langgraph/src/pregel/loop.ts
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,8 @@ export class PregelLoop {
this.checkpointMetadata,
Object.values(this.tasks),
this.checkpointPendingWrites,
this.prevCheckpointConfig
this.prevCheckpointConfig,
this.outputKeys
),
"debug"
)
Expand Down
1 change: 1 addition & 0 deletions libs/langgraph/src/pregel/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ export interface PregelTaskDescription {
readonly interrupts: Interrupt[];
readonly state?: LangGraphRunnableConfig | StateSnapshot;
readonly path?: TaskPath;
readonly result?: unknown;
}

interface CacheKey {
Expand Down
Loading