Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion marimo/_server/ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,16 @@ def _for_openai_like(
ai_config, name, fallback_key=fallback_key, require_key=require_key
)

# Use SSL_CERT_FILE environment variable as fallback for ca_bundle_path
ca_bundle_path = ai_config.get("ca_bundle_path") or cls.os_key(
"SSL_CERT_FILE"
)

kwargs: dict[str, Any] = {
"base_url": _get_base_url(ai_config) or fallback_base_url,
"api_key": key,
"ssl_verify": ai_config.get("ssl_verify", True),
"ca_bundle_path": ai_config.get("ca_bundle_path", None),
"ca_bundle_path": ca_bundle_path,
"client_pem": ai_config.get("client_pem", None),
"extra_headers": ai_config.get("extra_headers", None),
"tools": _get_tools(config.get("mode", "manual")),
Expand Down
71 changes: 71 additions & 0 deletions tests/_server/ai/test_ai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,77 @@ def test_for_model_with_autocomplete_model(self) -> None:
assert provider_config.tools is None


class TestSSLConfiguration:
"""Tests for SSL configuration across all OpenAI-like providers."""

@pytest.mark.parametrize(
("provider_name", "provider_method", "api_key_config"),
[
("openai", "for_openai", {"open_ai": {"api_key": "test-key"}}),
("github", "for_github", {"github": {"api_key": "test-key"}}),
("ollama", "for_ollama", {"ollama": {"api_key": "test-key"}}),
],
)
def test_ssl_config_from_provider_config(
self,
provider_name: str,
provider_method: str,
api_key_config: AiConfig,
) -> None:
"""Test SSL configuration is read from provider config."""
# Get the provider key from api_key_config
provider_key = next(iter(api_key_config.keys()))

config: AiConfig = {
**api_key_config,
}
config[provider_key]["ssl_verify"] = False
config[provider_key]["ca_bundle_path"] = "/custom/path/to/ca.pem"
config[provider_key]["client_pem"] = "/custom/path/to/client.pem"
config[provider_key]["extra_headers"] = {"X-Custom": "header"}

method = getattr(AnyProviderConfig, provider_method)
provider_config = method(config)

assert provider_config.ssl_verify is False, (
f"{provider_name}: ssl_verify should be False"
)
assert provider_config.ca_bundle_path == "/custom/path/to/ca.pem", (
f"{provider_name}: ca_bundle_path should match"
)
assert provider_config.client_pem == "/custom/path/to/client.pem", (
f"{provider_name}: client_pem should match"
)
assert provider_config.extra_headers == {"X-Custom": "header"}, (
f"{provider_name}: extra_headers should match"
)

@pytest.mark.parametrize(
("provider_name", "provider_method", "api_key_config"),
[
("openai", "for_openai", {"open_ai": {"api_key": "test-key"}}),
("github", "for_github", {"github": {"api_key": "test-key"}}),
("ollama", "for_ollama", {"ollama": {"api_key": "test-key"}}),
],
)
@patch.dict(os.environ, {"SSL_CERT_FILE": "/env/path/to/ca.pem"})
def test_ssl_cert_file_fallback(
self,
provider_name: str,
provider_method: str,
api_key_config: AiConfig,
) -> None:
"""Test SSL_CERT_FILE environment variable is used as fallback."""
config: AiConfig = {**api_key_config}

method = getattr(AnyProviderConfig, provider_method)
provider_config = method(config)

assert provider_config.ca_bundle_path == "/env/path/to/ca.pem", (
f"{provider_name}: should use SSL_CERT_FILE env var as fallback"
)


class TestEdgeCases:
"""Tests for edge cases and error conditions."""

Expand Down
Loading