Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions marimo/_server/ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,9 @@ class GoogleProvider(
"gemini-2.5-flash",
]

# Keep a persistent async client to avoid closing during stream iteration
_client: Optional[GoogleClient] = None

def is_thinking_model(self, model: str) -> bool:
return any(
model.startswith(prefix) for prefix in self.THINKING_MODEL_PREFIXES
Expand Down Expand Up @@ -832,6 +835,10 @@ def get_client(self, config: AnyProviderConfig) -> GoogleClient:
)
from google import genai # type: ignore

# Reuse a stored async client if already created
if self._client is not None:
return self._client

# If no API key is provided, try to use environment variables and ADC
# This supports Google Vertex AI usage without explicit API keys
if not config.api_key:
Expand All @@ -842,14 +849,18 @@ def get_client(self, config: AnyProviderConfig) -> GoogleClient:
if use_vertex:
project = os.getenv("GOOGLE_CLOUD_PROJECT")
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
return genai.Client(
self._client = genai.Client(
vertexai=True, project=project, location=location
).aio
else:
# Try default initialization which may work with environment variables
return genai.Client().aio
self._client = genai.Client().aio

# Return vertex or default client
return self._client

return genai.Client(api_key=config.api_key).aio
self._client = genai.Client(api_key=config.api_key).aio
return self._client

async def stream_completion(
self,
Expand All @@ -859,7 +870,7 @@ async def stream_completion(
additional_tools: list[ToolDefinition],
) -> AsyncIterator[GenerateContentResponse]:
client = self.get_client(self.config)
return await client.models.generate_content_stream(
return await client.models.generate_content_stream( # type: ignore[reportReturnType]
model=self.model,
contents=convert_to_google_messages(messages),
config=self.get_config(
Expand Down
Loading