Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/any_llm/providers/gemini/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ class GeminiProvider(GoogleProvider):
PROVIDER_DOCUMENTATION_URL = "https://ai.google.dev/gemini-api/docs"
ENV_API_KEY_NAME = "GEMINI_API_KEY/GOOGLE_API_KEY"

def _get_client(self, config: ClientConfig) -> "genai.Client":
"""Get Gemini API client."""
api_key = getattr(config, "api_key", None) or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
def _verify_and_set_api_key(self, config: ClientConfig) -> ClientConfig:
# Standardized API key handling. Splitting into its own function so that providers
# Can easily override this method if they don't want verification (for instance, LMStudio)
if not config.api_key:
config.api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")

if not api_key:
msg = "Google Gemini Developer API"
raise MissingApiKeyError(msg, "GEMINI_API_KEY/GOOGLE_API_KEY")
if not config.api_key:
raise MissingApiKeyError(self.PROVIDER_NAME, self.ENV_API_KEY_NAME)
return config

return genai.Client(api_key=api_key, **(config.client_args if config.client_args else {}))
def _get_client(self, config: ClientConfig) -> "genai.Client":
"""Get Gemini API client."""
return genai.Client(api_key=config.api_key, **(config.client_args if config.client_args else {}))
15 changes: 15 additions & 0 deletions tests/unit/providers/test_google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ def mock_google_provider(): # type: ignore[no-untyped-def]
yield mock_genai


@pytest.mark.parametrize("env_var", ["GEMINI_API_KEY", "GOOGLE_API_KEY"])
def test_gemini_initialization_with_env_var_api_key(env_var: str) -> None:
"""Test that the provider initializes correctly with API key from environment variable."""
with patch.dict("os.environ", {env_var: "env-api-key"}, clear=True):
provider = GeminiProvider(ClientConfig())
assert provider.config.api_key == "env-api-key"


def test_vertexai_initialization_with_env_var_api_key() -> None:
"""Test that the VertexaiProvider initializes correctly with GOOGLE_PROJECT_ID from environment variable."""
with patch.dict("os.environ", {"GOOGLE_PROJECT_ID": "env-project-id"}, clear=True):
provider = VertexaiProvider(ClientConfig())
assert provider.config.api_key == "env-project-id"


@pytest.mark.asyncio
async def test_completion_with_system_instruction(google_provider_class: type[Provider]) -> None:
"""Test that completion works correctly with system_instruction."""
Expand Down
Loading