Skip to content

Commit c672fe2

Browse files
authored
feat: use internal API for more organized implementations (#412)
## Description <!-- What does this PR do? --> Better structuring for providers where things are organized into shared function definitions ## PR Type <!-- Delete the types that don't apply --!> 💅 Refactor ## Relevant issues <!-- e.g. "Fixes #123" --> ## Checklist - [x] I have added unit tests that prove my fix/feature works - [x] New and existing tests pass locally - [x] Documentation was updated where necessary - [x] I have read and followed the [contribution guidelines](https://github.com/mozilla-ai/any-llm/blob/main/CONTRIBUTING.md)```
1 parent 3657e5a commit c672fe2

File tree

19 files changed

+1062
-297
lines changed

19 files changed

+1062
-297
lines changed

scripts/check_missing_api_keys.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Script to check which providers are being skipped due to missing API keys.
4+
5+
This script attempts to instantiate each provider and reports which ones
6+
fail due to missing API keys, missing packages, or other issues.
7+
"""
8+
9+
import os
10+
import sys
11+
from pathlib import Path
12+
13+
from any_llm.exceptions import MissingApiKeyError
14+
from any_llm.provider import ProviderFactory, ProviderName
15+
16+
src_path = Path(__file__).parent.parent / "src"
17+
sys.path.insert(0, str(src_path))
18+
19+
20+
def check_provider_status():
21+
"""Check the status of all providers and categorize them."""
22+
available_providers = []
23+
missing_api_keys = []
24+
missing_packages = []
25+
other_errors = []
26+
27+
print("Checking provider status...")
28+
print("=" * 50)
29+
30+
for provider_name in ProviderName:
31+
try:
32+
provider_class = ProviderFactory.get_provider_class(provider_name)
33+
34+
if provider_class.MISSING_PACKAGES_ERROR is not None:
35+
missing_packages.append(
36+
{
37+
"name": provider_name.value,
38+
"error": str(provider_class.MISSING_PACKAGES_ERROR),
39+
"env_var": provider_class.ENV_API_KEY_NAME,
40+
}
41+
)
42+
continue
43+
44+
available_providers.append(
45+
{
46+
"name": provider_name.value,
47+
"env_var": provider_class.ENV_API_KEY_NAME,
48+
"api_key_set": bool(os.getenv(provider_class.ENV_API_KEY_NAME)),
49+
}
50+
)
51+
52+
except MissingApiKeyError as e:
53+
missing_api_keys.append({"name": provider_name.value, "env_var": e.env_var_name, "error": str(e)})
54+
except ImportError as e:
55+
missing_packages.append({"name": provider_name.value, "error": str(e), "env_var": "N/A"})
56+
except Exception as e:
57+
other_errors.append({"name": provider_name.value, "error": str(e), "error_type": type(e).__name__})
58+
59+
return available_providers, missing_api_keys, missing_packages, other_errors
60+
61+
62+
def print_results(available_providers, missing_api_keys, missing_packages, other_errors):
63+
"""Print formatted results of the provider status check."""
64+
65+
if available_providers:
66+
print(f"✅ Available Providers ({len(available_providers)}):")
67+
for provider in sorted(available_providers, key=lambda x: x["name"]):
68+
key_status = "🔑" if provider["api_key_set"] else "🔓"
69+
print(f" {key_status} {provider['name']} (env: {provider['env_var']})")
70+
print()
71+
72+
if missing_api_keys:
73+
print(f"🔑 Missing API Keys ({len(missing_api_keys)}):")
74+
for provider in sorted(missing_api_keys, key=lambda x: x["name"]):
75+
print(f" ❌ {provider['name']} - Set {provider['env_var']}")
76+
print()
77+
78+
if missing_packages:
79+
print(f"📦 Missing Packages ({len(missing_packages)}):")
80+
for provider in sorted(missing_packages, key=lambda x: x["name"]):
81+
print(f" ❌ {provider['name']} - {provider['error']}")
82+
print()
83+
84+
if other_errors:
85+
print(f"⚠️ Other Errors ({len(other_errors)}):")
86+
for provider in sorted(other_errors, key=lambda x: x["name"]):
87+
print(f" ❌ {provider['name']} ({provider['error_type']}) - {provider['error']}")
88+
print()
89+
90+
total_providers = len(list(ProviderName))
91+
available_count = len(available_providers)
92+
missing_keys_count = len(missing_api_keys)
93+
missing_packages_count = len(missing_packages)
94+
other_errors_count = len(other_errors)
95+
96+
print("📊 Summary:")
97+
print(f" Total providers: {total_providers}")
98+
print(f" Available: {available_count}")
99+
print(f" Missing API keys: {missing_keys_count}")
100+
print(f" Missing packages: {missing_packages_count}")
101+
print(f" Other errors: {other_errors_count}")
102+
103+
if missing_api_keys:
104+
print("\n💡 To fix missing API keys, set these environment variables:")
105+
for provider in sorted(missing_api_keys, key=lambda x: x["name"]):
106+
print(f" export {provider['env_var']}='your-api-key-here'")
107+
108+
109+
def main():
110+
"""Main function to run the provider status check."""
111+
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help"]:
112+
print(__doc__)
113+
print("\nUsage:")
114+
print(" python scripts/check_missing_api_keys.py")
115+
print(" python scripts/check_missing_api_keys.py --help")
116+
return
117+
118+
try:
119+
available, missing_keys, missing_packages, other_errors = check_provider_status()
120+
print_results(available, missing_keys, missing_packages, other_errors)
121+
except Exception as e:
122+
print(f"Error running provider check: {e}")
123+
sys.exit(1)
124+
125+
126+
if __name__ == "__main__":
127+
main()

src/any_llm/provider.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,42 @@ def _verify_and_set_api_key(self, config: ClientConfig) -> ClientConfig:
153153
raise MissingApiKeyError(self.PROVIDER_NAME, self.ENV_API_KEY_NAME)
154154
return config
155155

156+
@staticmethod
157+
@abstractmethod
158+
def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]:
159+
msg = "Subclasses must implement this method"
160+
raise NotImplementedError(msg)
161+
162+
@staticmethod
163+
@abstractmethod
164+
def _convert_completion_response(response: Any) -> ChatCompletion:
165+
msg = "Subclasses must implement this method"
166+
raise NotImplementedError(msg)
167+
168+
@staticmethod
169+
@abstractmethod
170+
def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk:
171+
msg = "Subclasses must implement this method"
172+
raise NotImplementedError(msg)
173+
174+
@staticmethod
175+
@abstractmethod
176+
def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]:
177+
msg = "Subclasses must implement this method"
178+
raise NotImplementedError(msg)
179+
180+
@staticmethod
181+
@abstractmethod
182+
def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse:
183+
msg = "Subclasses must implement this method"
184+
raise NotImplementedError(msg)
185+
186+
@staticmethod
187+
@abstractmethod
188+
def _convert_list_models_response(response: Any) -> Sequence[Model]:
189+
msg = "Subclasses must implement this method"
190+
raise NotImplementedError(msg)
191+
156192
@classmethod
157193
def get_provider_metadata(cls) -> ProviderMetadata:
158194
"""Get provider metadata without requiring instantiation.

src/any_llm/providers/anthropic/anthropic.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from collections.abc import AsyncIterator, Sequence
2-
from typing import Any
2+
from typing import TYPE_CHECKING, Any
33

44
from any_llm.provider import Provider
5-
from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams
5+
from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse
66
from any_llm.types.model import Model
77

88
MISSING_PACKAGES_ERROR = None
@@ -18,6 +18,11 @@
1818
except ImportError as e:
1919
MISSING_PACKAGES_ERROR = e
2020

21+
if TYPE_CHECKING:
22+
from anthropic.pagination import SyncPage
23+
from anthropic.types import Message
24+
from anthropic.types.model_info import ModelInfo as AnthropicModelInfo
25+
2126

2227
class AnthropicProvider(Provider):
2328
"""
@@ -40,6 +45,39 @@ class AnthropicProvider(Provider):
4045

4146
MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR
4247

48+
@staticmethod
49+
def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]:
50+
"""Convert CompletionParams to kwargs for Anthropic API."""
51+
return _convert_params(params, **kwargs)
52+
53+
@staticmethod
54+
def _convert_completion_response(response: "Message") -> ChatCompletion:
55+
"""Convert Anthropic Message to OpenAI ChatCompletion format."""
56+
return _convert_response(response)
57+
58+
@staticmethod
59+
def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk:
60+
"""Convert Anthropic streaming chunk to OpenAI ChatCompletionChunk format."""
61+
model_id = kwargs.get("model_id", "unknown")
62+
return _create_openai_chunk_from_anthropic_chunk(response, model_id)
63+
64+
@staticmethod
65+
def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]:
66+
"""Anthropic does not support embeddings."""
67+
msg = "Anthropic does not support embeddings"
68+
raise NotImplementedError(msg)
69+
70+
@staticmethod
71+
def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse:
72+
"""Anthropic does not support embeddings."""
73+
msg = "Anthropic does not support embeddings"
74+
raise NotImplementedError(msg)
75+
76+
@staticmethod
77+
def _convert_list_models_response(response: "SyncPage[AnthropicModelInfo]") -> Sequence[Model]:
78+
"""Convert Anthropic models list to OpenAI format."""
79+
return _convert_models_list(response)
80+
4381
async def _stream_completion_async(
4482
self, client: "AsyncAnthropic", **kwargs: Any
4583
) -> AsyncIterator[ChatCompletionChunk]:
@@ -48,7 +86,7 @@ async def _stream_completion_async(
4886
**kwargs,
4987
) as anthropic_stream:
5088
async for event in anthropic_stream:
51-
yield _create_openai_chunk_from_anthropic_chunk(event, kwargs.get("model", "unknown"))
89+
yield self._convert_completion_chunk_response(event, model_id=kwargs.get("model", "unknown"))
5290

5391
async def acompletion(
5492
self,
@@ -63,14 +101,14 @@ async def acompletion(
63101
)
64102

65103
kwargs["provider_name"] = self.PROVIDER_NAME
66-
converted_kwargs = _convert_params(params, **kwargs)
104+
converted_kwargs = self._convert_completion_params(params, **kwargs)
67105

68106
if converted_kwargs.pop("stream", False):
69107
return self._stream_completion_async(client, **converted_kwargs)
70108

71109
message = await client.messages.create(**converted_kwargs)
72110

73-
return _convert_response(message)
111+
return self._convert_completion_response(message)
74112

75113
def list_models(self, **kwargs: Any) -> Sequence[Model]:
76114
"""List available models from Anthropic."""
@@ -80,4 +118,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]:
80118
**(self.config.client_args if self.config.client_args else {}),
81119
)
82120
models_list = client.models.list(**kwargs)
83-
return _convert_models_list(models_list)
121+
return self._convert_list_models_response(models_list)

0 commit comments

Comments
 (0)