Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions langgraph_supervisor/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Literal, Optional, Type, Union

from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt.chat_agent_executor import (
Expand Down Expand Up @@ -81,12 +82,12 @@ def _process_output(output: dict) -> dict:
"messages": messages,
}

def call_agent(state: dict) -> dict:
output = agent.invoke(state)
def call_agent(state: dict, config: RunnableConfig) -> dict:
output = agent.invoke(state, config)
return _process_output(output)

async def acall_agent(state: dict) -> dict:
output = await agent.ainvoke(state)
async def acall_agent(state: dict, config: RunnableConfig) -> dict:
output = await agent.ainvoke(state, config)
return _process_output(output)

return RunnableCallable(call_agent, acall_agent)
Expand Down
58 changes: 58 additions & 0 deletions tests/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool, tool
from langgraph.graph import MessagesState, StateGraph
from langgraph.prebuilt import create_react_agent

from langgraph_supervisor import create_supervisor
Expand Down Expand Up @@ -545,3 +547,59 @@ def get_tool_calls(msg):
},
]
assert received == expected


def test_metadata_passed_to_subagent() -> None:
"""Test that metadata from config is passed to sub-agents.

This test verifies that when a config object with metadata is passed to the supervisor,
the metadata is correctly passed to the sub-agent when it is invoked.
"""

# Create a tracking agent to verify metadata is passed
def test_node(_state: MessagesState, config: RunnableConfig):
# Assert that the metadata is passed to the sub-agent
assert config["metadata"]["test_key"] == "test_value"
assert config["metadata"]["another_key"] == 123
# Return a new message if the assertion passes.
return {"messages": [AIMessage(content="Test response")]}

tracking_agent_workflow = StateGraph(MessagesState)
tracking_agent_workflow.add_node("test_node", test_node)
tracking_agent_workflow.set_entry_point("test_node")
tracking_agent_workflow.set_finish_point("test_node")
tracking_agent = tracking_agent_workflow.compile()
tracking_agent.name = "test_agent"

# Create a supervisor with the tracking agent
supervisor_model = FakeChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "transfer_to_test_agent",
"args": {},
"id": "call_123",
"type": "tool_call",
}
],
),
AIMessage(content="Final response"),
]
)

supervisor = create_supervisor(
agents=[tracking_agent],
model=supervisor_model,
).compile()

# Create config with metadata
test_metadata = {"test_key": "test_value", "another_key": 123}
config = {"metadata": test_metadata}

# Invoke the supervisor with the config
result = supervisor.invoke({"messages": [HumanMessage(content="Test message")]}, config=config)
# Get the last message in the messages list & verify it matches the value
# returned from the node.
assert result["messages"][-1].content == "Final response"