diff --git a/README.fr.md b/README.fr.md index 59dc0cb840..8a655a18c7 100644 --- a/README.fr.md +++ b/README.fr.md @@ -987,6 +987,7 @@ Cette conception permet également le **support multi-agent** avec une sélectio | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obtenir Clé](https://www.byteplus.com/) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obtenir une clé](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obtenir un Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obtenir Clé](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth uniquement | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.ja.md b/README.ja.md index 6702fe3f8b..df493b2786 100644 --- a/README.ja.md +++ b/README.ja.md @@ -928,6 +928,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [キーを取得](https://www.byteplus.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [キーを取得](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [トークンを取得](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [キーを取得](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | カスタム | OAuthのみ | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.md b/README.md index 16252f83bd..2cecae6f0d 100644 --- a/README.md +++ b/README.md @@ -1005,6 +1005,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | | `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) | | `vivgrid` | LLM (Vivgrid direct) | [vivgrid.com](https://vivgrid.com) | +| `azure` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) | ### Model Configuration (model_list) @@ -1041,6 +1042,7 @@ This design also enables **multi-agent support** with flexible provider selectio | **Vivgrid** | `vivgrid/` | `https://api.vivgrid.com/v1` | OpenAI | [Get Key](https://vivgrid.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Get Key](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Get Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Get Key](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.pt-br.md b/README.pt-br.md index 91bf236193..9365c79af6 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -983,6 +983,7 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obter Chave](https://www.byteplus.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obter Chave](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obter Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obter Chave](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | Apenas OAuth | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.vi.md b/README.vi.md index 0ca95d5d74..a4bba93dc5 100644 --- a/README.vi.md +++ b/README.vi.md @@ -952,6 +952,7 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Lấy Khóa](https://www.byteplus.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Lấy Key](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Lấy Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Lấy Khóa](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | Tùy chỉnh | Chỉ OAuth | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.zh.md b/README.zh.md index 6458cfcedf..b60579d406 100644 --- a/README.zh.md +++ b/README.zh.md @@ -524,6 +524,7 @@ Agent 读取 HEARTBEAT.md | **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [获取密钥](https://www.byteplus.com) | | **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [获取密钥](https://longcat.chat/platform) | | **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [获取 Token](https://modelscope.cn/my/tokens) | +| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [获取密钥](https://portal.azure.com) | | **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/config/config.example.json b/config/config.example.json index 094aa46df2..1c11cd42a9 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -53,6 +53,12 @@ "api_key": "your-modelscope-access-token", "api_base": "https://api-inference.modelscope.cn/v1" }, + { + "model_name": "azure-gpt5", + "model": "azure/my-gpt5-deployment", + "api_key": "your-azure-api-key", + "api_base": "https://your-resource.openai.azure.com" + }, { "model_name": "loadbalanced-gpt-5.4", "model": "openai/gpt-5.4", diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 189af0a845..dc534d852d 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -384,6 +384,15 @@ func DefaultConfig() *Config { APIBase: "http://localhost:8000/v1", APIKey: "", }, + + // Azure OpenAI - https://portal.azure.com + // model_name is a user-friendly alias; the model field's path after "azure/" is your deployment name + { + ModelName: "azure-gpt5", + Model: "azure/my-gpt5-deployment", + APIBase: "https://your-resource.openai.azure.com", + APIKey: "", + }, }, Gateway: GatewayConfig{ Host: "127.0.0.1", diff --git a/pkg/providers/azure/provider.go b/pkg/providers/azure/provider.go new file mode 100644 index 0000000000..6e1d07e78a --- /dev/null +++ b/pkg/providers/azure/provider.go @@ -0,0 +1,150 @@ +package azure + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/common" + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ( + LLMResponse = protocoltypes.LLMResponse + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition +) + +const ( + // azureAPIVersion is the Azure OpenAI API version used for all requests. + azureAPIVersion = "2024-10-21" + defaultRequestTimeout = common.DefaultRequestTimeout +) + +// Provider implements the LLM provider interface for Azure OpenAI endpoints. +// It handles Azure-specific authentication (api-key header), URL construction +// (deployment-based), and request body formatting (max_completion_tokens, no model field). +type Provider struct { + apiKey string + apiBase string + httpClient *http.Client +} + +// Option configures the Azure Provider. +type Option func(*Provider) + +// WithRequestTimeout sets the HTTP request timeout. +func WithRequestTimeout(timeout time.Duration) Option { + return func(p *Provider) { + if timeout > 0 { + p.httpClient.Timeout = timeout + } + } +} + +// NewProvider creates a new Azure OpenAI provider. +func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { + p := &Provider{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + httpClient: common.NewHTTPClient(proxy), + } + + for _, opt := range opts { + if opt != nil { + opt(p) + } + } + + return p +} + +// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds. +func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider { + return NewProvider( + apiKey, apiBase, proxy, + WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second), + ) +} + +// Chat sends a chat completion request to the Azure OpenAI endpoint. +// The model parameter is used as the Azure deployment name in the URL. +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("Azure API base not configured") + } + + // model is the deployment name for Azure OpenAI + deployment := model + + // Build Azure-specific URL safely using url.JoinPath and query encoding + // to prevent path traversal or query injection via deployment names. + base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions") + if err != nil { + return nil, fmt.Errorf("failed to build Azure request URL: %w", err) + } + requestURL := base + "?api-version=" + azureAPIVersion + + // Build request body — no "model" field (Azure infers from deployment URL) + requestBody := map[string]any{ + "messages": common.SerializeMessages(messages), + } + + if len(tools) > 0 { + requestBody["tools"] = tools + requestBody["tool_choice"] = "auto" + } + + // Azure OpenAI always uses max_completion_tokens + if maxTokens, ok := common.AsInt(options["max_tokens"]); ok { + requestBody["max_completion_tokens"] = maxTokens + } + + if temperature, ok := common.AsFloat(options["temperature"]); ok { + requestBody["temperature"] = temperature + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Azure uses api-key header instead of Authorization: Bearer + req.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + req.Header.Set("api-key", p.apiKey) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, common.HandleErrorResponse(resp, p.apiBase) + } + + return common.ReadAndParseResponse(resp, p.apiBase) +} + +// GetDefaultModel returns an empty string as Azure deployments are user-configured. +func (p *Provider) GetDefaultModel() string { + return "" +} diff --git a/pkg/providers/azure/provider_test.go b/pkg/providers/azure/provider_test.go new file mode 100644 index 0000000000..8f44edff52 --- /dev/null +++ b/pkg/providers/azure/provider_test.go @@ -0,0 +1,232 @@ +package azure + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +// writeValidResponse writes a minimal valid Azure OpenAI chat completion response. +func writeValidResponse(w http.ResponseWriter) { + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func TestProviderChat_AzureURLConstruction(t *testing.T) { + var capturedPath string + var capturedAPIVersion string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + capturedAPIVersion = r.URL.Query().Get("api-version") + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions" + if capturedPath != wantPath { + t.Errorf("URL path = %q, want %q", capturedPath, wantPath) + } + if capturedAPIVersion != azureAPIVersion { + t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion) + } +} + +func TestProviderChat_AzureAuthHeader(t *testing.T) { + var capturedAPIKey string + var capturedAuth string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAPIKey = r.Header.Get("api-key") + capturedAuth = r.Header.Get("Authorization") + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-azure-key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if capturedAPIKey != "test-azure-key" { + t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key") + } + if capturedAuth != "" { + t.Errorf("Authorization header should be empty, got %q", capturedAuth) + } +} + +func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&requestBody) + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if _, exists := requestBody["model"]; exists { + t.Error("request body should not contain 'model' field for Azure OpenAI") + } +} + +func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&requestBody) + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "deployment", + map[string]any{"max_tokens": 2048}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if _, exists := requestBody["max_completion_tokens"]; !exists { + t.Error("request body should contain 'max_completion_tokens'") + } + if _, exists := requestBody["max_tokens"]; exists { + t.Error("request body should not contain 'max_tokens'") + } +} + +func TestProviderChat_AzureHTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized) + })) + defer server.Close() + + p := NewProvider("bad-key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProviderChat_AzureParseToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": "", + "tool_calls": []map[string]any{ + { + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": `{"city":"Seattle"}`, + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } +} + +func TestProvider_AzureEmptyAPIBase(t *testing.T) { + p := NewProvider("test-key", "", "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil) + if err == nil { + t.Fatal("expected error for empty API base") + } +} + +func TestProvider_AzureRequestTimeoutDefault(t *testing.T) { + p := NewProvider("test-key", "https://example.com", "") + if p.httpClient.Timeout != defaultRequestTimeout { + t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout) + } +} + +func TestProvider_AzureRequestTimeoutOverride(t *testing.T) { + p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second)) + if p.httpClient.Timeout != 300*time.Second { + t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second) + } +} + +func TestProvider_AzureNewProviderWithTimeout(t *testing.T) { + p := NewProviderWithTimeout("test-key", "https://example.com", "", 180) + if p.httpClient.Timeout != 180*time.Second { + t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second) + } +} + +func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) { + var capturedPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.RawPath // use RawPath to see percent-encoding + if capturedPath == "" { + capturedPath = r.URL.Path + } + writeValidResponse(w) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + + // Deployment name with characters that could cause path injection + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + // The slash and special chars in the deployment name must be escaped, not treated as path separators + if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" { + t.Fatal("deployment name was interpolated without escaping — path injection possible") + } +} diff --git a/pkg/providers/common/common.go b/pkg/providers/common/common.go new file mode 100644 index 0000000000..23680a1bf9 --- /dev/null +++ b/pkg/providers/common/common.go @@ -0,0 +1,380 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +// Package common provides shared utilities used by multiple LLM provider +// implementations (openai_compat, azure, etc.). +package common + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +// Re-export protocol types used across providers. +type ( + ToolCall = protocoltypes.ToolCall + FunctionCall = protocoltypes.FunctionCall + LLMResponse = protocoltypes.LLMResponse + UsageInfo = protocoltypes.UsageInfo + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition + ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition + ExtraContent = protocoltypes.ExtraContent + GoogleExtra = protocoltypes.GoogleExtra + ReasoningDetail = protocoltypes.ReasoningDetail +) + +const DefaultRequestTimeout = 120 * time.Second + +// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout. +func NewHTTPClient(proxy string) *http.Client { + client := &http.Client{ + Timeout: DefaultRequestTimeout, + } + if proxy != "" { + parsed, err := url.Parse(proxy) + if err == nil { + // Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.) + if base, ok := http.DefaultTransport.(*http.Transport); ok { + tr := base.Clone() + tr.Proxy = http.ProxyURL(parsed) + client.Transport = tr + } else { + // Fallback: minimal transport if DefaultTransport is not *http.Transport. + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(parsed), + } + } + } else { + log.Printf("common: invalid proxy URL %q: %v", proxy, err) + } + } + return client +} + +// --- Message serialization --- + +// openaiMessage is the wire-format message for OpenAI-compatible APIs. +// It mirrors protocoltypes.Message but omits SystemParts, which is an +// internal field that would be unknown to third-party endpoints. +type openaiMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// SerializeMessages converts internal Message structs to the OpenAI wire format. +// - Strips SystemParts (unknown to third-party endpoints) +// - Converts messages with Media to multipart content format (text + image_url parts) +// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages +func SerializeMessages(messages []Message) []any { + out := make([]any, 0, len(messages)) + for _, m := range messages { + if len(m.Media) == 0 { + out = append(out, openaiMessage{ + Role: m.Role, + Content: m.Content, + ReasoningContent: m.ReasoningContent, + ToolCalls: m.ToolCalls, + ToolCallID: m.ToolCallID, + }) + continue + } + + // Multipart content format for messages with media + parts := make([]map[string]any, 0, 1+len(m.Media)) + if m.Content != "" { + parts = append(parts, map[string]any{ + "type": "text", + "text": m.Content, + }) + } + for _, mediaURL := range m.Media { + if strings.HasPrefix(mediaURL, "data:image/") { + parts = append(parts, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": mediaURL, + }, + }) + } + } + + msg := map[string]any{ + "role": m.Role, + "content": parts, + } + if m.ToolCallID != "" { + msg["tool_call_id"] = m.ToolCallID + } + if len(m.ToolCalls) > 0 { + msg["tool_calls"] = m.ToolCalls + } + if m.ReasoningContent != "" { + msg["reasoning_content"] = m.ReasoningContent + } + out = append(out, msg) + } + return out +} + +// --- Response parsing --- + +// ParseResponse parses a JSON chat completion response body into an LLMResponse. +func ParseResponse(body io.Reader) (*LLMResponse, error) { + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + Reasoning string `json:"reasoning"` + ReasoningDetails []ReasoningDetail `json:"reasoning_details"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function *struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` + } `json:"function"` + ExtraContent *struct { + Google *struct { + ThoughtSignature string `json:"thought_signature"` + } `json:"google"` + } `json:"extra_content"` + } `json:"tool_calls"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *UsageInfo `json:"usage"` + } + + if err := json.NewDecoder(body).Decode(&apiResponse); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + }, nil + } + + choice := apiResponse.Choices[0] + toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + arguments := make(map[string]any) + name := "" + + // Extract thought_signature from Gemini/Google-specific extra content + thoughtSignature := "" + if tc.ExtraContent != nil && tc.ExtraContent.Google != nil { + thoughtSignature = tc.ExtraContent.Google.ThoughtSignature + } + + if tc.Function != nil { + name = tc.Function.Name + arguments = DecodeToolCallArguments(tc.Function.Arguments, name) + } + + toolCall := ToolCall{ + ID: tc.ID, + Name: name, + Arguments: arguments, + ThoughtSignature: thoughtSignature, + } + + if thoughtSignature != "" { + toolCall.ExtraContent = &ExtraContent{ + Google: &GoogleExtra{ + ThoughtSignature: thoughtSignature, + }, + } + } + + toolCalls = append(toolCalls, toolCall) + } + + return &LLMResponse{ + Content: choice.Message.Content, + ReasoningContent: choice.Message.ReasoningContent, + Reasoning: choice.Message.Reasoning, + ReasoningDetails: choice.Message.ReasoningDetails, + ToolCalls: toolCalls, + FinishReason: choice.FinishReason, + Usage: apiResponse.Usage, + }, nil +} + +// DecodeToolCallArguments decodes a tool call's arguments from raw JSON. +func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any { + arguments := make(map[string]any) + raw = bytes.TrimSpace(raw) + if len(raw) == 0 || bytes.Equal(raw, []byte("null")) { + return arguments + } + + var decoded any + if err := json.Unmarshal(raw, &decoded); err != nil { + log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err) + arguments["raw"] = string(raw) + return arguments + } + + switch v := decoded.(type) { + case string: + if strings.TrimSpace(v) == "" { + return arguments + } + if err := json.Unmarshal([]byte(v), &arguments); err != nil { + log.Printf("common: failed to decode tool call arguments for %q: %v", name, err) + arguments["raw"] = v + } + return arguments + case map[string]any: + return v + default: + log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded) + arguments["raw"] = string(raw) + return arguments + } +} + +// --- HTTP response helpers --- + +// HandleErrorResponse reads a non-200 response body and returns an appropriate error. +func HandleErrorResponse(resp *http.Response, apiBase string) error { + contentType := resp.Header.Get("Content-Type") + body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256)) + if readErr != nil { + return fmt.Errorf("failed to read response: %w", readErr) + } + if LooksLikeHTML(body, contentType) { + return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase) + } + return fmt.Errorf( + "API request failed:\n Status: %d\n Body: %s", + resp.StatusCode, + ResponsePreview(body, 128), + ) +} + +// ReadAndParseResponse peeks at the response body to detect HTML errors, +// then parses the JSON response into an LLMResponse. +func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) { + contentType := resp.Header.Get("Content-Type") + reader := bufio.NewReader(resp.Body) + prefix, err := reader.Peek(256) + if err != nil && err != io.EOF && err != bufio.ErrBufferFull { + return nil, fmt.Errorf("failed to inspect response: %w", err) + } + if LooksLikeHTML(prefix, contentType) { + return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase) + } + out, err := ParseResponse(reader) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", err) + } + return out, nil +} + +// LooksLikeHTML checks if the response body appears to be HTML. +func LooksLikeHTML(body []byte, contentType string) bool { + contentType = strings.ToLower(strings.TrimSpace(contentType)) + if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") { + return true + } + prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128)) + return bytes.HasPrefix(prefix, []byte("" + } + if len(trimmed) <= maxLen { + return string(trimmed) + } + return string(trimmed[:maxLen]) + "..." +} + +func leadingTrimmedPrefix(body []byte, maxLen int) []byte { + i := 0 + for i < len(body) { + switch body[i] { + case ' ', '\t', '\n', '\r', '\f', '\v': + i++ + default: + end := i + maxLen + if end > len(body) { + end = len(body) + } + return body[i:end] + } + } + return nil +} + +// --- Numeric helpers --- + +// AsInt converts various numeric types to int. +func AsInt(v any) (int, bool) { + switch val := v.(type) { + case int: + return val, true + case int64: + return int(val), true + case float64: + return int(val), true + case float32: + return int(val), true + default: + return 0, false + } +} + +// AsFloat converts various numeric types to float64. +func AsFloat(v any) (float64, bool) { + switch val := v.(type) { + case float64: + return val, true + case float32: + return float64(val), true + case int: + return float64(val), true + case int64: + return float64(val), true + default: + return 0, false + } +} diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go new file mode 100644 index 0000000000..bb7e7434d0 --- /dev/null +++ b/pkg/providers/common/common_test.go @@ -0,0 +1,558 @@ +package common + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +// --- NewHTTPClient tests --- + +func TestNewHTTPClient_DefaultTimeout(t *testing.T) { + client := NewHTTPClient("") + if client.Timeout != DefaultRequestTimeout { + t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout) + } +} + +func TestNewHTTPClient_WithProxy(t *testing.T) { + client := NewHTTPClient("http://127.0.0.1:8080") + transport, ok := client.Transport.(*http.Transport) + if !ok || transport == nil { + t.Fatalf("expected http.Transport with proxy, got %T", client.Transport) + } + req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}} + gotProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("proxy function error: %v", err) + } + if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" { + t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy) + } +} + +func TestNewHTTPClient_NoProxy(t *testing.T) { + client := NewHTTPClient("") + if client.Transport != nil { + t.Errorf("expected nil transport without proxy, got %T", client.Transport) + } +} + +func TestNewHTTPClient_InvalidProxy(t *testing.T) { + // Should not panic, just log and return client without proxy + client := NewHTTPClient("://bad-url") + if client == nil { + t.Fatal("expected non-nil client even with invalid proxy") + } +} + +// --- SerializeMessages tests --- + +func TestSerializeMessages_PlainText(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."}, + } + result := SerializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["content"] != "hello" { + t.Errorf("expected plain string content, got %v", msgs[0]["content"]) + } + if msgs[1]["reasoning_content"] != "thinking..." { + t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"]) + } +} + +func TestSerializeMessages_WithMedia(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}}, + } + result := SerializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + content, ok := msgs[0]["content"].([]any) + if !ok { + t.Fatalf("expected array content for media message, got %T", msgs[0]["content"]) + } + if len(content) != 2 { + t.Fatalf("expected 2 content parts, got %d", len(content)) + } +} + +func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { + messages := []Message{ + {Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, + } + result := SerializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["tool_call_id"] != "call_1" { + t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"]) + } +} + +func TestSerializeMessages_StripsSystemParts(t *testing.T) { + messages := []Message{ + { + Role: "system", + Content: "you are helpful", + SystemParts: []protocoltypes.ContentBlock{ + {Type: "text", Text: "you are helpful"}, + }, + }, + } + result := SerializeMessages(messages) + + data, _ := json.Marshal(result) + if strings.Contains(string(data), "system_parts") { + t.Error("system_parts should not appear in serialized output") + } +} + +// --- ParseResponse tests --- + +func TestParseResponse_BasicContent(t *testing.T) { + body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if out.Content != "hello world" { + t.Errorf("Content = %q, want %q", out.Content, "hello world") + } + if out.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop") + } +} + +func TestParseResponse_EmptyChoices(t *testing.T) { + body := `{"choices":[]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if out.Content != "" { + t.Errorf("Content = %q, want empty", out.Content) + } + if out.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop") + } +} + +func TestParseResponse_WithToolCalls(t *testing.T) { + body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } +} + +func TestParseResponse_WithUsage(t *testing.T) { + body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if out.Usage == nil { + t.Fatal("Usage is nil") + } + if out.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens) + } +} + +func TestParseResponse_WithReasoningContent(t *testing.T) { + body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if out.ReasoningContent != "Let me think... 1+1=2" { + t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2") + } +} + +func TestParseResponse_InvalidJSON(t *testing.T) { + _, err := ParseResponse(strings.NewReader("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +// --- DecodeToolCallArguments tests --- + +func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) { + raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`) + args := DecodeToolCallArguments(raw, "test") + if args["city"] != "Seattle" { + t.Errorf("city = %v, want Seattle", args["city"]) + } + if args["units"] != "metric" { + t.Errorf("units = %v, want metric", args["units"]) + } +} + +func TestDecodeToolCallArguments_StringJSON(t *testing.T) { + raw := json.RawMessage(`"{\"city\":\"SF\"}"`) + args := DecodeToolCallArguments(raw, "test") + if args["city"] != "SF" { + t.Errorf("city = %v, want SF", args["city"]) + } +} + +func TestDecodeToolCallArguments_EmptyInput(t *testing.T) { + args := DecodeToolCallArguments(nil, "test") + if len(args) != 0 { + t.Errorf("expected empty map, got %v", args) + } +} + +func TestDecodeToolCallArguments_NullInput(t *testing.T) { + args := DecodeToolCallArguments(json.RawMessage(`null`), "test") + if len(args) != 0 { + t.Errorf("expected empty map, got %v", args) + } +} + +func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) { + args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test") + if _, ok := args["raw"]; !ok { + t.Error("expected 'raw' fallback key for invalid JSON") + } +} + +func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) { + args := DecodeToolCallArguments(json.RawMessage(`" "`), "test") + if len(args) != 0 { + t.Errorf("expected empty map for whitespace string, got %v", args) + } +} + +// --- HandleErrorResponse tests --- + +func TestHandleErrorResponse_JSONError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"bad request"}`)) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + err = HandleErrorResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "400") { + t.Errorf("error should contain status code, got %v", err) + } + if strings.Contains(err.Error(), "HTML") { + t.Errorf("should not mention HTML for JSON error, got %v", err) + } +} + +func TestHandleErrorResponse_HTMLError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte("bad gateway")) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + err = HandleErrorResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "HTML instead of JSON") { + t.Errorf("expected HTML error message, got %v", err) + } +} + +// --- ReadAndParseResponse tests --- + +func TestReadAndParseResponse_ValidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`)) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + out, err := ReadAndParseResponse(resp, server.URL) + if err != nil { + t.Fatalf("ReadAndParseResponse() error = %v", err) + } + if out.Content != "ok" { + t.Errorf("Content = %q, want %q", out.Content, "ok") + } +} + +func TestReadAndParseResponse_HTMLResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte("login page")) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + _, err = ReadAndParseResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error for HTML response") + } + if !strings.Contains(err.Error(), "HTML instead of JSON") { + t.Errorf("expected HTML error, got %v", err) + } +} + +// --- LooksLikeHTML tests --- + +func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) { + if !LooksLikeHTML(nil, "text/html; charset=utf-8") { + t.Error("expected true for text/html content type") + } +} + +func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) { + if !LooksLikeHTML(nil, "application/xhtml+xml") { + t.Error("expected true for xhtml content type") + } +} + +func TestLooksLikeHTML_BodyPrefix(t *testing.T) { + tests := []struct { + name string + body string + }{ + {"doctype", ""}, + {"html tag", ""}, + {"head tag", ""}, + {"body tag", "<body>content"}, + {"whitespace before", " \n\t<!DOCTYPE html>"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !LooksLikeHTML([]byte(tt.body), "application/json") { + t.Errorf("expected true for body %q", tt.body) + } + }) + } +} + +func TestLooksLikeHTML_NotHTML(t *testing.T) { + if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") { + t.Error("expected false for JSON body") + } +} + +// --- ResponsePreview tests --- + +func TestResponsePreview_Short(t *testing.T) { + got := ResponsePreview([]byte("hello"), 128) + if got != "hello" { + t.Errorf("got %q, want %q", got, "hello") + } +} + +func TestResponsePreview_Truncated(t *testing.T) { + body := strings.Repeat("a", 200) + got := ResponsePreview([]byte(body), 128) + if len(got) != 131 { // 128 + "..." + t.Errorf("len = %d, want 131", len(got)) + } + if !strings.HasSuffix(got, "...") { + t.Error("expected ... suffix") + } +} + +func TestResponsePreview_Empty(t *testing.T) { + got := ResponsePreview([]byte(""), 128) + if got != "<empty>" { + t.Errorf("got %q, want %q", got, "<empty>") + } +} + +func TestResponsePreview_Whitespace(t *testing.T) { + got := ResponsePreview([]byte(" \n\t "), 128) + if got != "<empty>" { + t.Errorf("got %q, want %q for whitespace-only body", got, "<empty>") + } +} + +// --- AsInt tests --- + +func TestAsInt(t *testing.T) { + tests := []struct { + name string + val any + want int + ok bool + }{ + {"int", 42, 42, true}, + {"int64", int64(99), 99, true}, + {"float64", float64(512), 512, true}, + {"float32", float32(256), 256, true}, + {"string", "nope", 0, false}, + {"nil", nil, 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := AsInt(tt.val) + if ok != tt.ok || got != tt.want { + t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok) + } + }) + } +} + +// --- AsFloat tests --- + +func TestAsFloat(t *testing.T) { + tests := []struct { + name string + val any + want float64 + ok bool + }{ + {"float64", float64(0.7), 0.7, true}, + {"float32", float32(0.5), float64(float32(0.5)), true}, + {"int", 1, 1.0, true}, + {"int64", int64(100), 100.0, true}, + {"string", "nope", 0, false}, + {"nil", nil, 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := AsFloat(tt.val) + if ok != tt.ok || got != tt.want { + t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok) + } + }) + } +} + +// --- WrapHTMLResponseError tests --- + +func TestWrapHTMLResponseError(t *testing.T) { + err := WrapHTMLResponseError(502, []byte("<html>bad</html>"), "text/html", "https://api.example.com") + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + if !strings.Contains(msg, "502") { + t.Errorf("expected status code in error, got %v", msg) + } + if !strings.Contains(msg, "https://api.example.com") { + t.Errorf("expected api base in error, got %v", msg) + } + if !strings.Contains(msg, "HTML instead of JSON") { + t.Errorf("expected HTML mention in error, got %v", msg) + } +} + +// --- HandleErrorResponse with read failure --- + +func TestHandleErrorResponse_EmptyBody(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + // empty body + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + err = HandleErrorResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("expected status code, got %v", err) + } +} + +// --- ReadAndParseResponse with invalid JSON --- + +func TestReadAndParseResponse_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("not valid json")) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("http.Get() error = %v", err) + } + defer resp.Body.Close() + _, err = ReadAndParseResponse(resp, server.URL) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +// --- ParseResponse with thought_signature (Google/Gemini) --- + +func TestParseResponse_WithThoughtSignature(t *testing.T) { + body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}` + out, err := ParseResponse(strings.NewReader(body)) + if err != nil { + t.Fatalf("ParseResponse() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].ThoughtSignature != "sig123" { + t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123") + } + if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil { + t.Fatal("ExtraContent.Google is nil") + } + if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" { + t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q", + out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123") + } +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index e99e07bc26..b7567f9fcf 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -11,6 +11,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages" + "github.com/sipeed/picoclaw/pkg/providers/azure" ) // createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store. @@ -94,6 +95,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.RequestTimeout, ), modelID, nil + case "azure", "azure-openai": + // Azure OpenAI uses deployment-based URLs, api-key header auth, + // and always sends max_completion_tokens. + if cfg.APIKey == "" { + return nil, "", fmt.Errorf("api_key is required for azure protocol") + } + if cfg.APIBase == "" { + return nil, "", fmt.Errorf( + "api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)", + ) + } + return azure.NewProviderWithTimeout( + cfg.APIKey, + cfg.APIBase, + cfg.Proxy, + cfg.RequestTimeout, + ), modelID, nil + case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", "vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian", diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index 00676ebf98..b678a7eb61 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -64,6 +64,12 @@ func TestExtractProtocol(t *testing.T) { wantProtocol: "nvidia", wantModelID: "meta/llama-3.1-8b", }, + { + name: "azure with prefix", + model: "azure/my-gpt5-deployment", + wantProtocol: "azure", + wantModelID: "my-gpt5-deployment", + }, } for _, tt := range tests { @@ -371,3 +377,69 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) { t.Fatalf("Chat() error = %q, want timeout-related error", errMsg) } } + +func TestCreateProviderFromConfig_Azure(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "azure-gpt5", + Model: "azure/my-gpt5-deployment", + APIKey: "test-azure-key", + APIBase: "https://my-resource.openai.azure.com", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "my-gpt5-deployment" { + t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment") + } +} + +func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "azure-gpt4", + Model: "azure-openai/my-deployment", + APIKey: "test-azure-key", + APIBase: "https://my-resource.openai.azure.com", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "my-deployment" { + t.Errorf("modelID = %q, want %q", modelID, "my-deployment") + } +} + +func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "azure-gpt5", + Model: "azure/my-gpt5-deployment", + APIBase: "https://my-resource.openai.azure.com", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for missing API key") + } +} + +func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "azure-gpt5", + Model: "azure/my-gpt5-deployment", + APIKey: "test-azure-key", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for missing API base") + } +} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index f97bf3acd5..fb2abaa5c2 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -1,18 +1,16 @@ package openai_compat import ( - "bufio" "bytes" "context" "encoding/json" "fmt" - "io" - "log" "net/http" "net/url" "strings" "time" + "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -38,7 +36,7 @@ type Provider struct { type Option func(*Provider) -const defaultRequestTimeout = 120 * time.Second +const defaultRequestTimeout = common.DefaultRequestTimeout func WithMaxTokensField(maxTokensField string) Option { return func(p *Provider) { @@ -55,25 +53,10 @@ func WithRequestTimeout(timeout time.Duration) Option { } func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { - client := &http.Client{ - Timeout: defaultRequestTimeout, - } - - if proxy != "" { - parsed, err := url.Parse(proxy) - if err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(parsed), - } - } else { - log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err) - } - } - p := &Provider{ apiKey: apiKey, apiBase: strings.TrimRight(apiBase, "/"), - httpClient: client, + httpClient: common.NewHTTPClient(proxy), } for _, opt := range opts { @@ -117,7 +100,7 @@ func (p *Provider) Chat( requestBody := map[string]any{ "model": model, - "messages": serializeMessages(messages), + "messages": common.SerializeMessages(messages), } if len(tools) > 0 { @@ -125,7 +108,7 @@ func (p *Provider) Chat( requestBody["tool_choice"] = "auto" } - if maxTokens, ok := asInt(options["max_tokens"]); ok { + if maxTokens, ok := common.AsInt(options["max_tokens"]); ok { // Use configured maxTokensField if specified, otherwise fallback to model-based detection fieldName := p.maxTokensField if fieldName == "" { @@ -141,7 +124,7 @@ func (p *Provider) Chat( requestBody[fieldName] = maxTokens } - if temperature, ok := asFloat(options["temperature"]); ok { + if temperature, ok := common.AsFloat(options["temperature"]); ok { lowerModel := strings.ToLower(model) // Kimi k2 models only support temperature=1. if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { @@ -185,275 +168,11 @@ func (p *Provider) Chat( } defer resp.Body.Close() - contentType := resp.Header.Get("Content-Type") - - // Non-200: read a prefix to tell HTML error page apart from JSON error body. if resp.StatusCode != http.StatusOK { - body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256)) - if readErr != nil { - return nil, fmt.Errorf("failed to read response: %w", readErr) - } - if looksLikeHTML(body, contentType) { - return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase) - } - return nil, fmt.Errorf( - "API request failed:\n Status: %d\n Body: %s", - resp.StatusCode, - responsePreview(body, 128), - ) - } - - // Peek without consuming so the full stream reaches the JSON decoder. - reader := bufio.NewReader(resp.Body) - prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort - if err != nil && err != io.EOF && err != bufio.ErrBufferFull { - return nil, fmt.Errorf("failed to inspect response: %w", err) - } - if looksLikeHTML(prefix, contentType) { - return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase) - } - - out, err := parseResponse(reader) - if err != nil { - return nil, fmt.Errorf("failed to parse JSON response: %w", err) + return nil, common.HandleErrorResponse(resp, p.apiBase) } - return out, nil -} - -func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error { - respPreview := responsePreview(body, 128) - return fmt.Errorf( - "API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s", - apiBase, - contentType, - statusCode, - respPreview, - ) -} - -func looksLikeHTML(body []byte, contentType string) bool { - contentType = strings.ToLower(strings.TrimSpace(contentType)) - if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") { - return true - } - prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128)) - return bytes.HasPrefix(prefix, []byte("<!doctype html")) || - bytes.HasPrefix(prefix, []byte("<html")) || - bytes.HasPrefix(prefix, []byte("<head")) || - bytes.HasPrefix(prefix, []byte("<body")) -} - -func leadingTrimmedPrefix(body []byte, maxLen int) []byte { - i := 0 - for i < len(body) { - switch body[i] { - case ' ', '\t', '\n', '\r', '\f', '\v': - i++ - default: - end := i + maxLen - if end > len(body) { - end = len(body) - } - return body[i:end] - } - } - return nil -} - -func responsePreview(body []byte, maxLen int) string { - trimmed := bytes.TrimSpace(body) - if len(trimmed) == 0 { - return "<empty>" - } - if len(trimmed) <= maxLen { - return string(trimmed) - } - return string(trimmed[:maxLen]) + "..." -} - -func parseResponse(body io.Reader) (*LLMResponse, error) { - var apiResponse struct { - Choices []struct { - Message struct { - Content string `json:"content"` - ReasoningContent string `json:"reasoning_content"` - Reasoning string `json:"reasoning"` - ReasoningDetails []ReasoningDetail `json:"reasoning_details"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function *struct { - Name string `json:"name"` - Arguments json.RawMessage `json:"arguments"` - } `json:"function"` - ExtraContent *struct { - Google *struct { - ThoughtSignature string `json:"thought_signature"` - } `json:"google"` - } `json:"extra_content"` - } `json:"tool_calls"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage *UsageInfo `json:"usage"` - } - - if err := json.NewDecoder(body).Decode(&apiResponse); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - if len(apiResponse.Choices) == 0 { - return &LLMResponse{ - Content: "", - FinishReason: "stop", - }, nil - } - - choice := apiResponse.Choices[0] - toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for _, tc := range choice.Message.ToolCalls { - arguments := make(map[string]any) - name := "" - - // Extract thought_signature from Gemini/Google-specific extra content - thoughtSignature := "" - if tc.ExtraContent != nil && tc.ExtraContent.Google != nil { - thoughtSignature = tc.ExtraContent.Google.ThoughtSignature - } - - if tc.Function != nil { - name = tc.Function.Name - arguments = decodeToolCallArguments(tc.Function.Arguments, name) - } - - // Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence - toolCall := ToolCall{ - ID: tc.ID, - Name: name, - Arguments: arguments, - ThoughtSignature: thoughtSignature, - } - - if thoughtSignature != "" { - toolCall.ExtraContent = &ExtraContent{ - Google: &GoogleExtra{ - ThoughtSignature: thoughtSignature, - }, - } - } - - toolCalls = append(toolCalls, toolCall) - } - - return &LLMResponse{ - Content: choice.Message.Content, - ReasoningContent: choice.Message.ReasoningContent, - Reasoning: choice.Message.Reasoning, - ReasoningDetails: choice.Message.ReasoningDetails, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, - }, nil -} - -func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any { - arguments := make(map[string]any) - raw = bytes.TrimSpace(raw) - if len(raw) == 0 || bytes.Equal(raw, []byte("null")) { - return arguments - } - - var decoded any - if err := json.Unmarshal(raw, &decoded); err != nil { - log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err) - arguments["raw"] = string(raw) - return arguments - } - - switch v := decoded.(type) { - case string: - if strings.TrimSpace(v) == "" { - return arguments - } - if err := json.Unmarshal([]byte(v), &arguments); err != nil { - log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) - arguments["raw"] = v - } - return arguments - case map[string]any: - return v - default: - log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded) - arguments["raw"] = string(raw) - return arguments - } -} - -// openaiMessage is the wire-format message for OpenAI-compatible APIs. -// It mirrors protocoltypes.Message but omits SystemParts, which is an -// internal field that would be unknown to third-party endpoints. -type openaiMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ReasoningContent string `json:"reasoning_content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} - -// serializeMessages converts internal Message structs to the OpenAI wire format. -// - Strips SystemParts (unknown to third-party endpoints) -// - Converts messages with Media to multipart content format (text + image_url parts) -// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages -func serializeMessages(messages []Message) []any { - out := make([]any, 0, len(messages)) - for _, m := range messages { - if len(m.Media) == 0 { - out = append(out, openaiMessage{ - Role: m.Role, - Content: m.Content, - ReasoningContent: m.ReasoningContent, - ToolCalls: m.ToolCalls, - ToolCallID: m.ToolCallID, - }) - continue - } - - // Multipart content format for messages with media - parts := make([]map[string]any, 0, 1+len(m.Media)) - if m.Content != "" { - parts = append(parts, map[string]any{ - "type": "text", - "text": m.Content, - }) - } - for _, mediaURL := range m.Media { - if strings.HasPrefix(mediaURL, "data:image/") { - parts = append(parts, map[string]any{ - "type": "image_url", - "image_url": map[string]any{ - "url": mediaURL, - }, - }) - } - } - - msg := map[string]any{ - "role": m.Role, - "content": parts, - } - if m.ToolCallID != "" { - msg["tool_call_id"] = m.ToolCallID - } - if len(m.ToolCalls) > 0 { - msg["tool_calls"] = m.ToolCalls - } - if m.ReasoningContent != "" { - msg["reasoning_content"] = m.ReasoningContent - } - out = append(out, msg) - } - return out + return common.ReadAndParseResponse(resp, p.apiBase) } func normalizeModel(model, apiBase string) string { @@ -476,36 +195,6 @@ func normalizeModel(model, apiBase string) string { } } -func asInt(v any) (int, bool) { - switch val := v.(type) { - case int: - return val, true - case int64: - return int(val), true - case float64: - return int(val), true - case float32: - return int(val), true - default: - return 0, false - } -} - -func asFloat(v any) (float64, bool) { - switch val := v.(type) { - case float64: - return val, true - case float32: - return float64(val), true - case int: - return float64(val), true - case int64: - return float64(val), true - default: - return 0, false - } -} - // supportsPromptCacheKey reports whether the given API base is known to // support the prompt_cache_key request field. Currently only OpenAI's own // API and Azure OpenAI support this. All other OpenAI-compatible providers diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 41f278a1b1..ed9747f9d7 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -648,7 +649,7 @@ func TestSerializeMessages_PlainText(t *testing.T) { {Role: "user", Content: "hello"}, {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."}, } - result := serializeMessages(messages) + result := common.SerializeMessages(messages) data, err := json.Marshal(result) if err != nil { @@ -670,7 +671,7 @@ func TestSerializeMessages_WithMedia(t *testing.T) { messages := []protocoltypes.Message{ {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}}, } - result := serializeMessages(messages) + result := common.SerializeMessages(messages) data, _ := json.Marshal(result) var msgs []map[string]any @@ -703,7 +704,7 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { messages := []protocoltypes.Message{ {Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, } - result := serializeMessages(messages) + result := common.SerializeMessages(messages) data, _ := json.Marshal(result) var msgs []map[string]any @@ -833,7 +834,7 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) { }, }, } - result := serializeMessages(messages) + result := common.SerializeMessages(messages) data, _ := json.Marshal(result) raw := string(data)