diff --git a/langgraph_supervisor/supervisor.py b/langgraph_supervisor/supervisor.py index 77a4ced..e047a02 100644 --- a/langgraph_supervisor/supervisor.py +++ b/langgraph_supervisor/supervisor.py @@ -9,6 +9,7 @@ from langgraph.prebuilt import ToolNode from langgraph.prebuilt.chat_agent_executor import ( AgentState, + AgentStateWithStructuredResponse, Prompt, StateSchemaType, StructuredResponseSchema, @@ -201,7 +202,7 @@ def create_supervisor( Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]] ] = None, parallel_tool_calls: bool = False, - state_schema: StateSchemaType = AgentState, + state_schema: StateSchemaType | None = None, config_schema: Type[Any] | None = None, output_mode: OutputMode = "last_message", add_handoff_messages: bool = True, @@ -323,6 +324,12 @@ def web_search(query: str) -> str: """ if add_handoff_back_messages is None: add_handoff_back_messages = add_handoff_messages + + if state_schema is None: + state_schema = ( + AgentStateWithStructuredResponse if response_format is not None else AgentState + ) + agent_names = set() for agent in agents: if agent.name is None or agent.name == "LangGraph":