Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions examples/slackbot/src/slackbot/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
post_slack_message,
)
from slackbot.strings import count_tokens, slice_tokens
from slackbot.wrap import WatchToolCalls, _progress_message, _tool_usage_counts
from slackbot.wrap import (
ToolUseLimitExceeded,
WatchToolCalls,
_progress_message,
_tool_usage_counts,
)

BOT_MENTION = r"<@(\w+)>"

Expand Down Expand Up @@ -69,7 +74,10 @@ async def run_agent(
counts_token = _tool_usage_counts.set(defaultdict(int))

try:
with WatchToolCalls(settings=decorator_settings):
with WatchToolCalls(
settings=decorator_settings,
max_tool_calls=settings.max_tool_calls_per_turn,
):
result = await create_agent(model=settings.model_name).run(
user_prompt=cleaned_message,
message_history=conversation,
Expand Down Expand Up @@ -154,6 +162,19 @@ async def handle_message(payload: SlackPayload, db: Database):
channel_id=event.channel,
thread_ts=thread_ts,
)
except ToolUseLimitExceeded as e:
logger.warning(f"Tool use limit exceeded: {e}")
assert event.channel is not None, "No channel found"
await task(post_slack_message)(
message=str(e),
channel_id=event.channel,
thread_ts=thread_ts,
)
return Completed(
message="Tool use limit exceeded",
name="LIMIT_EXCEEDED",
data=dict(user_context=user_context),
)
except Exception as e:
logger.error(f"Error running agent: {e}")
assert event.channel is not None, "No channel found"
Expand Down
14 changes: 9 additions & 5 deletions examples/slackbot/src/slackbot/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,24 @@ def validate_log_level(cls, v: str) -> str:
description="Slack user ID to notify when discussions are created (e.g., U1234567890)",
)

# Tool use limits
max_tool_calls_per_turn: int = Field(
default=50,
description="Maximum number of tool calls allowed per agent turn to prevent runaway tool use",
)

@model_validator(mode="after")
def validate_temperature(self) -> "SlackbotSettings":
def _apply_post_validation_defaults(self) -> "SlackbotSettings":
if "gpt-5" in self.model_name:
self.temperature = 1.0
return self

@model_validator(mode="after")
def set_turbopuffer_api_key(self) -> "SlackbotSettings":
if not os.getenv("TURBOPUFFER_API_KEY"):
try:
api_key = Secret.load("tpuf-api-key", _sync=True).get() # type: ignore
os.environ["TURBOPUFFER_API_KEY"] = api_key
except Exception:
pass # If secret doesn't exist, turbopuffer will handle the error
if not self.admin_slack_user_id:
Copy link

Copilot AI Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code assumes admin_slack_user_id is optional, but the field definition doesn't show it as Optional. If it's a required field, this check will never be True. If it should be optional, the field type should be Optional[str] or str | None.

Copilot uses AI. Check for mistakes.
self.admin_slack_user_id = Variable.get("admin-slack-id", _sync=True)
return self

@property
Expand Down
21 changes: 20 additions & 1 deletion examples/slackbot/src/slackbot/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

T = TypeVar("T")


class ToolUseLimitExceeded(Exception):
"""Raised when tool use limit is exceeded."""

pass


_progress_message: ContextVar[Any] = ContextVar("progress_message", default=None)
_tool_usage_counts: ContextVar[dict[str, int] | None] = ContextVar(
"tool_usage_counts", default=None
Expand Down Expand Up @@ -64,6 +71,7 @@ def prefect_wrapped_function(
decorator: Callable[..., Callable[..., T]] = task,
tags: set[str] | None = None,
settings: dict[str, Any] | None = None,
max_tool_calls: int = 10, # Default limit per agent run (matches settings)
) -> Callable[..., Callable[..., T]]:
"""Decorator for wrapping a function with a prefect decorator."""
tags = tags or set[str]()
Expand Down Expand Up @@ -102,6 +110,14 @@ async def wrapper(*args, **kwargs) -> T:
_tool_usage_counts.set(counts)
counts[tool_name] += 1

# Check if we've exceeded the limit
total_calls = sum(counts.values())
if total_calls > max_tool_calls:
# Raise an exception to preserve type safety
raise ToolUseLimitExceeded(
"I've reached my tool use limit for this response. Please ask a follow-up question if you need more information."
)

# Set current tool
_current_tool_token = _current_tool.set(tool_name)

Expand Down Expand Up @@ -161,11 +177,13 @@ def __init__(
patch_method_name: str = "call_tool",
tags: set[str] | None = None,
settings: dict[str, Any] | None = None,
max_tool_calls: int = 10,
):
"""Initialize the context manager.
Args:
tags: Prefect tags to apply to the flow.
flow_kwargs: Keyword arguments to pass to the flow.
settings: Settings to pass to the decorator.
max_tool_calls: Maximum number of tool calls allowed per turn.
"""
# Import here to avoid circular imports
from pydantic_ai.toolsets.abstract import AbstractToolset
Expand All @@ -176,4 +194,5 @@ def __init__(
decorator=prefect_wrapped_function,
tags=tags,
settings=settings,
max_tool_calls=max_tool_calls,
)