Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions libs/giskard-agents/src/giskard/agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,16 @@ async def _run_tools(self, chat: Chat[Any]) -> AsyncGenerator[Message, None]:
return

for tool_call in chat.last.tool_calls:
if tool_call.function.name not in self._workflow.tools:
continue # TODO: raise an error?
tool_name = tool_call.function.name or "<missing>"
if tool_name not in self._workflow.tools:
registered_tools = ", ".join(sorted(self._workflow.tools)) or "<none>"
raise ValueError(
f"Unknown tool call '{tool_name}' "
f"(tool_call_id='{tool_call.id}'). "
f"Registered tools: {registered_tools}."
)

tool = self._workflow.tools[tool_call.function.name]
tool = self._workflow.tools[tool_name]
tool_content = await tool.run(
json.loads(tool_call.function.arguments),
ctx=chat.context,
Expand Down
65 changes: 65 additions & 0 deletions libs/giskard-agents/tests/test_workflow_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from giskard import agents
from giskard.agents.errors import WorkflowError
from giskard.agents.generators.base import Response
from giskard.agents.tools import Function, ToolCall
from giskard.agents.workflow import ErrorPolicy
from pydantic import Field, PrivateAttr

Expand Down Expand Up @@ -31,6 +32,41 @@ async def _call_model(
)


class UnknownToolCallGenerator(agents.generators.BaseGenerator):
_num_calls: int = PrivateAttr(default=0)

@override
async def _call_model(
self,
messages: list[agents.chat.Message],
params: agents.generators.GenerationParams,
metadata: dict[str, Any] | None = None,
) -> Response:
self._num_calls += 1

if self._num_calls == 1:
return Response(
message=agents.chat.Message(
role="assistant",
tool_calls=[
ToolCall(
id="call_unknown_1",
function=Function(name="missing_tool", arguments="{}"),
)
],
),
finish_reason="tool_calls",
)

return Response(
message=agents.chat.Message(
role="assistant",
content="unexpected second completion",
),
finish_reason="stop",
)


async def test_run_raises_error():
"""Test that errors are handled correctly."""
workflow = agents.ChatWorkflow(generator=FailingGenerator(fail_after=0))
Expand Down Expand Up @@ -59,6 +95,35 @@ async def test_run_skips_error():
assert chat.failed


async def test_unknown_tool_call_raises_error():
generator = UnknownToolCallGenerator()
workflow = agents.ChatWorkflow(generator=generator)

with pytest.raises(WorkflowError) as exc_info:
_ = await workflow.chat("Hello!", role="user").run()

assert generator._num_calls == 1
assert exc_info.value.exception is not None
assert "Unknown tool call 'missing_tool'" in str(exc_info.value.exception)
assert "Registered tools: <none>" in str(exc_info.value.exception)


async def test_unknown_tool_call_returns_failed_chat():
generator = UnknownToolCallGenerator()
workflow = agents.ChatWorkflow(generator=generator)

chat = await workflow.chat("Hello!", role="user").on_error(ErrorPolicy.RETURN).run()

assert generator._num_calls == 1
assert chat.failed
assert chat.error is not None
assert "Unknown tool call 'missing_tool'" in chat.error.message
assert len(chat.messages) == 2
assert chat.last.role == "assistant"
assert chat.last.tool_calls is not None
assert chat.last.tool_calls[0].function.name == "missing_tool"


async def test_run_many_raises_error():
"""Test that errors are handled correctly."""

Expand Down
Loading