|
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | import os |
5 | | -from typing import Callable, Optional, cast |
| 5 | +from typing import Any, Callable, Optional, cast |
6 | 6 |
|
7 | 7 | from marimo._ai._convert import ( |
8 | 8 | convert_to_anthropic_messages, |
@@ -211,29 +211,29 @@ def __call__( |
211 | 211 | DependencyManager.anthropic.require( |
212 | 212 | "chat model requires anthropic. `pip install anthropic`" |
213 | 213 | ) |
214 | | - from anthropic import ( # type: ignore[import-not-found] |
215 | | - NOT_GIVEN, |
216 | | - Anthropic, |
217 | | - ) |
| 214 | + from anthropic import Anthropic |
218 | 215 |
|
219 | 216 | client = Anthropic( |
220 | 217 | api_key=self._require_api_key, |
221 | 218 | base_url=self.base_url, |
222 | 219 | ) |
223 | 220 |
|
224 | 221 | anthropic_messages = convert_to_anthropic_messages(messages) |
225 | | - response = client.messages.create( |
226 | | - model=self.model, |
227 | | - system=self.system_message, |
228 | | - max_tokens=config.max_tokens or 4096, |
229 | | - messages=anthropic_messages, |
230 | | - top_p=config.top_p if config.top_p is not None else NOT_GIVEN, |
231 | | - top_k=config.top_k if config.top_k is not None else NOT_GIVEN, |
232 | | - stream=False, |
233 | | - temperature=config.temperature |
234 | | - if config.temperature is not None |
235 | | - else NOT_GIVEN, |
236 | | - ) |
| 222 | + params: dict[str, Any] = { |
| 223 | + "model": self.model, |
| 224 | + "system": self.system_message, |
| 225 | + "max_tokens": config.max_tokens or 4096, |
| 226 | + "messages": anthropic_messages, |
| 227 | + "stream": False, |
| 228 | + } |
| 229 | + if config.top_p is not None: |
| 230 | + params["top_p"] = config.top_p |
| 231 | + if config.top_k is not None: |
| 232 | + params["top_k"] = config.top_k |
| 233 | + if config.temperature is not None: |
| 234 | + params["temperature"] = config.temperature |
| 235 | + |
| 236 | + response = client.messages.create(**params) |
237 | 237 |
|
238 | 238 | content = response.content |
239 | 239 | if len(content) > 0: |
|
0 commit comments