Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -387,15 +387,15 @@ export function ProviderConfigModal({

// Validate connection before saving
// For local providers, we might skip this or just check if models exist (which the backend does)
const result = await api.testProviderConnection(provider.id, {
api_key: values.api_key,
base_url: values.base_url,
chat_model: values.chat_model,
});
if (!provider.is_custom) {
const result = await api.testProviderConnection(provider.id, {
api_key: values.api_key,
base_url: values.base_url,
chat_model: values.chat_model,
});

if (!result.success) {
message.error(result.message || t("models.testConnectionFailed"));
if (!provider.is_custom) {
if (!result.success) {
message.error(result.message || t("models.testConnectionFailed"));
// For built-in providers, we want to enforce valid config before saving
return;
}
Expand Down Expand Up @@ -502,14 +502,16 @@ export function ProviderConfigModal({
{t("models.revokeAuthorization")}
</Button>
)}
<Button
size="small"
icon={<ApiOutlined />}
onClick={handleTest}
loading={testing}
>
{t("models.testConnection")}
</Button>
{!provider.is_custom && (
<Button
size="small"
icon={<ApiOutlined />}
onClick={handleTest}
loading={testing}
>
{t("models.testConnection")}
</Button>
)}
Comment thread
pan-x-c marked this conversation as resolved.
</div>
<div className={styles.modalFooterRight}>
<Button onClick={onClose}>{t("models.cancel")}</Button>
Expand Down
9 changes: 7 additions & 2 deletions src/copaw/cli/providers_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def add_provider_cmd(
"""Add a new custom provider."""
manager = _manager()
try:
asyncio.run(
provider_info = asyncio.run(
manager.add_custom_provider(
ProviderInfo(
id=provider_id,
Expand All @@ -518,7 +518,12 @@ def add_provider_cmd(
except ValueError as exc:
click.echo(click.style(f"Error: {exc}", fg="red"))
raise SystemExit(1) from exc
click.echo(f"✓ Custom provider '{name}' ({provider_id}) created.")
click.echo(
"✓ Custom provider "
f"'{provider_info.name}' ({provider_info.id}) created.",
)
if provider_info.id != provider_id:
click.echo(f" requested id: {provider_id}")
if base_url:
click.echo(f" base_url: {base_url}")
click.echo(
Expand Down
31 changes: 21 additions & 10 deletions src/copaw/providers/provider_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,20 +292,31 @@ async def fetch_provider_models(
)
return []

def _resolve_custom_provider_id(self, provider_id: str) -> str:
"""Resolve provider ID conflicts for a custom provider."""
base_id = provider_id
if base_id in self.builtin_providers:
base_id = f"{base_id}-custom"

resolved_id = base_id
while (
resolved_id in self.builtin_providers
or resolved_id in self.custom_providers
):
resolved_id = f"{resolved_id}-new"

return resolved_id

async def add_custom_provider(self, provider_data: ProviderInfo):
# Add a new custom provider with the given data. This will update the
# providers.json file and make the new provider available in the UI.
if provider_data.id in self.builtin_providers:
raise ValueError(
f"'{provider_data.id}' conflicts with a built-in provider.",
)
if provider_data.id in self.custom_providers:
raise ValueError(
f"Custom provider '{provider_data.id}' already exists.",
)
provider_data.is_custom = True
provider_payload = provider_data.model_dump()
provider_payload["id"] = self._resolve_custom_provider_id(
provider_data.id,
)
provider_payload["is_custom"] = True
Comment thread
pan-x-c marked this conversation as resolved.
provider = self._provider_from_data(
provider_data.model_dump(),
provider_payload,
) # Validate provider data
self.custom_providers[provider.id] = provider
self._save_provider(provider, is_builtin=False)
Expand Down
46 changes: 29 additions & 17 deletions tests/unit/providers/test_provider_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,30 +95,33 @@ async def test_add_custom_provider_and_reload_from_storage(
models=[ModelInfo(id="custom-model", name="Custom Model")],
)

await manager.add_custom_provider(custom)
with pytest.raises(ValueError, match="conflicts with a built-in provider"):
await manager.add_custom_provider(
OpenAIProvider(
id="openai",
name="Conflict OpenAI",
),
)

with pytest.raises(
ValueError,
match="Custom provider 'custom-openai' already exists",
):
await manager.add_custom_provider(custom)
created = await manager.add_custom_provider(custom)
builtin_conflict = await manager.add_custom_provider(
OpenAIProvider(
id="openai",
name="Conflict OpenAI",
),
)
duplicate = await manager.add_custom_provider(custom)

reloaded = ProviderManager()
loaded = reloaded.get_provider("custom-openai")
loaded_builtin_conflict = reloaded.get_provider("openai-custom")
loaded_duplicate = reloaded.get_provider("custom-openai-new")

assert created.id == "custom-openai"
assert builtin_conflict.id == "openai-custom"
assert duplicate.id == "custom-openai-new"
assert loaded is not None
assert isinstance(loaded, OpenAIProvider)
assert loaded.is_custom is True
assert loaded.base_url == "https://custom.example/v1"
assert loaded.api_key == "sk-custom"
assert [m.id for m in loaded.models] == ["custom-model"]
assert loaded_builtin_conflict is not None
assert isinstance(loaded_builtin_conflict, OpenAIProvider)
assert loaded_duplicate is not None
assert isinstance(loaded_duplicate, OpenAIProvider)


async def test_activate_provider_persists_active_model(
Expand Down Expand Up @@ -222,7 +225,7 @@ def test_migrate_legacy_file_and_persist_active_model(
assert active_model_file.exists()


async def test_add_custom_provider_conflict_with_builtin_raises(
async def test_add_custom_provider_conflict_resolution_loops_until_unique(
isolated_secret_dir,
) -> None:
manager = ProviderManager()
Expand All @@ -231,8 +234,17 @@ async def test_add_custom_provider_conflict_with_builtin_raises(
name="Conflict OpenAI",
)

with pytest.raises(ValueError, match="conflicts with a built-in provider"):
await manager.add_custom_provider(conflict)
first = await manager.add_custom_provider(conflict)
second = await manager.add_custom_provider(conflict)
third = await manager.add_custom_provider(conflict)

assert first.id == "openai-custom"
assert second.id == "openai-custom-new"
assert third.id == "openai-custom-new-new"

assert manager.get_provider("openai-custom") is not None
assert manager.get_provider("openai-custom-new") is not None
assert manager.get_provider("openai-custom-new-new") is not None


def test_update_provider_for_builtin_persists_to_builtin_path(
Expand Down