diff --git a/docs/migration/model-list-migration.md b/docs/migration/model-list-migration.md index 0d4af719ca..41b196b297 100644 --- a/docs/migration/model-list-migration.md +++ b/docs/migration/model-list-migration.md @@ -113,6 +113,7 @@ The `model` field uses a protocol prefix format: `[protocol/]model-identifier` | `api_base` | No | API endpoint URL | | `api_key` | No* | API authentication key | | `proxy` | No | HTTP proxy URL | +| `headers` | No | Custom HTTP headers (`Authorization` and `Content-Type` are reserved and will be ignored) | | `auth_method` | No | Authentication method: `oauth`, `token` | | `connect_mode` | No | Connection mode for CLI providers: `stdio`, `grpc` | | `rpm` | No | Requests per minute limit | diff --git a/pkg/config/config.go b/pkg/config/config.go index 0ee3acfe05..d6efbe3b50 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -497,6 +497,7 @@ type ModelConfig struct { APIBase string `json:"api_base,omitempty"` // API endpoint URL APIKey string `json:"api_key"` // API authentication key Proxy string `json:"proxy,omitempty"` // HTTP proxy URL + Headers map[string]string `json:"headers,omitempty"` // Custom HTTP headers (Authorization and Content-Type are reserved and will be ignored) // Special providers (CLI-based, OAuth, etc.) AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index c05fb0ad4d..072fea7ed0 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -84,12 +84,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if apiBase == "" { apiBase = getDefaultAPIBase(protocol) } - return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( + return NewHTTPProviderWithOptions( cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField, cfg.RequestTimeout, + cfg.Headers, ), modelID, nil case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", @@ -103,12 +104,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if apiBase == "" { apiBase = getDefaultAPIBase(protocol) } - return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( + return NewHTTPProviderWithOptions( cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField, cfg.RequestTimeout, + cfg.Headers, ), modelID, nil case "anthropic": @@ -128,12 +130,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if cfg.APIKey == "" { return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model) } - return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( + return NewHTTPProviderWithOptions( cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField, cfg.RequestTimeout, + cfg.Headers, ), modelID, nil case "antigravity": diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 5c328f418c..7f7fbd7f74 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -31,14 +31,23 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( apiKey, apiBase, proxy, maxTokensField string, requestTimeoutSeconds int, ) *HTTPProvider { + return NewHTTPProviderWithOptions(apiKey, apiBase, proxy, maxTokensField, requestTimeoutSeconds, nil) +} + +func NewHTTPProviderWithOptions( + apiKey, apiBase, proxy, maxTokensField string, + requestTimeoutSeconds int, + customHeaders map[string]string, +) *HTTPProvider { + opts := []openai_compat.Option{ + openai_compat.WithMaxTokensField(maxTokensField), + openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second), + } + if customHeaders != nil { + opts = append(opts, openai_compat.WithCustomHeaders(customHeaders)) + } return &HTTPProvider{ - delegate: openai_compat.NewProvider( - apiKey, - apiBase, - proxy, - openai_compat.WithMaxTokensField(maxTokensField), - openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second), - ), + delegate: openai_compat.NewProvider(apiKey, apiBase, proxy, opts...), } } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 1904ee1533..550f37e01d 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -32,6 +32,7 @@ type Provider struct { apiKey string apiBase string maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models) + customHeaders map[string]string httpClient *http.Client } @@ -53,6 +54,12 @@ func WithRequestTimeout(timeout time.Duration) Option { } } +func WithCustomHeaders(headers map[string]string) Option { + return func(p *Provider) { + p.customHeaders = headers + } +} + func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { client := &http.Client{ Timeout: defaultRequestTimeout, @@ -176,6 +183,12 @@ func (p *Provider) Chat( if p.apiKey != "" { req.Header.Set("Authorization", "Bearer "+p.apiKey) } + for key, value := range p.customHeaders { + if k := http.CanonicalHeaderKey(key); k == "Authorization" || k == "Content-Type" { + continue + } + req.Header.Set(key, value) + } resp, err := p.httpClient.Do(req) if err != nil { diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 174bcf00d8..45fb43e7f5 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -420,96 +420,118 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) { } } -func TestSerializeMessages_PlainText(t *testing.T) { - messages := []protocoltypes.Message{ - {Role: "user", Content: "hello"}, - {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."}, - } - result := serializeMessages(messages) +func TestProviderChat_SendsCustomHeaders(t *testing.T) { + var gotHeader string - data, err := json.Marshal(result) - if err != nil { - t.Fatal(err) - } - - var msgs []map[string]any - json.Unmarshal(data, &msgs) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeader = r.Header.Get("X-Custom-Key") + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() - if msgs[0]["content"] != "hello" { - t.Fatalf("expected plain string content, got %v", msgs[0]["content"]) + p := NewProvider("key", server.URL, "", WithCustomHeaders(map[string]string{ + "X-Custom-Key": "myvalue", + })) + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) } - if msgs[1]["reasoning_content"] != "thinking..." { - t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"]) + if gotHeader != "myvalue" { + t.Fatalf("X-Custom-Key = %q, want %q", gotHeader, "myvalue") } } -func TestSerializeMessages_WithMedia(t *testing.T) { - messages := []protocoltypes.Message{ - {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}}, - } - result := serializeMessages(messages) - - data, _ := json.Marshal(result) - var msgs []map[string]any - json.Unmarshal(data, &msgs) +func TestProviderChat_SkipsReservedHeaders(t *testing.T) { + var gotAuth, gotContentType string - content, ok := msgs[0]["content"].([]any) - if !ok { - t.Fatalf("expected array content for media message, got %T", msgs[0]["content"]) - } - if len(content) != 2 { - t.Fatalf("expected 2 content parts, got %d", len(content)) - } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + gotContentType = r.Header.Get("Content-Type") + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() - textPart := content[0].(map[string]any) - if textPart["type"] != "text" || textPart["text"] != "describe this" { - t.Fatalf("text part mismatch: %v", textPart) + p := NewProvider("mykey", server.URL, "", WithCustomHeaders(map[string]string{ + "Authorization": "Bearer OVERRIDE", + "Content-Type": "text/plain", + "X-Extra": "allowed", + })) + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) } - - imgPart := content[1].(map[string]any) - if imgPart["type"] != "image_url" { - t.Fatalf("expected image_url type, got %v", imgPart["type"]) + if gotAuth != "Bearer mykey" { + t.Fatalf("Authorization = %q, want %q", gotAuth, "Bearer mykey") } - imgURL := imgPart["image_url"].(map[string]any) - if imgURL["url"] != "data:image/png;base64,abc123" { - t.Fatalf("image url mismatch: %v", imgURL["url"]) + if gotContentType != "application/json" { + t.Fatalf("Content-Type = %q, want %q", gotContentType, "application/json") } } -func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { - messages := []protocoltypes.Message{ - {Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, - } - result := serializeMessages(messages) - - data, _ := json.Marshal(result) - var msgs []map[string]any - json.Unmarshal(data, &msgs) - - if msgs[0]["tool_call_id"] != "call_1" { - t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"]) - } - // Content should be multipart array - if _, ok := msgs[0]["content"].([]any); !ok { - t.Fatalf("expected array content, got %T", msgs[0]["content"]) - } -} +func TestProviderChat_RoundTripsReasoningContent(t *testing.T) { + var reqBody map[string]any -func TestSerializeMessages_StripsSystemParts(t *testing.T) { - messages := []protocoltypes.Message{ - { - Role: "system", - Content: "you are helpful", - SystemParts: []protocoltypes.ContentBlock{ - {Type: "text", Text: "you are helpful"}, + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{ + {Role: "user", Content: "1+1=?"}, + {Role: "assistant", Content: "2", ReasoningContent: "let me think..."}, + {Role: "user", Content: "thanks"}, }, + nil, + "gpt-4o", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) } - result := serializeMessages(messages) - data, _ := json.Marshal(result) - raw := string(data) - if strings.Contains(raw, "system_parts") { - t.Fatal("system_parts should not appear in serialized output") + msgs, ok := reqBody["messages"].([]any) + if !ok { + t.Fatalf("messages is not []any") + } + assistantMsg, ok := msgs[1].(map[string]any) + if !ok { + t.Fatalf("messages[1] is not map[string]any") + } + if got := assistantMsg["reasoning_content"]; got != "let me think..." { + t.Fatalf("reasoning_content = %v, want %q", got, "let me think...") } } diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index 244f0d4a24..d7a38b0de4 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -107,6 +107,7 @@ func RunToolLoop( Role: "assistant", Content: response.Content, } + assistantMsg.ReasoningContent = response.ReasoningContent for _, tc := range normalizedToolCalls { argumentsJSON, _ := json.Marshal(tc.Arguments) assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{