Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions langgraph_supervisor/supervisor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from typing import Any, Callable, Literal, Optional, Sequence, Type, Union, cast, get_args
from uuid import UUID, uuid5

from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.runnables import RunnableConfig
Expand All @@ -15,6 +16,7 @@
create_react_agent,
)
from langgraph.pregel import Pregel
from langgraph.utils.config import patch_configurable
from langgraph.utils.runnable import RunnableCallable

from langgraph_supervisor.agent_name import AgentNameMode, with_agent_name
Expand Down Expand Up @@ -85,11 +87,25 @@ def _process_output(output: dict) -> dict:
}

def call_agent(state: dict, config: RunnableConfig) -> dict:
output = agent.invoke(state, config)
thread_id = config["configurable"].get("thread_id")
output = agent.invoke(
state,
patch_configurable(
config,
{"thread_id": uuid5(UUID(str(thread_id)), agent.name) if thread_id else None},
),
)
return _process_output(output)

async def acall_agent(state: dict, config: RunnableConfig) -> dict:
output = await agent.ainvoke(state, config)
thread_id = config["configurable"].get("thread_id")
output = await agent.ainvoke(
state,
patch_configurable(
config,
{"thread_id": uuid5(UUID(str(thread_id)), agent.name) if thread_id else None},
),
)
return _process_output(output)

return RunnableCallable(call_agent, acall_agent)
Expand Down