|
17 | 17 | ) |
18 | 18 |
|
19 | 19 | from marimo import _loggers |
| 20 | +from marimo._ai._tools.types import ToolGuidelines |
20 | 21 | from marimo._ai._tools.utils.exceptions import ToolExecutionError |
21 | 22 | from marimo._config.config import CopilotMode |
22 | 23 | from marimo._server.ai.tools.types import ( |
@@ -95,6 +96,7 @@ class ToolBase(Generic[ArgsT, OutT], ABC): |
95 | 96 | # Override in subclass, or rely on fallbacks below |
96 | 97 | name: str = "" |
97 | 98 | description: str = "" |
| 99 | + guidelines: Optional[ToolGuidelines] = None |
98 | 100 | Args: type[ArgsT] |
99 | 101 | Output: type[OutT] |
100 | 102 | context: ToolContext |
@@ -127,7 +129,15 @@ def __init__(self, context: ToolContext) -> None: |
127 | 129 |
|
128 | 130 | # get description from class docstring |
129 | 131 | if self.description == "": |
130 | | - self.description = (self.__class__.__doc__ or "").strip() |
| 132 | + base_description = (self.__class__.__doc__ or "").strip() |
| 133 | + |
| 134 | + # If guidelines exist, append them |
| 135 | + if self.guidelines is not None: |
| 136 | + self.description = self._format_with_guidelines( |
| 137 | + base_description, self.guidelines |
| 138 | + ) |
| 139 | + else: |
| 140 | + self.description = base_description |
131 | 141 |
|
132 | 142 | async def __call__(self, args: ArgsT) -> OutT: |
133 | 143 | """ |
@@ -241,6 +251,34 @@ def _coerce_args(self, args: Any) -> ArgsT: # type: ignore[override] |
241 | 251 | return args # type: ignore[return-value] |
242 | 252 | return parse_raw(args, self.Args) |
243 | 253 |
|
| 254 | + def _format_with_guidelines( |
| 255 | + self, description: str, guidelines: ToolGuidelines |
| 256 | + ) -> str: |
| 257 | + """Combine description with structured guidelines.""" |
| 258 | + parts = [description] if description else [] |
| 259 | + |
| 260 | + if guidelines.when_to_use: |
| 261 | + parts.append("\n## When to use:") |
| 262 | + parts.extend(f"- {item}" for item in guidelines.when_to_use) |
| 263 | + |
| 264 | + if guidelines.avoid_if: |
| 265 | + parts.append("\n## Avoid if:") |
| 266 | + parts.extend(f"- {item}" for item in guidelines.avoid_if) |
| 267 | + |
| 268 | + if guidelines.prerequisites: |
| 269 | + parts.append("\n## Prerequisites:") |
| 270 | + parts.extend(f"- {item}" for item in guidelines.prerequisites) |
| 271 | + |
| 272 | + if guidelines.side_effects: |
| 273 | + parts.append("\n## Side effects:") |
| 274 | + parts.extend(f"- {item}" for item in guidelines.side_effects) |
| 275 | + |
| 276 | + if guidelines.additional_info: |
| 277 | + parts.append("\n## Additional info:") |
| 278 | + parts.append(guidelines.additional_info) |
| 279 | + |
| 280 | + return "\n".join(parts) |
| 281 | + |
244 | 282 | # error defaults/hooks |
245 | 283 | def _default_error_code(self) -> str: |
246 | 284 | return "UNEXPECTED_ERROR" |
|
0 commit comments