Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions marimo/_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,36 +726,4 @@ def merge_config(
):
merged["runtime"]["auto_reload"] = "lazy"

# If missing ai.models.chat_model or ai.models.edit_model, use ai.open_ai.model
openai_model = merged.get("ai", {}).get("open_ai", {}).get("model")
chat_model = merged.get("ai", {}).get("models", {}).get("chat_model")
edit_model = merged.get("ai", {}).get("models", {}).get("edit_model")
if not chat_model and not edit_model and openai_model:
merged_ai_config = cast(dict[Any, Any], merged.get("ai", {}))
models_config = {
"models": {
"chat_model": chat_model or openai_model,
"edit_model": edit_model or openai_model,
}
}
merged["ai"] = cast(
AiConfig, deep_merge(merged_ai_config, models_config)
)

# Migrate completion.model to ai.models.autocomplete_model
completion_model = merged.get("completion", {}).get("model")
autocomplete_model = (
merged.get("ai", {}).get("models", {}).get("autocomplete_model")
)
if completion_model and not autocomplete_model:
merged_ai_config = cast(dict[Any, Any], merged.get("ai", {}))
models_config = {
"models": {
"autocomplete_model": completion_model,
}
}
merged["ai"] = cast(
AiConfig, deep_merge(merged_ai_config, models_config)
)

return merged
38 changes: 29 additions & 9 deletions marimo/_server/ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,22 @@ def __post_init__(self) -> None:
def for_openai(cls, config: AiConfig) -> AnyProviderConfig:
fallback_key = cls.os_key("OPENAI_API_KEY")
return cls._for_openai_like(
config, "open_ai", "OpenAI", fallback_key=fallback_key
config,
"open_ai",
"OpenAI",
fallback_key=fallback_key,
require_key=True,
)

@classmethod
def for_azure(cls, config: AiConfig) -> AnyProviderConfig:
fallback_key = cls.os_key("AZURE_API_KEY")
return cls._for_openai_like(
config, "azure", "Azure OpenAI", fallback_key=fallback_key
config,
"azure",
"Azure OpenAI",
fallback_key=fallback_key,
require_key=True,
)

@classmethod
Expand Down Expand Up @@ -84,6 +92,7 @@ def for_github(cls, config: AiConfig) -> AnyProviderConfig:
fallback_key=fallback_key,
# Default base URL for GitHub Copilot
fallback_base_url="https://api.githubcopilot.com/",
require_key=True,
)

@classmethod
Expand All @@ -95,9 +104,12 @@ def _for_openai_like(
*,
fallback_key: Optional[str] = None,
fallback_base_url: Optional[str] = None,
require_key: bool = False,
) -> AnyProviderConfig:
ai_config = _get_ai_config(config, key)
key = _get_key(ai_config, name, fallback_key=fallback_key)
ai_config: dict[str, Any] = _get_ai_config(config, key)
key = _get_key(
ai_config, name, fallback_key=fallback_key, require_key=require_key
)

kwargs: dict[str, Any] = {
"base_url": _get_base_url(ai_config) or fallback_base_url,
Expand All @@ -119,6 +131,7 @@ def for_anthropic(cls, config: AiConfig) -> AnyProviderConfig:
ai_config,
"Anthropic",
fallback_key=fallback_key,
require_key=True,
)
return cls(
base_url=_get_base_url(ai_config),
Expand All @@ -136,6 +149,7 @@ def for_google(cls, config: AiConfig) -> AnyProviderConfig:
ai_config,
"Google AI",
fallback_key=fallback_key,
require_key=True,
)
return cls(
base_url=_get_base_url(ai_config),
Expand Down Expand Up @@ -175,7 +189,9 @@ def for_model(cls, model: str, config: AiConfig) -> AnyProviderConfig:
else:
# Catch-all: try OpenAI compatible first, then OpenAI.
try:
return cls.for_openai_compatible(config)
if "open_ai_compatible" in config:
return cls.for_openai_compatible(config)
return cls.for_openai(config)
except HTTPException:
return cls.for_openai(config)

Expand Down Expand Up @@ -243,6 +259,7 @@ def _get_key(
name: str,
*,
fallback_key: Optional[str] = None,
require_key: bool = False,
) -> str:
"""Get the API key for a given provider."""
if not isinstance(config, dict):
Expand Down Expand Up @@ -277,10 +294,13 @@ def _get_key(
if fallback_key:
return fallback_key

raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"{name} API key not configured. Go to Settings > AI to configure.",
)
if require_key:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"{name} API key not configured. Go to Settings > AI to configure.",
)

return ""


def _get_base_url(config: Any, name: str = "") -> Optional[str]:
Expand Down
254 changes: 0 additions & 254 deletions tests/_config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,257 +140,3 @@ def test_merge_config_with_keymap_overrides() -> None:

assert new_config["keymap"]["preset"] == "vim"
assert new_config.get("keymap", {}).get("overrides", {}) == {}


def test_merge_config_openai_model_fallback() -> None:
"""Test that openai_model is used as fallback for missing chat_model and edit_model."""
base_config = merge_default_config(PartialMarimoConfig())

new_config = merge_config(
base_config,
PartialMarimoConfig(
ai={
"open_ai": {
"api_key": "test_key",
"model": "gpt-4",
},
}
),
)

# When neither chat_model nor edit_model are present, openai_model should be used for both
assert (
new_config.get("ai", {}).get("models", {}).get("chat_model") == "gpt-4"
)
assert (
new_config.get("ai", {}).get("models", {}).get("edit_model") == "gpt-4"
)
assert new_config.get("ai", {}).get("open_ai", {}).get("model") == "gpt-4"


def test_merge_config_existing_chat_edit_models_no_fallback() -> None:
"""Test that existing chat_model and edit_model are preserved (no fallback)."""
base_config = merge_default_config(PartialMarimoConfig())

new_config = merge_config(
base_config,
PartialMarimoConfig(
ai={
"open_ai": {
"api_key": "test_key",
"model": "gpt-4",
},
"models": {
"chat_model": "claude-3-sonnet",
"edit_model": "gpt-3.5-turbo",
"displayed_models": [],
"custom_models": [],
},
}
),
)

# Existing models should be preserved, not overridden by openai_model
assert (
new_config.get("ai", {}).get("models", {}).get("chat_model")
== "claude-3-sonnet"
)
assert (
new_config.get("ai", {}).get("models", {}).get("edit_model")
== "gpt-3.5-turbo"
)
assert new_config.get("ai", {}).get("open_ai", {}).get("model") == "gpt-4"


def test_merge_config_no_openai_model_no_fallback() -> None:
"""Test that no fallback happens when openai_model is not present."""
base_config = merge_default_config(PartialMarimoConfig())

new_config = merge_config(
base_config,
PartialMarimoConfig(
ai={
"open_ai": {
"api_key": "test_key",
},
}
),
)

# No models should be set since openai_model is not present
assert "chat_model" not in new_config.get("ai", {}).get("models", {})
assert "edit_model" not in new_config.get("ai", {}).get("models", {})


def test_merge_config_partial_model_config() -> None:
"""Test fallback when only one of chat_model or edit_model is present."""
base_config = merge_default_config(PartialMarimoConfig())

# Test with only chat_model present
new_config = merge_config(
base_config,
PartialMarimoConfig(
ai={
"open_ai": {
"api_key": "test_key",
"model": "gpt-4",
},
"models": {
"chat_model": "claude-3-sonnet",
"displayed_models": [],
"custom_models": [],
},
}
),
)

# chat_model should be preserved, but no fallback should happen since chat_model exists
assert (
new_config.get("ai", {}).get("models", {}).get("chat_model")
== "claude-3-sonnet"
)
assert "edit_model" not in new_config.get("ai", {}).get("models", {})

# Test with only edit_model present
new_config = merge_config(
base_config,
PartialMarimoConfig(
ai={
"open_ai": {
"api_key": "test_key",
"model": "gpt-4",
},
"models": {
"edit_model": "gpt-3.5-turbo",
"displayed_models": [],
"custom_models": [],
},
}
),
)

# edit_model should be preserved, but no fallback should happen since edit_model exists
assert (
new_config.get("ai", {}).get("models", {}).get("edit_model")
== "gpt-3.5-turbo"
)
assert "chat_model" not in new_config.get("ai", {}).get("models", {})


def test_merge_config_empty_models_with_openai_model() -> None:
"""Test fallback when models config exists but chat_model and edit_model are empty/None."""
base_config = merge_default_config(PartialMarimoConfig())

new_config = merge_config(
base_config,
PartialMarimoConfig(
ai={
"open_ai": {
"api_key": "test_key",
"model": "gpt-4",
},
"models": {
"chat_model": "",
"edit_model": "",
"displayed_models": [],
"custom_models": [],
},
}
),
)

# Empty/None values should trigger fallback to openai_model
assert (
new_config.get("ai", {}).get("models", {}).get("chat_model") == "gpt-4"
)
assert (
new_config.get("ai", {}).get("models", {}).get("edit_model") == "gpt-4"
)


def test_merge_config_completion_model_migration() -> None:
"""Test that completion.model is migrated to ai.models.autocomplete_model."""
base_config = merge_default_config(PartialMarimoConfig())

new_config = merge_config(
base_config,
PartialMarimoConfig(
completion={
"activate_on_typing": True,
"copilot": "custom",
"model": "custom-model-v1",
}
),
)

# completion.model should be migrated to ai.models.autocomplete_model
assert (
new_config.get("ai", {}).get("models", {}).get("autocomplete_model")
== "custom-model-v1"
)
# Original completion.model should still exist for backward compatibility
assert new_config.get("completion", {}).get("model") == "custom-model-v1"


def test_merge_config_completion_model_no_migration_when_autocomplete_exists() -> (
None
):
"""Test that completion.model is not migrated when ai.models.autocomplete_model already exists."""
base_config = merge_default_config(PartialMarimoConfig())

new_config = merge_config(
base_config,
PartialMarimoConfig(
completion={
"activate_on_typing": True,
"copilot": "custom",
"model": "old-model",
},
ai={
"models": {
"autocomplete_model": "new-model",
"displayed_models": [],
"custom_models": [],
}
},
),
)

# ai.models.autocomplete_model should not be overridden
assert (
new_config.get("ai", {}).get("models", {}).get("autocomplete_model")
== "new-model"
)
# Original completion.model should still exist
assert new_config.get("completion", {}).get("model") == "old-model"


def test_merge_config_completion_model_migration_creates_ai_config() -> None:
"""Test that completion.model migration creates ai config when it doesn't exist."""
base_config = merge_default_config(PartialMarimoConfig())

new_config = merge_config(
base_config,
PartialMarimoConfig(
completion={
"activate_on_typing": True,
"copilot": "custom",
"model": "custom-model",
}
),
)

# Should create ai.models config with the migrated model
assert "ai" in new_config
assert "models" in new_config.get("ai", {})
assert (
new_config.get("ai", {}).get("models", {}).get("autocomplete_model")
== "custom-model"
)
assert (
new_config.get("ai", {}).get("models", {}).get("displayed_models")
== []
)
assert (
new_config.get("ai", {}).get("models", {}).get("custom_models") == []
)
Loading
Loading