-
Notifications
You must be signed in to change notification settings - Fork 119
feat: use internal API for more organized implementations #412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
2beb3f2
edc7498
5e4eec8
b601f09
5bf6c18
b02be97
725963b
815d9d7
18b1bd3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| #!/usr/bin/env python3 | ||
| """ | ||
| Script to check which providers are being skipped due to missing API keys. | ||
|
|
||
| This script attempts to instantiate each provider and reports which ones | ||
| fail due to missing API keys, missing packages, or other issues. | ||
| """ | ||
|
|
||
| import os | ||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| from any_llm.exceptions import MissingApiKeyError | ||
| from any_llm.provider import ProviderFactory, ProviderName | ||
|
|
||
| src_path = Path(__file__).parent.parent / "src" | ||
| sys.path.insert(0, str(src_path)) | ||
|
|
||
|
|
||
| def check_provider_status(): | ||
| """Check the status of all providers and categorize them.""" | ||
| available_providers = [] | ||
| missing_api_keys = [] | ||
| missing_packages = [] | ||
| other_errors = [] | ||
|
|
||
| print("Checking provider status...") | ||
| print("=" * 50) | ||
|
|
||
| for provider_name in ProviderName: | ||
| try: | ||
| provider_class = ProviderFactory.get_provider_class(provider_name) | ||
|
|
||
| if provider_class.MISSING_PACKAGES_ERROR is not None: | ||
| missing_packages.append( | ||
| { | ||
| "name": provider_name.value, | ||
| "error": str(provider_class.MISSING_PACKAGES_ERROR), | ||
| "env_var": provider_class.ENV_API_KEY_NAME, | ||
| } | ||
| ) | ||
| continue | ||
|
|
||
| available_providers.append( | ||
| { | ||
| "name": provider_name.value, | ||
| "env_var": provider_class.ENV_API_KEY_NAME, | ||
| "api_key_set": bool(os.getenv(provider_class.ENV_API_KEY_NAME)), | ||
| } | ||
| ) | ||
|
|
||
| except MissingApiKeyError as e: | ||
| missing_api_keys.append({"name": provider_name.value, "env_var": e.env_var_name, "error": str(e)}) | ||
| except ImportError as e: | ||
| missing_packages.append({"name": provider_name.value, "error": str(e), "env_var": "N/A"}) | ||
| except Exception as e: | ||
| other_errors.append({"name": provider_name.value, "error": str(e), "error_type": type(e).__name__}) | ||
|
|
||
| return available_providers, missing_api_keys, missing_packages, other_errors | ||
|
|
||
|
|
||
| def print_results(available_providers, missing_api_keys, missing_packages, other_errors): | ||
| """Print formatted results of the provider status check.""" | ||
|
|
||
| if available_providers: | ||
| print(f"✅ Available Providers ({len(available_providers)}):") | ||
| for provider in sorted(available_providers, key=lambda x: x["name"]): | ||
| key_status = "🔑" if provider["api_key_set"] else "🔓" | ||
| print(f" {key_status} {provider['name']} (env: {provider['env_var']})") | ||
| print() | ||
|
|
||
| if missing_api_keys: | ||
| print(f"🔑 Missing API Keys ({len(missing_api_keys)}):") | ||
| for provider in sorted(missing_api_keys, key=lambda x: x["name"]): | ||
| print(f" ❌ {provider['name']} - Set {provider['env_var']}") | ||
| print() | ||
|
|
||
| if missing_packages: | ||
| print(f"📦 Missing Packages ({len(missing_packages)}):") | ||
| for provider in sorted(missing_packages, key=lambda x: x["name"]): | ||
| print(f" ❌ {provider['name']} - {provider['error']}") | ||
| print() | ||
|
|
||
| if other_errors: | ||
| print(f"⚠️ Other Errors ({len(other_errors)}):") | ||
| for provider in sorted(other_errors, key=lambda x: x["name"]): | ||
| print(f" ❌ {provider['name']} ({provider['error_type']}) - {provider['error']}") | ||
| print() | ||
|
|
||
| total_providers = len(list(ProviderName)) | ||
| available_count = len(available_providers) | ||
| missing_keys_count = len(missing_api_keys) | ||
| missing_packages_count = len(missing_packages) | ||
| other_errors_count = len(other_errors) | ||
|
|
||
| print("📊 Summary:") | ||
| print(f" Total providers: {total_providers}") | ||
| print(f" Available: {available_count}") | ||
| print(f" Missing API keys: {missing_keys_count}") | ||
| print(f" Missing packages: {missing_packages_count}") | ||
| print(f" Other errors: {other_errors_count}") | ||
|
|
||
| if missing_api_keys: | ||
| print("\n💡 To fix missing API keys, set these environment variables:") | ||
| for provider in sorted(missing_api_keys, key=lambda x: x["name"]): | ||
| print(f" export {provider['env_var']}='your-api-key-here'") | ||
|
|
||
|
|
||
| def main(): | ||
| """Main function to run the provider status check.""" | ||
| if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help"]: | ||
| print(__doc__) | ||
| print("\nUsage:") | ||
| print(" python scripts/check_missing_api_keys.py") | ||
| print(" python scripts/check_missing_api_keys.py --help") | ||
| return | ||
|
|
||
| try: | ||
| available, missing_keys, missing_packages, other_errors = check_provider_status() | ||
| print_results(available, missing_keys, missing_packages, other_errors) | ||
| except Exception as e: | ||
| print(f"Error running provider check: {e}") | ||
| sys.exit(1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,8 @@ | ||
| from collections.abc import AsyncIterator, Sequence | ||
| from typing import Any | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from any_llm.provider import Provider | ||
| from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams | ||
| from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse | ||
| from any_llm.types.model import Model | ||
|
|
||
| MISSING_PACKAGES_ERROR = None | ||
|
|
@@ -18,6 +18,11 @@ | |
| except ImportError as e: | ||
| MISSING_PACKAGES_ERROR = e | ||
|
|
||
| if TYPE_CHECKING: | ||
| from anthropic.pagination import SyncPage | ||
| from anthropic.types import Message | ||
| from anthropic.types.model_info import ModelInfo as AnthropicModelInfo | ||
|
|
||
|
|
||
| class AnthropicProvider(Provider): | ||
| """ | ||
|
|
@@ -39,6 +44,39 @@ class AnthropicProvider(Provider): | |
|
|
||
| MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR | ||
|
|
||
| @staticmethod | ||
| def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: | ||
| """Convert CompletionParams to kwargs for Anthropic API.""" | ||
| return _convert_params(params, **kwargs) | ||
|
|
||
| @staticmethod | ||
| def _convert_completion_response(response: "Message") -> ChatCompletion: | ||
| """Convert Anthropic Message to OpenAI ChatCompletion format.""" | ||
| return _convert_response(response) | ||
|
|
||
| @staticmethod | ||
| def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: | ||
| """Convert Anthropic streaming chunk to OpenAI ChatCompletionChunk format.""" | ||
| model_id = kwargs.get("model_id", "unknown") | ||
| return _create_openai_chunk_from_anthropic_chunk(response, model_id) | ||
|
|
||
| @staticmethod | ||
| def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: | ||
| """Anthropic does not support embeddings.""" | ||
| msg = "Anthropic does not support embeddings" | ||
| raise NotImplementedError(msg) | ||
|
|
||
| @staticmethod | ||
| def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: | ||
| """Anthropic does not support embeddings.""" | ||
| msg = "Anthropic does not support embeddings" | ||
| raise NotImplementedError(msg) | ||
|
|
||
| @staticmethod | ||
| def _convert_list_models_response(response: "SyncPage[AnthropicModelInfo]") -> Sequence[Model]: | ||
| """Convert Anthropic models list to OpenAI format.""" | ||
| return _convert_models_list(response) | ||
|
|
||
| async def _stream_completion_async( | ||
| self, client: "AsyncAnthropic", **kwargs: Any | ||
| ) -> AsyncIterator[ChatCompletionChunk]: | ||
|
|
@@ -47,7 +85,7 @@ async def _stream_completion_async( | |
| **kwargs, | ||
| ) as anthropic_stream: | ||
| async for event in anthropic_stream: | ||
| yield _create_openai_chunk_from_anthropic_chunk(event, kwargs.get("model", "unknown")) | ||
| yield self._convert_completion_chunk_response(event, model_id=kwargs.get("model", "unknown")) | ||
|
|
||
| async def acompletion( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we're standardizing API by adding more required methods, should we abstract
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this idea. This PR is already a big boy, so I'll split this off into a separate PR, so that this change doesn't keep exploding into a bigger change.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| self, | ||
|
|
@@ -62,14 +100,14 @@ async def acompletion( | |
| ) | ||
|
|
||
| kwargs["provider_name"] = self.PROVIDER_NAME | ||
| converted_kwargs = _convert_params(params, **kwargs) | ||
| converted_kwargs = self._convert_completion_params(params, **kwargs) | ||
|
|
||
| if converted_kwargs.pop("stream", False): | ||
| return self._stream_completion_async(client, **converted_kwargs) | ||
|
|
||
| message = await client.messages.create(**converted_kwargs) | ||
|
|
||
| return _convert_response(message) | ||
| return self._convert_completion_response(message) | ||
|
|
||
| def list_models(self, **kwargs: Any) -> Sequence[Model]: | ||
| """List available models from Anthropic.""" | ||
|
|
@@ -79,4 +117,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: | |
| **(self.config.client_args if self.config.client_args else {}), | ||
| ) | ||
| models_list = client.models.list(**kwargs) | ||
| return _convert_models_list(models_list) | ||
| return self._convert_list_models_response(models_list) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about moving all the logic from the utils.py file into this function, but I ended up keeping it like this because it seemed smoother and kept the anthropic file from getting too busy 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will not complain about smaller files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that there are some provider-specific quirks, but it would be super great if we could make those happen in an inherited call with stuff before and after a
super().acompletion(*args, **kwargs)