diff --git a/examples/slackbot/src/slackbot/api.py b/examples/slackbot/src/slackbot/api.py index 175d0de83..20abcae7c 100644 --- a/examples/slackbot/src/slackbot/api.py +++ b/examples/slackbot/src/slackbot/api.py @@ -32,7 +32,12 @@ post_slack_message, ) from slackbot.strings import count_tokens, slice_tokens -from slackbot.wrap import WatchToolCalls, _progress_message, _tool_usage_counts +from slackbot.wrap import ( + ToolUseLimitExceeded, + WatchToolCalls, + _progress_message, + _tool_usage_counts, +) BOT_MENTION = r"<@(\w+)>" @@ -69,7 +74,10 @@ async def run_agent( counts_token = _tool_usage_counts.set(defaultdict(int)) try: - with WatchToolCalls(settings=decorator_settings): + with WatchToolCalls( + settings=decorator_settings, + max_tool_calls=settings.max_tool_calls_per_turn, + ): result = await create_agent(model=settings.model_name).run( user_prompt=cleaned_message, message_history=conversation, @@ -154,6 +162,19 @@ async def handle_message(payload: SlackPayload, db: Database): channel_id=event.channel, thread_ts=thread_ts, ) + except ToolUseLimitExceeded as e: + logger.warning(f"Tool use limit exceeded: {e}") + assert event.channel is not None, "No channel found" + await task(post_slack_message)( + message=str(e), + channel_id=event.channel, + thread_ts=thread_ts, + ) + return Completed( + message="Tool use limit exceeded", + name="LIMIT_EXCEEDED", + data=dict(user_context=user_context), + ) except Exception as e: logger.error(f"Error running agent: {e}") assert event.channel is not None, "No channel found" diff --git a/examples/slackbot/src/slackbot/settings.py b/examples/slackbot/src/slackbot/settings.py index 7acde640a..250e3ece8 100644 --- a/examples/slackbot/src/slackbot/settings.py +++ b/examples/slackbot/src/slackbot/settings.py @@ -78,20 +78,24 @@ def validate_log_level(cls, v: str) -> str: description="Slack user ID to notify when discussions are created (e.g., U1234567890)", ) + # Tool use limits + max_tool_calls_per_turn: int = Field( + default=50, + description="Maximum number of tool calls allowed per agent turn to prevent runaway tool use", + ) + @model_validator(mode="after") - def validate_temperature(self) -> "SlackbotSettings": + def _apply_post_validation_defaults(self) -> "SlackbotSettings": if "gpt-5" in self.model_name: self.temperature = 1.0 - return self - - @model_validator(mode="after") - def set_turbopuffer_api_key(self) -> "SlackbotSettings": if not os.getenv("TURBOPUFFER_API_KEY"): try: api_key = Secret.load("tpuf-api-key", _sync=True).get() # type: ignore os.environ["TURBOPUFFER_API_KEY"] = api_key except Exception: pass # If secret doesn't exist, turbopuffer will handle the error + if not self.admin_slack_user_id: + self.admin_slack_user_id = Variable.get("admin-slack-id", _sync=True) return self @property diff --git a/examples/slackbot/src/slackbot/wrap.py b/examples/slackbot/src/slackbot/wrap.py index 2dc32ee36..a9b1dacd1 100644 --- a/examples/slackbot/src/slackbot/wrap.py +++ b/examples/slackbot/src/slackbot/wrap.py @@ -10,6 +10,13 @@ T = TypeVar("T") + +class ToolUseLimitExceeded(Exception): + """Raised when tool use limit is exceeded.""" + + pass + + _progress_message: ContextVar[Any] = ContextVar("progress_message", default=None) _tool_usage_counts: ContextVar[dict[str, int] | None] = ContextVar( "tool_usage_counts", default=None @@ -64,6 +71,7 @@ def prefect_wrapped_function( decorator: Callable[..., Callable[..., T]] = task, tags: set[str] | None = None, settings: dict[str, Any] | None = None, + max_tool_calls: int = 10, # Default limit per agent run (matches settings) ) -> Callable[..., Callable[..., T]]: """Decorator for wrapping a function with a prefect decorator.""" tags = tags or set[str]() @@ -102,6 +110,14 @@ async def wrapper(*args, **kwargs) -> T: _tool_usage_counts.set(counts) counts[tool_name] += 1 + # Check if we've exceeded the limit + total_calls = sum(counts.values()) + if total_calls > max_tool_calls: + # Raise an exception to preserve type safety + raise ToolUseLimitExceeded( + "I've reached my tool use limit for this response. Please ask a follow-up question if you need more information." + ) + # Set current tool _current_tool_token = _current_tool.set(tool_name) @@ -161,11 +177,13 @@ def __init__( patch_method_name: str = "call_tool", tags: set[str] | None = None, settings: dict[str, Any] | None = None, + max_tool_calls: int = 10, ): """Initialize the context manager. Args: tags: Prefect tags to apply to the flow. - flow_kwargs: Keyword arguments to pass to the flow. + settings: Settings to pass to the decorator. + max_tool_calls: Maximum number of tool calls allowed per turn. """ # Import here to avoid circular imports from pydantic_ai.toolsets.abstract import AbstractToolset @@ -176,4 +194,5 @@ def __init__( decorator=prefect_wrapped_function, tags=tags, settings=settings, + max_tool_calls=max_tool_calls, )