Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies]

all = [
"any-llm-sdk[mistral,anthropic,huggingface,google,cohere,cerebras,fireworks,groq,aws,azure,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,llamafile,llamacpp]"
"any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,cohere,cerebras,fireworks,groq,aws,azure,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,llamafile,llamacpp]"
]

perplexity = []
Expand All @@ -32,7 +32,11 @@ anthropic = [
"anthropic",
]

google = [
gemini = [
"google-genai",
]

vertexai = [
"google-genai",
]

Expand Down
3 changes: 2 additions & 1 deletion src/any_llm/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ProviderName(StrEnum):
DATABRICKS = "databricks"
DEEPSEEK = "deepseek"
FIREWORKS = "fireworks"
GOOGLE = "google"
GEMINI = "gemini"
GROQ = "groq"
HUGGINGFACE = "huggingface"
INCEPTION = "inception"
Expand All @@ -60,6 +60,7 @@ class ProviderName(StrEnum):
PORTKEY = "portkey"
SAMBANOVA = "sambanova"
TOGETHER = "together"
VERTEXAI = "vertexai"
VOYAGE = "voyage"
WATSONX = "watsonx"
XAI = "xai"
Expand Down
3 changes: 3 additions & 0 deletions src/any_llm/providers/gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gemini import GeminiProvider

__all__ = ["GeminiProvider"]
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from abc import abstractmethod
from collections.abc import AsyncIterator, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any, Literal, cast

from pydantic import BaseModel

from any_llm.exceptions import MissingApiKeyError, UnsupportedParameterError
from any_llm.exceptions import UnsupportedParameterError
from any_llm.provider import ClientConfig, Provider
from any_llm.types.completion import (
ChatCompletion,
Expand Down Expand Up @@ -38,16 +38,14 @@
except ImportError as e:
MISSING_PACKAGES_ERROR = e

# From https://ai.google.dev/gemini-api/docs/openai#thinking
if TYPE_CHECKING:
from google import genai

REASONING_EFFORT_TO_THINKING_BUDGETS = {"minimal": 256, "low": 1024, "medium": 8192, "high": 24576}


class GoogleProvider(Provider):
"""Google Provider using the new response conversion utilities."""

PROVIDER_NAME = "google"
PROVIDER_DOCUMENTATION_URL = "https://cloud.google.com/vertex-ai/docs"
ENV_API_KEY_NAME = "GOOGLE_API_KEY/GEMINI_API_KEY"
"""Base Google Provider class with common functionality for Gemini and Vertex AI."""

SUPPORTS_COMPLETION_STREAMING = True
SUPPORTS_COMPLETION = True
Expand All @@ -58,43 +56,17 @@ class GoogleProvider(Provider):

MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR

def __init__(self, config: ClientConfig) -> None:
"""Initialize Google GenAI provider."""
self._verify_no_missing_packages()
self.config = config
self.use_vertex_ai = os.getenv("GOOGLE_USE_VERTEX_AI", "false").lower() == "true"

def _get_client(self, use_vertex_ai: bool, config: ClientConfig) -> "genai.Client":
if use_vertex_ai:
project_id = os.getenv("GOOGLE_PROJECT_ID")
location = os.getenv("GOOGLE_REGION", "us-central1")

if not project_id:
msg = "Google Vertex AI"
raise MissingApiKeyError(msg, "GOOGLE_PROJECT_ID")

return genai.Client(
vertexai=True,
project=project_id,
location=location,
**(config.client_args if config.client_args else {}),
)

api_key = getattr(config, "api_key", None) or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")

if not api_key:
msg = "Google Gemini Developer API"
raise MissingApiKeyError(msg, "GEMINI_API_KEY/GOOGLE_API_KEY")

return genai.Client(api_key=api_key, **(config.client_args if config.client_args else {}))
@abstractmethod
def _get_client(self, config: ClientConfig) -> "genai.Client":
"""Get the appropriate client for this provider implementation."""

async def aembedding(
self,
model: str,
inputs: str | list[str],
**kwargs: Any,
) -> CreateEmbeddingResponse:
client = self._get_client(self.use_vertex_ai, self.config)
client = self._get_client(self.config)
result = await client.aio.models.embed_content(
model=model,
contents=inputs, # type: ignore[arg-type]
Expand Down Expand Up @@ -125,15 +97,13 @@ async def acompletion(

if params.reasoning_effort is None:
kwargs["thinking_config"] = types.ThinkingConfig(include_thoughts=False)
# in "auto" mode, we just don't pass a `thinking_config`
elif params.reasoning_effort != "auto":
kwargs["thinking_config"] = types.ThinkingConfig(
include_thoughts=True, thinking_budget=REASONING_EFFORT_TO_THINKING_BUDGETS[params.reasoning_effort]
)

stream = bool(params.stream)
response_format = params.response_format
# Build generation config without duplicating keys (e.g., tools)
base_kwargs = params.model_dump(
exclude_none=True,
exclude={
Expand All @@ -148,7 +118,6 @@ async def acompletion(
},
)

# Convert max_tokens to max_output_tokens for Google
if params.max_tokens is not None:
base_kwargs["max_output_tokens"] = params.max_tokens

Expand All @@ -162,7 +131,7 @@ async def acompletion(
if system_instruction:
generation_config.system_instruction = system_instruction

client = self._get_client(self.use_vertex_ai, self.config)
client = self._get_client(self.config)
if stream:
response_stream = await client.aio.models.generate_content_stream(
model=params.model_id,
Expand All @@ -184,7 +153,6 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]:

response_dict = _convert_response_to_response_dict(response)

# Directly construct ChatCompletion
choices_out: list[Choice] = []
for i, choice_item in enumerate(response_dict.get("choices", [])):
message_dict: dict[str, Any] = choice_item.get("message", {})
Expand All @@ -211,8 +179,6 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]:
tool_calls=tool_calls,
reasoning=Reasoning(content=reasoning_content) if reasoning_content else None,
)
from typing import Literal, cast

choices_out.append(
Choice(
index=i,
Expand Down Expand Up @@ -241,9 +207,7 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]:
)

def list_models(self, **kwargs: Any) -> Sequence[Model]:
"""
Fetch available models from the /v1/models endpoint.
"""
client = self._get_client(self.use_vertex_ai, self.config)
"""Fetch available models from the /v1/models endpoint."""
client = self._get_client(self.config)
models_list = client.models.list(**kwargs)
return _convert_models_list(models_list)
26 changes: 26 additions & 0 deletions src/any_llm/providers/gemini/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

from google import genai

from any_llm.exceptions import MissingApiKeyError
from any_llm.provider import ClientConfig

from .base import GoogleProvider


class GeminiProvider(GoogleProvider):
"""Gemini Provider using the Google GenAI Developer API."""

PROVIDER_NAME = "gemini"
PROVIDER_DOCUMENTATION_URL = "https://ai.google.dev/gemini-api/docs"
ENV_API_KEY_NAME = "GEMINI_API_KEY/GOOGLE_API_KEY"

def _get_client(self, config: ClientConfig) -> "genai.Client":
"""Get Gemini API client."""
api_key = getattr(config, "api_key", None) or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")

if not api_key:
msg = "Google Gemini Developer API"
raise MissingApiKeyError(msg, "GEMINI_API_KEY/GOOGLE_API_KEY")

return genai.Client(api_key=api_key, **(config.client_args if config.client_args else {}))
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def _convert_tool_spec(openai_tools: list[dict[str, Any]]) -> list[types.Tool]:
continue

function = tool["function"]
# Preserve nested schema details such as items/additionalProperties for arrays/objects
properties: dict[str, dict[str, Any]] = {}
for param_name, param_info in function["parameters"]["properties"].items():
prop: dict[str, Any] = {
Expand All @@ -35,12 +34,10 @@ def _convert_tool_spec(openai_tools: list[dict[str, Any]]) -> list[types.Tool]:
}
if "enum" in param_info:
prop["enum"] = param_info["enum"]
# Google requires explicit items for arrays
if "items" in param_info:
prop["items"] = param_info["items"]
if prop.get("type") == "array" and "items" not in prop:
prop["items"] = {"type": "string"}
# Google tool schema does not accept additionalProperties; drop it
properties[param_name] = prop

parameters_dict = {
Expand Down Expand Up @@ -85,7 +82,7 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[list[types.Conten
formatted_messages.append(types.Content(role="user", parts=parts))
elif message["role"] == "assistant":
if message.get("tool_calls"):
tool_call = message["tool_calls"][0] # Assuming single function call for now
tool_call = message["tool_calls"][0]
function_call = tool_call["function"]

parts = [
Expand Down Expand Up @@ -202,7 +199,6 @@ def _create_openai_embedding_response_from_google(
if embedding.values
]

# Google does not provide usage data in the embedding response
usage = Usage(prompt_tokens=0, total_tokens=0)

return CreateEmbeddingResponse(
Expand All @@ -228,10 +224,8 @@ def _create_openai_chunk_from_google_chunk(

for part in candidate.content.parts:
if part.thought:
# This is a thinking/reasoning part
reasoning_content += part.text or ""
else:
# Regular content part
content += part.text or ""

delta = ChoiceDelta(
Expand All @@ -247,7 +241,7 @@ def _create_openai_chunk_from_google_chunk(
)

return ChatCompletionChunk(
id=f"chatcmpl-{time()}", # Google doesn't provide an ID in the chunk
id=f"chatcmpl-{time()}",
choices=[choice],
created=int(time()),
model=str(response.model_version),
Expand All @@ -256,5 +250,4 @@ def _create_openai_chunk_from_google_chunk(


def _convert_models_list(models_list: Pager[types.Model]) -> list[Model]:
# Google doesn't provide a creation date for models
return [Model(id=model.name or "Unknown", object="model", created=0, owned_by="google") for model in models_list]
3 changes: 0 additions & 3 deletions src/any_llm/providers/google/__init__.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/any_llm/providers/vertexai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .vertexai import VertexaiProvider

__all__ = ["VertexaiProvider"]
35 changes: 35 additions & 0 deletions src/any_llm/providers/vertexai/vertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
from typing import TYPE_CHECKING

from any_llm.exceptions import MissingApiKeyError
from any_llm.provider import ClientConfig
from any_llm.providers.gemini.base import GoogleProvider

if TYPE_CHECKING:
from google import genai


class VertexaiProvider(GoogleProvider):
"""Vertex AI Provider using Google Cloud Vertex AI."""

PROVIDER_NAME = "vertexai"
PROVIDER_DOCUMENTATION_URL = "https://cloud.google.com/vertex-ai/docs"
ENV_API_KEY_NAME = "GOOGLE_PROJECT_ID"

def _get_client(self, config: ClientConfig) -> "genai.Client":
"""Get Vertex AI client."""
from google import genai

project_id = os.getenv("GOOGLE_PROJECT_ID")
location = os.getenv("GOOGLE_REGION", "us-central1")

if not project_id:
msg = "Google Vertex AI"
raise MissingApiKeyError(msg, "GOOGLE_PROJECT_ID")

return genai.Client(
vertexai=True,
project=project_id,
location=location,
**(config.client_args if config.client_args else {}),
)
9 changes: 6 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ def provider_reasoning_model_map() -> dict[ProviderName, str]:
return {
ProviderName.ANTHROPIC: "claude-sonnet-4-20250514",
ProviderName.MISTRAL: "magistral-small-latest",
ProviderName.GOOGLE: "gemini-2.5-flash",
ProviderName.GEMINI: "gemini-2.5-flash",
ProviderName.VERTEXAI: "gemini-2.5-flash",
ProviderName.GROQ: "openai/gpt-oss-20b",
ProviderName.FIREWORKS: "accounts/fireworks/models/deepseek-r1",
ProviderName.OPENAI: "gpt-5-nano",
Expand All @@ -33,7 +34,8 @@ def provider_model_map() -> dict[ProviderName, str]:
ProviderName.DATABRICKS: "databricks-meta-llama-3-1-8b-instruct",
ProviderName.DEEPSEEK: "deepseek-chat",
ProviderName.OPENAI: "gpt-5-nano",
ProviderName.GOOGLE: "gemini-2.5-flash",
ProviderName.GEMINI: "gemini-2.5-flash",
ProviderName.VERTEXAI: "gemini-2.5-flash",
ProviderName.MOONSHOT: "moonshot-v1-8k",
ProviderName.SAMBANOVA: "Meta-Llama-3.1-8B-Instruct",
ProviderName.TOGETHER: "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
Expand Down Expand Up @@ -73,7 +75,8 @@ def embedding_provider_model_map() -> dict[ProviderName, str]:
ProviderName.OLLAMA: "gpt-oss:20b",
ProviderName.LLAMAFILE: "N/A",
ProviderName.LMSTUDIO: "text-embedding-nomic-embed-text-v1.5",
ProviderName.GOOGLE: "gemini-embedding-001",
ProviderName.GEMINI: "gemini-embedding-001",
ProviderName.VERTEXAI: "gemini-embedding-001",
ProviderName.AZURE: "openai/text-embedding-3-small",
ProviderName.AZUREOPENAI: "azure/<your_deployment_name>",
ProviderName.VOYAGE: "voyage-3.5-lite",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ async def test_embedding_providers_async(
assert len(result.data) > 0
for entry in result.data:
assert all(isinstance(v, float) for v in entry.embedding)
if provider not in (ProviderName.GOOGLE, ProviderName.LMSTUDIO):
if provider not in (ProviderName.GEMINI, ProviderName.VERTEXAI, ProviderName.LMSTUDIO):
assert result.usage.prompt_tokens > 0
assert result.usage.total_tokens > 0
4 changes: 2 additions & 2 deletions tests/integration/test_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def test_completion_reasoning(

model_id = provider_reasoning_model_map[provider]
extra_kwargs = provider_extra_kwargs_map.get(provider, {})
if provider in (ProviderName.ANTHROPIC, ProviderName.GOOGLE, ProviderName.OLLAMA):
if provider in (ProviderName.ANTHROPIC, ProviderName.GEMINI, ProviderName.VERTEXAI, ProviderName.OLLAMA):
extra_kwargs["reasoning_effort"] = "low"

try:
Expand Down Expand Up @@ -62,7 +62,7 @@ async def test_completion_reasoning_streaming(

model_id = provider_reasoning_model_map[provider]
extra_kwargs = provider_extra_kwargs_map.get(provider, {})
if provider in (ProviderName.ANTHROPIC, ProviderName.GOOGLE, ProviderName.OLLAMA):
if provider in (ProviderName.ANTHROPIC, ProviderName.GEMINI, ProviderName.VERTEXAI, ProviderName.OLLAMA):
extra_kwargs["reasoning_effort"] = "low"

try:
Expand Down
Loading
Loading