diff --git a/config/config.example.json b/config/config.example.json index aa75c8338b..8e2bc315ee 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -91,6 +91,10 @@ "api_key": "YOUR_ZHIPU_API_KEY", "api_base": "" }, + "zai": { + "api_key": "YOUR_ZAI_API_KEY", + "api_base": "" + }, "gemini": { "api_key": "", "api_base": "" diff --git a/pkg/config/config.go b/pkg/config/config.go index d76ec80955..97e6975e89 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -172,6 +172,7 @@ type ProvidersConfig struct { OpenRouter ProviderConfig `json:"openrouter"` Groq ProviderConfig `json:"groq"` Zhipu ProviderConfig `json:"zhipu"` + ZAI ProviderConfig `json:"zai"` VLLM ProviderConfig `json:"vllm"` Gemini ProviderConfig `json:"gemini"` Nvidia ProviderConfig `json:"nvidia"` @@ -299,6 +300,7 @@ func DefaultConfig() *Config { OpenRouter: ProviderConfig{}, Groq: ProviderConfig{}, Zhipu: ProviderConfig{}, + ZAI: ProviderConfig{}, VLLM: ProviderConfig{}, Gemini: ProviderConfig{}, Nvidia: ProviderConfig{}, @@ -396,6 +398,9 @@ func (c *Config) GetAPIKey() string { if c.Providers.Zhipu.APIKey != "" { return c.Providers.Zhipu.APIKey } + if c.Providers.ZAI.APIKey != "" { + return c.Providers.ZAI.APIKey + } if c.Providers.Groq.APIKey != "" { return c.Providers.Groq.APIKey } diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 17eb6214c5..d9a25c2bbd 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -56,7 +56,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5) if idx := strings.Index(model, "/"); idx != -1 { prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" { + if prefix == "moonshot" || prefix == "nvidia" || prefix == "zai" { model = model[idx+1:] } } @@ -277,6 +277,19 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiBase = "https://open.bigmodel.cn/api/paas/v4" } } + case "zai", "z.ai": + if cfg.Providers.ZAI.APIKey != "" { + apiKey = cfg.Providers.ZAI.APIKey + apiBase = cfg.Providers.ZAI.APIBase + proxy = cfg.Providers.ZAI.Proxy + if apiBase == "" { + if strings.Contains(lowerModel, "coding") { + apiBase = "https://api.z.ai/api/coding/paas/v4" + } else { + apiBase = "https://api.z.ai/api/paas/v4" + } + } + } case "gemini", "google": if cfg.Providers.Gemini.APIKey != "" { apiKey = cfg.Providers.Gemini.APIKey @@ -377,7 +390,19 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiBase = "https://generativelanguage.googleapis.com/v1beta" } - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "cogview") || strings.Contains(lowerModel, "cogvideo") || strings.Contains(lowerModel, "autoglm") || strings.Contains(lowerModel, "vidu")) && cfg.Providers.ZAI.APIKey != "": + apiKey = cfg.Providers.ZAI.APIKey + apiBase = cfg.Providers.ZAI.APIBase + proxy = cfg.Providers.ZAI.Proxy + if apiBase == "" { + if strings.Contains(lowerModel, "coding") { + apiBase = "https://api.z.ai/api/coding/paas/v4" + } else { + apiBase = "https://api.z.ai/api/paas/v4" + } + } + + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu")) && cfg.Providers.Zhipu.APIKey != "": apiKey = cfg.Providers.Zhipu.APIKey apiBase = cfg.Providers.Zhipu.APIBase proxy = cfg.Providers.Zhipu.Proxy