diff --git a/marimo/_server/ai/config.py b/marimo/_server/ai/config.py index 7a42876ba79..e7a95579d1e 100644 --- a/marimo/_server/ai/config.py +++ b/marimo/_server/ai/config.py @@ -124,11 +124,16 @@ def _for_openai_like( ai_config, name, fallback_key=fallback_key, require_key=require_key ) + # Use SSL_CERT_FILE environment variable as fallback for ca_bundle_path + ca_bundle_path = ai_config.get("ca_bundle_path") or cls.os_key( + "SSL_CERT_FILE" + ) + kwargs: dict[str, Any] = { "base_url": _get_base_url(ai_config) or fallback_base_url, "api_key": key, "ssl_verify": ai_config.get("ssl_verify", True), - "ca_bundle_path": ai_config.get("ca_bundle_path", None), + "ca_bundle_path": ca_bundle_path, "client_pem": ai_config.get("client_pem", None), "extra_headers": ai_config.get("extra_headers", None), "tools": _get_tools(config.get("mode", "manual")), diff --git a/tests/_server/ai/test_ai_config.py b/tests/_server/ai/test_ai_config.py index 7846b9106bf..1fc78fdd27f 100644 --- a/tests/_server/ai/test_ai_config.py +++ b/tests/_server/ai/test_ai_config.py @@ -974,6 +974,77 @@ def test_for_model_with_autocomplete_model(self) -> None: assert provider_config.tools is None +class TestSSLConfiguration: + """Tests for SSL configuration across all OpenAI-like providers.""" + + @pytest.mark.parametrize( + ("provider_name", "provider_method", "api_key_config"), + [ + ("openai", "for_openai", {"open_ai": {"api_key": "test-key"}}), + ("github", "for_github", {"github": {"api_key": "test-key"}}), + ("ollama", "for_ollama", {"ollama": {"api_key": "test-key"}}), + ], + ) + def test_ssl_config_from_provider_config( + self, + provider_name: str, + provider_method: str, + api_key_config: AiConfig, + ) -> None: + """Test SSL configuration is read from provider config.""" + # Get the provider key from api_key_config + provider_key = next(iter(api_key_config.keys())) + + config: AiConfig = { + **api_key_config, + } + config[provider_key]["ssl_verify"] = False + config[provider_key]["ca_bundle_path"] = "/custom/path/to/ca.pem" + config[provider_key]["client_pem"] = "/custom/path/to/client.pem" + config[provider_key]["extra_headers"] = {"X-Custom": "header"} + + method = getattr(AnyProviderConfig, provider_method) + provider_config = method(config) + + assert provider_config.ssl_verify is False, ( + f"{provider_name}: ssl_verify should be False" + ) + assert provider_config.ca_bundle_path == "/custom/path/to/ca.pem", ( + f"{provider_name}: ca_bundle_path should match" + ) + assert provider_config.client_pem == "/custom/path/to/client.pem", ( + f"{provider_name}: client_pem should match" + ) + assert provider_config.extra_headers == {"X-Custom": "header"}, ( + f"{provider_name}: extra_headers should match" + ) + + @pytest.mark.parametrize( + ("provider_name", "provider_method", "api_key_config"), + [ + ("openai", "for_openai", {"open_ai": {"api_key": "test-key"}}), + ("github", "for_github", {"github": {"api_key": "test-key"}}), + ("ollama", "for_ollama", {"ollama": {"api_key": "test-key"}}), + ], + ) + @patch.dict(os.environ, {"SSL_CERT_FILE": "/env/path/to/ca.pem"}) + def test_ssl_cert_file_fallback( + self, + provider_name: str, + provider_method: str, + api_key_config: AiConfig, + ) -> None: + """Test SSL_CERT_FILE environment variable is used as fallback.""" + config: AiConfig = {**api_key_config} + + method = getattr(AnyProviderConfig, provider_method) + provider_config = method(config) + + assert provider_config.ca_bundle_path == "/env/path/to/ca.pem", ( + f"{provider_name}: should use SSL_CERT_FILE env var as fallback" + ) + + class TestEdgeCases: """Tests for edge cases and error conditions."""