Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/list_models.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Models

::: any_llm.api.list_models
::: any_llm.api.list_models_async
::: any_llm.api.alist_models
15 changes: 7 additions & 8 deletions src/any_llm/any_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Inspired by https://github.com/andrewyng/aisuite/tree/main/aisuite
from __future__ import annotations

import asyncio
import importlib
import os
import warnings
Expand Down Expand Up @@ -473,14 +472,14 @@ async def _aembedding(self, model: str, inputs: str | list[str], **kwargs: Any)
raise NotImplementedError(msg)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
"""Return a list of Model if the provider supports listing models.
return run_async_in_sync(self.alist_models(**kwargs), allow_running_loop=INSIDE_NOTEBOOK)

Should be overridden by subclasses.
"""
msg = "Subclasses must implement list_models method"
async def alist_models(self, **kwargs: Any) -> Sequence[Model]:
return await self._alist_models(**kwargs)

async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
if not self.SUPPORTS_LIST_MODELS:
msg = "Provider doesn't support listing models."
raise NotImplementedError(msg)
msg = "Subclasses must implement _alist_models method"
raise NotImplementedError(msg)

async def list_models_async(self, **kwargs: Any) -> Sequence[Model]:
return await asyncio.to_thread(self.list_models, **kwargs)
4 changes: 2 additions & 2 deletions src/any_llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def list_models(
return llm.list_models(**kwargs)


async def list_models_async(
async def alist_models(
provider: str | LLMProvider,
api_key: str | None = None,
api_base: str | None = None,
Expand All @@ -443,4 +443,4 @@ async def list_models_async(
llm = AnyLLM.create(
LLMProvider.from_string(provider), ClientConfig(api_key=api_key, api_base=api_base, client_args=client_args)
)
return await llm.list_models_async(**kwargs)
return await llm.alist_models(**kwargs)
14 changes: 6 additions & 8 deletions src/any_llm/providers/anthropic/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

MISSING_PACKAGES_ERROR = None
try:
from anthropic import Anthropic, AsyncAnthropic
from anthropic import AsyncAnthropic

from .utils import (
_convert_models_list,
Expand All @@ -19,7 +19,6 @@
MISSING_PACKAGES_ERROR = e

if TYPE_CHECKING:
from anthropic.pagination import SyncPage
from anthropic.types import Message
from anthropic.types.model_info import ModelInfo as AnthropicModelInfo

Expand Down Expand Up @@ -75,7 +74,7 @@ def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse:
raise NotImplementedError(msg)

@staticmethod
def _convert_list_models_response(response: "SyncPage[AnthropicModelInfo]") -> Sequence[Model]:
def _convert_list_models_response(response: "list[AnthropicModelInfo]") -> Sequence[Model]:
"""Convert Anthropic models list to OpenAI format."""
return _convert_models_list(response)

Expand Down Expand Up @@ -111,12 +110,11 @@ async def _acompletion(

return self._convert_completion_response(message)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
"""List available models from Anthropic."""
client = Anthropic(
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
client = AsyncAnthropic(
api_key=self.config.api_key,
base_url=self.config.api_base,
**(self.config.client_args if self.config.client_args else {}),
)
models_list = client.models.list(**kwargs)
return self._convert_list_models_response(models_list)
models_list = await client.models.list(**kwargs)
return self._convert_list_models_response(models_list.data)
3 changes: 1 addition & 2 deletions src/any_llm/providers/anthropic/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from typing import Any

from anthropic.pagination import SyncPage
from anthropic.types import (
ContentBlockDeltaEvent,
ContentBlockStartEvent,
Expand Down Expand Up @@ -316,7 +315,7 @@ def _convert_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]:
return result_kwargs


def _convert_models_list(models_list: SyncPage[AnthropicModelInfo]) -> list[Model]:
def _convert_models_list(models_list: list[AnthropicModelInfo]) -> list[Model]:
"""Convert Anthropic models list to OpenAI format."""
return [
Model(id=model.id, object="model", created=int(model.created_at.timestamp()), owned_by="anthropic")
Expand Down
2 changes: 1 addition & 1 deletion src/any_llm/providers/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def embedding(
response_data = {"embedding_data": embedding_data, "model": model, "total_tokens": total_tokens}
return self._convert_embedding_response(response_data)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/any_llm/providers/cerebras/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ async def _acompletion(

return self._convert_completion_response(response_data)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
client = cerebras.Cerebras(
client = cerebras.AsyncCerebras(
api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})
)
models_list = client.models.list(**kwargs)
models_list = await client.models.list(**kwargs)
return self._convert_list_models_response(models_list)
6 changes: 3 additions & 3 deletions src/any_llm/providers/cohere/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ async def _acompletion(

return self._convert_completion_response(response, model=params.model_id)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
client = cohere.ClientV2(
client = cohere.AsyncClientV2(
api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})
)
model_list = client.models.list(**kwargs)
model_list = await client.models.list(**kwargs)
return self._convert_list_models_response(model_list)
5 changes: 2 additions & 3 deletions src/any_llm/providers/gemini/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]:
response_dict = _convert_response_to_response_dict(response)
return self._convert_completion_response((response_dict, params.model_id))

def list_models(self, **kwargs: Any) -> Sequence[Model]:
"""Fetch available models from the /v1/models endpoint."""
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
client = self._get_client(self.config)
models_list = client.models.list(**kwargs)
models_list = await client.aio.models.list(**kwargs)
return self._convert_list_models_response(models_list)
8 changes: 5 additions & 3 deletions src/any_llm/providers/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ async def _aresponses(

return response

def list_models(self, **kwargs: Any) -> Sequence[Model]:
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
client = groq.Groq(api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}))
models_list = client.models.list(**kwargs)
client = groq.AsyncGroq(
api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})
)
models_list = await client.models.list(**kwargs)
return self._convert_list_models_response(models_list)
2 changes: 1 addition & 1 deletion src/any_llm/providers/huggingface/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def _acompletion(
usage=usage,
)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
Expand Down
7 changes: 2 additions & 5 deletions src/any_llm/providers/mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,11 @@ async def _aembedding(

return self._convert_embedding_response(result)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
client = Mistral(
api_key=self.config.api_key,
server_url=self.config.api_base,
**(self.config.client_args if self.config.client_args else {}),
)
models_list = client.models.list(**kwargs)
models_list = await client.models.list_async(**kwargs)
return self._convert_list_models_response(models_list)
13 changes: 5 additions & 8 deletions src/any_llm/providers/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

MISSING_PACKAGES_ERROR = None
try:
from ollama import AsyncClient, Client
from ollama import AsyncClient

from .utils import (
_convert_models_list,
Expand All @@ -24,7 +24,7 @@
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Sequence

from ollama import AsyncClient, Client # noqa: TC004
from ollama import AsyncClient # noqa: TC004
from ollama import ChatResponse as OllamaChatResponse

from any_llm.config import ClientConfig
Expand Down Expand Up @@ -230,10 +230,7 @@ async def _aembedding(
)
return self._convert_embedding_response(response)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
client = Client(host=self.url, **(self.config.client_args if self.config.client_args else {}))
models_list = client.list(**kwargs)
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
client = AsyncClient(host=self.url, **(self.config.client_args if self.config.client_args else {}))
models_list = await client.list(**kwargs)
return self._convert_list_models_response(models_list)
6 changes: 3 additions & 3 deletions src/any_llm/providers/openai/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ async def _aembedding(
)
)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
if not self.SUPPORTS_LIST_MODELS:
message = f"{self.PROVIDER_NAME} does not support listing models."
raise NotImplementedError(message)
client = cast("OpenAI", self._get_client(sync=True))
response = client.models.list(**kwargs)
client = cast("AsyncOpenAI", self._get_client())
response = await client.models.list(**kwargs)
return self._convert_list_models_response(response)
2 changes: 1 addition & 1 deletion src/any_llm/providers/together/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TogetherProvider(AnyLLM):
SUPPORTS_COMPLETION_IMAGE = True
SUPPORTS_COMPLETION_PDF = True
SUPPORTS_EMBEDDING = False
SUPPORTS_LIST_MODELS = True
SUPPORTS_LIST_MODELS = False

MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR

Expand Down
2 changes: 1 addition & 1 deletion src/any_llm/providers/watsonx/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ async def _acompletion(

return self._convert_completion_response(response)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
Expand Down
9 changes: 5 additions & 4 deletions src/any_llm/providers/xai/xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
MISSING_PACKAGES_ERROR = None
try:
from xai_sdk import AsyncClient as XaiAsyncClient
from xai_sdk import Client as XaiClient
from xai_sdk.chat import Chunk as XaiChunk
from xai_sdk.chat import Response as XaiResponse
from xai_sdk.chat import assistant, required_tool, system, tool_result, user
Expand Down Expand Up @@ -148,10 +147,12 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]:

return self._convert_completion_response(response)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
async def _alist_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
client = XaiClient(api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}))
models_list = client.models.list_language_models()
client = XaiAsyncClient(
api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})
)
models_list = await client.models.list_language_models()
return self._convert_list_models_response(models_list)
16 changes: 9 additions & 7 deletions tests/unit/providers/test_openai_base_provider.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

from any_llm.config import ClientConfig
from any_llm.providers.openai.base import BaseOpenAIProvider
from any_llm.types.model import Model


@patch("any_llm.providers.openai.base.OpenAI")
@patch("any_llm.providers.openai.base.AsyncOpenAI")
def test_list_models_returns_model_list_when_supported(mock_openai_class: MagicMock) -> None:
class ListModelsProvider(BaseOpenAIProvider):
SUPPORTS_LIST_MODELS = True
Expand All @@ -19,7 +19,7 @@ class ListModelsProvider(BaseOpenAIProvider):
Model(id="gpt-4", object="model", created=1687882411, owned_by="openai"),
]

mock_client = MagicMock()
mock_client = AsyncMock()
mock_client.models.list.return_value.data = mock_model_data
mock_openai_class.return_value = mock_client

Expand All @@ -33,7 +33,7 @@ class ListModelsProvider(BaseOpenAIProvider):
mock_client.models.list.assert_called_once_with()


@patch("any_llm.providers.openai.base.OpenAI")
@patch("any_llm.providers.openai.base.AsyncOpenAI")
def test_list_models_uses_default_api_base_when_not_configured(mock_openai_class: MagicMock) -> None:
class ListModelsProvider(BaseOpenAIProvider):
SUPPORTS_LIST_MODELS = True
Expand All @@ -42,7 +42,7 @@ class ListModelsProvider(BaseOpenAIProvider):
PROVIDER_DOCUMENTATION_URL = "https://example.com"
API_BASE = "https://api.default.com/v1"

mock_client = MagicMock()
mock_client = AsyncMock()
mock_client.models.list.return_value.data = []
mock_openai_class.return_value = mock_client

Expand All @@ -54,15 +54,17 @@ class ListModelsProvider(BaseOpenAIProvider):
mock_openai_class.assert_called_once_with(base_url="https://api.default.com/v1", api_key="test-key")


@patch("any_llm.providers.openai.base.OpenAI")
@patch(
"any_llm.providers.openai.base.AsyncOpenAI",
)
def test_list_models_passes_kwargs_to_client(mock_openai_class: MagicMock) -> None:
class ListModelsProvider(BaseOpenAIProvider):
SUPPORTS_LIST_MODELS = True
PROVIDER_NAME = "ListModelsProvider"
ENV_API_KEY_NAME = "TEST_API_KEY"
PROVIDER_DOCUMENTATION_URL = "https://example.com"

mock_client = MagicMock()
mock_client = AsyncMock()
mock_client.models.list.return_value.data = []
mock_openai_class.return_value = mock_client

Expand Down