Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/migration/model-list-migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ The `model` field uses a protocol prefix format: `[protocol/]model-identifier`
| `api_base` | No | API endpoint URL |
| `api_key` | No* | API authentication key |
| `proxy` | No | HTTP proxy URL |
| `headers` | No | Custom HTTP headers (`Authorization` and `Content-Type` are reserved and will be ignored) |
| `auth_method` | No | Authentication method: `oauth`, `token` |
| `connect_mode` | No | Connection mode for CLI providers: `stdio`, `grpc` |
| `rpm` | No | Requests per minute limit |
Expand Down
1 change: 1 addition & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ type ModelConfig struct {
APIBase string `json:"api_base,omitempty"` // API endpoint URL
APIKey string `json:"api_key"` // API authentication key
Proxy string `json:"proxy,omitempty"` // HTTP proxy URL
Headers map[string]string `json:"headers,omitempty"` // Custom HTTP headers (Authorization and Content-Type are reserved and will be ignored)

// Special providers (CLI-based, OAuth, etc.)
AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token
Expand Down
9 changes: 6 additions & 3 deletions pkg/providers/factory_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
return NewHTTPProviderWithOptions(
cfg.APIKey,
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
cfg.Headers,
), modelID, nil

case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
Expand All @@ -103,12 +104,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
return NewHTTPProviderWithOptions(
cfg.APIKey,
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
cfg.Headers,
), modelID, nil

case "anthropic":
Expand All @@ -128,12 +130,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if cfg.APIKey == "" {
return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model)
}
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
return NewHTTPProviderWithOptions(
cfg.APIKey,
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
cfg.RequestTimeout,
cfg.Headers,
), modelID, nil

case "antigravity":
Expand Down
23 changes: 16 additions & 7 deletions pkg/providers/http_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,23 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
apiKey, apiBase, proxy, maxTokensField string,
requestTimeoutSeconds int,
) *HTTPProvider {
return NewHTTPProviderWithOptions(apiKey, apiBase, proxy, maxTokensField, requestTimeoutSeconds, nil)
}

func NewHTTPProviderWithOptions(
apiKey, apiBase, proxy, maxTokensField string,
requestTimeoutSeconds int,
customHeaders map[string]string,
) *HTTPProvider {
opts := []openai_compat.Option{
openai_compat.WithMaxTokensField(maxTokensField),
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
}
if customHeaders != nil {
opts = append(opts, openai_compat.WithCustomHeaders(customHeaders))
}
return &HTTPProvider{
delegate: openai_compat.NewProvider(
apiKey,
apiBase,
proxy,
openai_compat.WithMaxTokensField(maxTokensField),
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
),
delegate: openai_compat.NewProvider(apiKey, apiBase, proxy, opts...),
}
}

Expand Down
13 changes: 13 additions & 0 deletions pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Provider struct {
apiKey string
apiBase string
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
customHeaders map[string]string
httpClient *http.Client
}

Expand All @@ -53,6 +54,12 @@ func WithRequestTimeout(timeout time.Duration) Option {
}
}

func WithCustomHeaders(headers map[string]string) Option {
return func(p *Provider) {
p.customHeaders = headers
}
}

func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
client := &http.Client{
Timeout: defaultRequestTimeout,
Expand Down Expand Up @@ -176,6 +183,12 @@ func (p *Provider) Chat(
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
for key, value := range p.customHeaders {
if k := http.CanonicalHeaderKey(key); k == "Authorization" || k == "Content-Type" {
continue
}
req.Header.Set(key, value)
}

resp, err := p.httpClient.Do(req)
if err != nil {
Expand Down
168 changes: 95 additions & 73 deletions pkg/providers/openai_compat/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
"net/http"
"net/http/httptest"
"net/url"
"strings"

Check failure on line 8 in pkg/providers/openai_compat/provider_test.go

View workflow job for this annotation

GitHub Actions / Tests

"strings" imported and not used

Check failure on line 8 in pkg/providers/openai_compat/provider_test.go

View workflow job for this annotation

GitHub Actions / Linter

"strings" imported and not used
"testing"
"time"

"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"

Check failure on line 12 in pkg/providers/openai_compat/provider_test.go

View workflow job for this annotation

GitHub Actions / Tests

"github.com/sipeed/picoclaw/pkg/providers/protocoltypes" imported and not used

Check failure on line 12 in pkg/providers/openai_compat/provider_test.go

View workflow job for this annotation

GitHub Actions / Linter

"github.com/sipeed/picoclaw/pkg/providers/protocoltypes" imported and not used (typecheck)
)

func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
Expand Down Expand Up @@ -420,96 +420,118 @@
}
}

func TestSerializeMessages_PlainText(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
}
result := serializeMessages(messages)
func TestProviderChat_SendsCustomHeaders(t *testing.T) {
var gotHeader string

data, err := json.Marshal(result)
if err != nil {
t.Fatal(err)
}

var msgs []map[string]any
json.Unmarshal(data, &msgs)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeader = r.Header.Get("X-Custom-Key")
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()

if msgs[0]["content"] != "hello" {
t.Fatalf("expected plain string content, got %v", msgs[0]["content"])
p := NewProvider("key", server.URL, "", WithCustomHeaders(map[string]string{
"X-Custom-Key": "myvalue",
}))
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if msgs[1]["reasoning_content"] != "thinking..." {
t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
if gotHeader != "myvalue" {
t.Fatalf("X-Custom-Key = %q, want %q", gotHeader, "myvalue")
}
}

func TestSerializeMessages_WithMedia(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
result := serializeMessages(messages)

data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
func TestProviderChat_SkipsReservedHeaders(t *testing.T) {
var gotAuth, gotContentType string

content, ok := msgs[0]["content"].([]any)
if !ok {
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
}
if len(content) != 2 {
t.Fatalf("expected 2 content parts, got %d", len(content))
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
gotContentType = r.Header.Get("Content-Type")
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()

textPart := content[0].(map[string]any)
if textPart["type"] != "text" || textPart["text"] != "describe this" {
t.Fatalf("text part mismatch: %v", textPart)
p := NewProvider("mykey", server.URL, "", WithCustomHeaders(map[string]string{
"Authorization": "Bearer OVERRIDE",
"Content-Type": "text/plain",
"X-Extra": "allowed",
}))
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}

imgPart := content[1].(map[string]any)
if imgPart["type"] != "image_url" {
t.Fatalf("expected image_url type, got %v", imgPart["type"])
if gotAuth != "Bearer mykey" {
t.Fatalf("Authorization = %q, want %q", gotAuth, "Bearer mykey")
}
imgURL := imgPart["image_url"].(map[string]any)
if imgURL["url"] != "data:image/png;base64,abc123" {
t.Fatalf("image url mismatch: %v", imgURL["url"])
if gotContentType != "application/json" {
t.Fatalf("Content-Type = %q, want %q", gotContentType, "application/json")
}
}

func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
}
result := serializeMessages(messages)

data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)

if msgs[0]["tool_call_id"] != "call_1" {
t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"])
}
// Content should be multipart array
if _, ok := msgs[0]["content"].([]any); !ok {
t.Fatalf("expected array content, got %T", msgs[0]["content"])
}
}
func TestProviderChat_RoundTripsReasoningContent(t *testing.T) {
var reqBody map[string]any

func TestSerializeMessages_StripsSystemParts(t *testing.T) {
messages := []protocoltypes.Message{
{
Role: "system",
Content: "you are helpful",
SystemParts: []protocoltypes.ContentBlock{
{Type: "text", Text: "you are helpful"},
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()

p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{
{Role: "user", Content: "1+1=?"},
{Role: "assistant", Content: "2", ReasoningContent: "let me think..."},
{Role: "user", Content: "thanks"},
},
nil,
"gpt-4o",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
result := serializeMessages(messages)

data, _ := json.Marshal(result)
raw := string(data)
if strings.Contains(raw, "system_parts") {
t.Fatal("system_parts should not appear in serialized output")
msgs, ok := reqBody["messages"].([]any)
if !ok {
t.Fatalf("messages is not []any")
}
assistantMsg, ok := msgs[1].(map[string]any)
if !ok {
t.Fatalf("messages[1] is not map[string]any")
}
if got := assistantMsg["reasoning_content"]; got != "let me think..." {
t.Fatalf("reasoning_content = %v, want %q", got, "let me think...")
}
}
1 change: 1 addition & 0 deletions pkg/tools/toolloop.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ func RunToolLoop(
Role: "assistant",
Content: response.Content,
}
assistantMsg.ReasoningContent = response.ReasoningContent
for _, tc := range normalizedToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
Expand Down
Loading