Skip to content

Commit 1eada02

Browse files
authored
refactor: Replace list_models_async with alist_models. (#442)
For consistency with the rest of the API.
1 parent 2cb4bc2 commit 1eada02

File tree

18 files changed

+58
-64
lines changed

18 files changed

+58
-64
lines changed

docs/api/list_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
## Models
22

33
::: any_llm.api.list_models
4-
::: any_llm.api.list_models_async
4+
::: any_llm.api.alist_models

src/any_llm/any_llm.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Inspired by https://github.com/andrewyng/aisuite/tree/main/aisuite
22
from __future__ import annotations
33

4-
import asyncio
54
import importlib
65
import os
76
import warnings
@@ -473,14 +472,14 @@ async def _aembedding(self, model: str, inputs: str | list[str], **kwargs: Any)
473472
raise NotImplementedError(msg)
474473

475474
def list_models(self, **kwargs: Any) -> Sequence[Model]:
476-
"""Return a list of Model if the provider supports listing models.
475+
return run_async_in_sync(self.alist_models(**kwargs), allow_running_loop=INSIDE_NOTEBOOK)
477476

478-
Should be overridden by subclasses.
479-
"""
480-
msg = "Subclasses must implement list_models method"
477+
async def alist_models(self, **kwargs: Any) -> Sequence[Model]:
478+
return await self._alist_models(**kwargs)
479+
480+
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
481481
if not self.SUPPORTS_LIST_MODELS:
482+
msg = "Provider doesn't support listing models."
482483
raise NotImplementedError(msg)
484+
msg = "Subclasses must implement _alist_models method"
483485
raise NotImplementedError(msg)
484-
485-
async def list_models_async(self, **kwargs: Any) -> Sequence[Model]:
486-
return await asyncio.to_thread(self.list_models, **kwargs)

src/any_llm/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def list_models(
432432
return llm.list_models(**kwargs)
433433

434434

435-
async def list_models_async(
435+
async def alist_models(
436436
provider: str | LLMProvider,
437437
api_key: str | None = None,
438438
api_base: str | None = None,
@@ -443,4 +443,4 @@ async def list_models_async(
443443
llm = AnyLLM.create(
444444
LLMProvider.from_string(provider), ClientConfig(api_key=api_key, api_base=api_base, client_args=client_args)
445445
)
446-
return await llm.list_models_async(**kwargs)
446+
return await llm.alist_models(**kwargs)

src/any_llm/providers/anthropic/anthropic.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
MISSING_PACKAGES_ERROR = None
99
try:
10-
from anthropic import Anthropic, AsyncAnthropic
10+
from anthropic import AsyncAnthropic
1111

1212
from .utils import (
1313
_convert_models_list,
@@ -19,7 +19,6 @@
1919
MISSING_PACKAGES_ERROR = e
2020

2121
if TYPE_CHECKING:
22-
from anthropic.pagination import SyncPage
2322
from anthropic.types import Message
2423
from anthropic.types.model_info import ModelInfo as AnthropicModelInfo
2524

@@ -75,7 +74,7 @@ def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse:
7574
raise NotImplementedError(msg)
7675

7776
@staticmethod
78-
def _convert_list_models_response(response: "SyncPage[AnthropicModelInfo]") -> Sequence[Model]:
77+
def _convert_list_models_response(response: "list[AnthropicModelInfo]") -> Sequence[Model]:
7978
"""Convert Anthropic models list to OpenAI format."""
8079
return _convert_models_list(response)
8180

@@ -111,12 +110,11 @@ async def _acompletion(
111110

112111
return self._convert_completion_response(message)
113112

114-
def list_models(self, **kwargs: Any) -> Sequence[Model]:
115-
"""List available models from Anthropic."""
116-
client = Anthropic(
113+
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
114+
client = AsyncAnthropic(
117115
api_key=self.config.api_key,
118116
base_url=self.config.api_base,
119117
**(self.config.client_args if self.config.client_args else {}),
120118
)
121-
models_list = client.models.list(**kwargs)
122-
return self._convert_list_models_response(models_list)
119+
models_list = await client.models.list(**kwargs)
120+
return self._convert_list_models_response(models_list.data)

src/any_llm/providers/anthropic/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
22
from typing import Any
33

4-
from anthropic.pagination import SyncPage
54
from anthropic.types import (
65
ContentBlockDeltaEvent,
76
ContentBlockStartEvent,
@@ -316,7 +315,7 @@ def _convert_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]:
316315
return result_kwargs
317316

318317

319-
def _convert_models_list(models_list: SyncPage[AnthropicModelInfo]) -> list[Model]:
318+
def _convert_models_list(models_list: list[AnthropicModelInfo]) -> list[Model]:
320319
"""Convert Anthropic models list to OpenAI format."""
321320
return [
322321
Model(id=model.id, object="model", created=int(model.created_at.timestamp()), owned_by="anthropic")

src/any_llm/providers/bedrock/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def embedding(
239239
response_data = {"embedding_data": embedding_data, "model": model, "total_tokens": total_tokens}
240240
return self._convert_embedding_response(response_data)
241241

242-
def list_models(self, **kwargs: Any) -> Sequence[Model]:
242+
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
243243
"""
244244
Fetch available models from the /v1/models endpoint.
245245
"""

src/any_llm/providers/cerebras/cerebras.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ async def _acompletion(
151151

152152
return self._convert_completion_response(response_data)
153153

154-
def list_models(self, **kwargs: Any) -> Sequence[Model]:
154+
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
155155
"""
156156
Fetch available models from the /v1/models endpoint.
157157
"""
158-
client = cerebras.Cerebras(
158+
client = cerebras.AsyncCerebras(
159159
api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})
160160
)
161-
models_list = client.models.list(**kwargs)
161+
models_list = await client.models.list(**kwargs)
162162
return self._convert_list_models_response(models_list)

src/any_llm/providers/cohere/cohere.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ async def _acompletion(
148148

149149
return self._convert_completion_response(response, model=params.model_id)
150150

151-
def list_models(self, **kwargs: Any) -> Sequence[Model]:
151+
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
152152
"""
153153
Fetch available models from the /v1/models endpoint.
154154
"""
155-
client = cohere.ClientV2(
155+
client = cohere.AsyncClientV2(
156156
api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})
157157
)
158-
model_list = client.models.list(**kwargs)
158+
model_list = await client.models.list(**kwargs)
159159
return self._convert_list_models_response(model_list)

src/any_llm/providers/gemini/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,7 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]:
246246
response_dict = _convert_response_to_response_dict(response)
247247
return self._convert_completion_response((response_dict, params.model_id))
248248

249-
def list_models(self, **kwargs: Any) -> Sequence[Model]:
250-
"""Fetch available models from the /v1/models endpoint."""
249+
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
251250
client = self._get_client(self.config)
252-
models_list = client.models.list(**kwargs)
251+
models_list = await client.aio.models.list(**kwargs)
253252
return self._convert_list_models_response(models_list)

src/any_llm/providers/groq/groq.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,12 @@ async def _aresponses(
171171

172172
return response
173173

174-
def list_models(self, **kwargs: Any) -> Sequence[Model]:
174+
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
175175
"""
176176
Fetch available models from the /v1/models endpoint.
177177
"""
178-
client = groq.Groq(api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}))
179-
models_list = client.models.list(**kwargs)
178+
client = groq.AsyncGroq(
179+
api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})
180+
)
181+
models_list = await client.models.list(**kwargs)
180182
return self._convert_list_models_response(models_list)

0 commit comments

Comments
 (0)