Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions langgraph_supervisor/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from uuid import UUID, uuid5

from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt.chat_agent_executor import (
AgentState,
Expand All @@ -20,6 +22,7 @@
from langgraph.pregel.remote import RemoteGraph
from langgraph.utils.config import patch_configurable
from langgraph.utils.runnable import RunnableCallable, RunnableLike
from typing_extensions import Annotated, TypedDict

from langgraph_supervisor.agent_name import AgentNameMode, with_agent_name
from langgraph_supervisor.handoff import (
Expand Down Expand Up @@ -192,6 +195,12 @@ def _prepare_tool_node(
return tool_node


class _OuterState(TypedDict):
"""The state of the supervisor workflow."""

messages: Annotated[Sequence[AnyMessage], add_messages]


def create_supervisor(
agents: list[Pregel],
*,
Expand Down Expand Up @@ -362,10 +371,10 @@ def web_search(query: str) -> str:
if add_handoff_back_messages is None:
add_handoff_back_messages = add_handoff_messages

if state_schema is None:
state_schema = (
AgentStateWithStructuredResponse if response_format is not None else AgentState
)
supervisor_schema = state_schema or (
AgentStateWithStructuredResponse if response_format is not None else AgentState
)
workflow_schema = state_schema or _OuterState

agent_names = set()
for agent in agents:
Expand Down Expand Up @@ -406,13 +415,13 @@ def web_search(query: str) -> str:
model=model,
tools=tool_node,
prompt=prompt,
state_schema=state_schema,
state_schema=supervisor_schema,
response_format=response_format,
pre_model_hook=pre_model_hook,
post_model_hook=post_model_hook,
)

builder = StateGraph(state_schema, config_schema=config_schema)
builder = StateGraph(workflow_schema, config_schema=config_schema)
builder.add_node(supervisor_agent, destinations=tuple(agent_names) + (END,))
builder.add_edge(START, supervisor_agent.name)
for agent in agents:
Expand Down