diff --git a/README.md b/README.md index 5cf9f61438..82a0e44f26 100644 --- a/README.md +++ b/README.md @@ -1097,6 +1097,26 @@ This design also enables **multi-agent support** with flexible provider selectio > Run `picoclaw auth login --provider anthropic` to paste your API token. +**Anthropic Messages API (native format)** + +For direct Anthropic API access or custom endpoints that only support Anthropic's native message format: + +```json +{ + "model_name": "claude-opus-4-6", + "model": "anthropic-messages/claude-opus-4-6", + "api_key": "sk-ant-your-key", + "api_base": "https://api.anthropic.com" +} +``` + +> Use `anthropic-messages` protocol when: +> - Using third-party proxies that only support Anthropic's native `/v1/messages` endpoint (not OpenAI-compatible `/v1/chat/completions`) +> - Connecting to services like MiniMax, Synthetic that require Anthropic's native message format +> - The existing `anthropic` protocol returns 404 errors (indicating the endpoint doesn't support OpenAI-compatible format) +> +> **Note:** The `anthropic` protocol uses OpenAI-compatible format (`/v1/chat/completions`), while `anthropic-messages` uses Anthropic's native format (`/v1/messages`). Choose based on your endpoint's supported format. + **Ollama (local)** ```json diff --git a/README.zh.md b/README.zh.md index c744e0d20c..37360d4fd7 100644 --- a/README.zh.md +++ b/README.zh.md @@ -593,6 +593,26 @@ Agent 读取 HEARTBEAT.md > 运行 `picoclaw auth login --provider anthropic` 来设置 OAuth 凭证。 +**Anthropic Messages API(原生格式)** + +用于直接访问 Anthropic API 或仅支持 Anthropic 原生消息格式的自定义端点: + +```json +{ + "model_name": "claude-opus-4-6", + "model": "anthropic-messages/claude-opus-4-6", + "api_key": "sk-ant-your-key", + "api_base": "https://api.anthropic.com" +} +``` + +> 使用 `anthropic-messages` 协议的场景: +> - 使用仅支持 Anthropic 原生 `/v1/messages` 端点的第三方代理(不支持 OpenAI 兼容的 `/v1/chat/completions`) +> - 连接到 MiniMax、Synthetic 等需要 Anthropic 原生消息格式的服务 +> - 现有的 `anthropic` 协议返回 404 错误(说明端点不支持 OpenAI 兼容格式) +> +> **注意:** `anthropic` 协议使用 OpenAI 兼容格式(`/v1/chat/completions`),而 `anthropic-messages` 使用 Anthropic 原生格式(`/v1/messages`)。请根据端点支持的格式选择。 + **Ollama (本地)** ```json diff --git a/config/config.example.json b/config/config.example.json index 0e2cae8e5b..4ae4613816 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -25,6 +25,13 @@ "api_base": "https://api.anthropic.com/v1", "thinking_level": "high" }, + { + "_comment": "Anthropic Messages API - use native format for direct Anthropic API access", + "model_name": "claude-opus-4-6", + "model": "anthropic-messages/claude-opus-4-6", + "api_key": "sk-ant-your-key", + "api_base": "https://api.anthropic.com" + }, { "model_name": "gemini", "model": "antigravity/gemini-2.0-flash", diff --git a/pkg/providers/anthropic_messages/provider.go b/pkg/providers/anthropic_messages/provider.go new file mode 100644 index 0000000000..8a83a7058f --- /dev/null +++ b/pkg/providers/anthropic_messages/provider.go @@ -0,0 +1,415 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package anthropicmessages + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ( + ToolCall = protocoltypes.ToolCall + FunctionCall = protocoltypes.FunctionCall + LLMResponse = protocoltypes.LLMResponse + UsageInfo = protocoltypes.UsageInfo + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition + ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition +) + +const ( + defaultAPIVersion = "2023-06-01" + defaultBaseURL = "https://api.anthropic.com/v1" + defaultRequestTimeout = 120 * time.Second +) + +// Provider implements Anthropic Messages API via HTTP (without SDK). +// It supports custom endpoints that use Anthropic's native message format. +type Provider struct { + apiKey string + apiBase string + httpClient *http.Client +} + +// NewProvider creates a new Anthropic Messages API provider. +func NewProvider(apiKey, apiBase string) *Provider { + return NewProviderWithTimeout(apiKey, apiBase, 0) +} + +// NewProviderWithTimeout creates a provider with custom request timeout. +func NewProviderWithTimeout(apiKey, apiBase string, timeoutSeconds int) *Provider { + baseURL := normalizeBaseURL(apiBase) + timeout := defaultRequestTimeout + if timeoutSeconds > 0 { + timeout = time.Duration(timeoutSeconds) * time.Second + } + + return &Provider{ + apiKey: apiKey, + apiBase: baseURL, + httpClient: &http.Client{ + Timeout: timeout, + }, + } +} + +// Chat sends messages to the Anthropic Messages API and returns the response. +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + if p.apiKey == "" { + return nil, fmt.Errorf("API key not configured") + } + + // Build request body + requestBody, err := buildRequestBody(messages, tools, model, options) + if err != nil { + return nil, fmt.Errorf("building request body: %w", err) + } + + // Serialize to JSON + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("serializing request body: %w", err) + } + + // Build request URL + endpointURL, err := url.JoinPath(p.apiBase, "messages") + if err != nil { + return nil, fmt.Errorf("building endpoint URL: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("creating HTTP request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-API-Key", p.apiKey) //nolint:canonicalheader // Anthropic API requires exact header name + req.Header.Set("Anthropic-Version", defaultAPIVersion) + + // Execute request + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("executing HTTP request: %w", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response body: %w", err) + } + + // Check for HTTP errors with detailed messages + switch resp.StatusCode { + case http.StatusUnauthorized: + return nil, fmt.Errorf("authentication failed (401): check your API key") + case http.StatusTooManyRequests: + return nil, fmt.Errorf("rate limited (429): %s", string(body)) + case http.StatusBadRequest: + return nil, fmt.Errorf("bad request (400): %s", string(body)) + case http.StatusNotFound: + return nil, fmt.Errorf("endpoint not found (404): %s", string(body)) + case http.StatusInternalServerError: + return nil, fmt.Errorf("internal server error (500): %s", string(body)) + case http.StatusServiceUnavailable: + return nil, fmt.Errorf("service unavailable (503): %s", string(body)) + default: + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + } + + // Parse response + return parseResponseBody(body) +} + +// GetDefaultModel returns the default model for this provider. +func (p *Provider) GetDefaultModel() string { + return "claude-sonnet-4.6" +} + +// buildRequestBody converts internal message format to Anthropic Messages API format. +func buildRequestBody( + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (map[string]any, error) { + // max_tokens is required and guaranteed by agent loop + maxTokens, ok := asInt(options["max_tokens"]) + if !ok { + return nil, fmt.Errorf("max_tokens is required in options") + } + + result := map[string]any{ + "model": model, + "max_tokens": int64(maxTokens), + "messages": []any{}, + } + + // Set temperature from options + if temp, ok := asFloat(options["temperature"]); ok { + result["temperature"] = temp + } + + // Process messages + var systemPrompt string + var apiMessages []any + + for _, msg := range messages { + switch msg.Role { + case "system": + // Accumulate system messages + if systemPrompt != "" { + systemPrompt += "\n\n" + msg.Content + } else { + systemPrompt = msg.Content + } + + case "user": + if msg.ToolCallID != "" { + // Tool result message + content := []map[string]any{ + { + "type": "tool_result", + "tool_use_id": msg.ToolCallID, + "content": msg.Content, + }, + } + apiMessages = append(apiMessages, map[string]any{ + "role": "user", + "content": content, + }) + } else { + // Regular user message + apiMessages = append(apiMessages, map[string]any{ + "role": "user", + "content": msg.Content, + }) + } + + case "assistant": + content := []any{} + + // Add text content if present + if msg.Content != "" { + content = append(content, map[string]any{ + "type": "text", + "text": msg.Content, + }) + } + + // Add tool_use blocks + for _, tc := range msg.ToolCalls { + toolUse := map[string]any{ + "type": "tool_use", + "id": tc.ID, + "name": tc.Name, + "input": tc.Arguments, + } + content = append(content, toolUse) + } + + apiMessages = append(apiMessages, map[string]any{ + "role": "assistant", + "content": content, + }) + + case "tool": + // Tool result (alternative format) + content := []map[string]any{ + { + "type": "tool_result", + "tool_use_id": msg.ToolCallID, + "content": msg.Content, + }, + } + apiMessages = append(apiMessages, map[string]any{ + "role": "user", + "content": content, + }) + } + } + + result["messages"] = apiMessages + + // Set system prompt if present + if systemPrompt != "" { + result["system"] = systemPrompt + } + + // Add tools if present + if len(tools) > 0 { + result["tools"] = buildTools(tools) + } + + return result, nil +} + +// buildTools converts tool definitions to Anthropic format. +func buildTools(tools []ToolDefinition) []any { + result := make([]any, len(tools)) + for i, tool := range tools { + toolDef := map[string]any{ + "name": tool.Function.Name, + "description": tool.Function.Description, + "input_schema": tool.Function.Parameters, + } + result[i] = toolDef + } + return result +} + +// parseResponseBody parses Anthropic Messages API response. +func parseResponseBody(body []byte) (*LLMResponse, error) { + var resp anthropicMessageResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("parsing JSON response: %w", err) + } + + // Extract content and tool calls + var content strings.Builder + toolCalls := make([]ToolCall, 0) // Initialize as empty slice (not nil) for consistent JSON serialization + + for _, block := range resp.Content { + switch block.Type { + case "text": + content.WriteString(block.Text) + case "tool_use": + argsJSON, _ := json.Marshal(block.Input) + toolCalls = append(toolCalls, ToolCall{ + ID: block.ID, + Name: block.Name, + Arguments: block.Input, + Function: &FunctionCall{ + Name: block.Name, + Arguments: string(argsJSON), + }, + }) + } + } + + // Map stop_reason + finishReason := "stop" + switch resp.StopReason { + case "tool_use": + finishReason = "tool_calls" + case "max_tokens": + finishReason = "length" + case "end_turn": + finishReason = "stop" + case "stop_sequence": + finishReason = "stop" + } + + return &LLMResponse{ + Content: content.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + }, nil +} + +// normalizeBaseURL ensures the base URL is properly formatted. +// It removes /v1 suffix if present (to avoid duplication) and always appends /v1. +// This handles edge cases like "https://api.example.com/v1/proxy" correctly. +func normalizeBaseURL(apiBase string) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + // Remove trailing slashes + base = strings.TrimRight(base, "/") + + // Remove /v1 suffix if present (will be re-added) + // This prevents duplication for URLs like "https://api.example.com/v1/proxy" + if before, ok := strings.CutSuffix(base, "/v1"); ok { + base = before + } + + // Ensure we don't have an empty string after cutting + if base == "" { + return defaultBaseURL + } + + // Add /v1 suffix (required by Anthropic Messages API) + return base + "/v1" +} + +// Helper functions for type conversion + +func asInt(v any) (int, bool) { + switch val := v.(type) { + case int: + return val, true + case float64: + return int(val), true + case int64: + return int(val), true + default: + return 0, false + } +} + +func asFloat(v any) (float64, bool) { + switch val := v.(type) { + case float64: + return val, true + case int: + return float64(val), true + case int64: + return float64(val), true + default: + return 0, false + } +} + +// Anthropic API response structures + +type anthropicMessageResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []contentBlock `json:"content"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Usage usageInfo `json:"usage"` +} + +type contentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input map[string]any `json:"input,omitempty"` +} + +type usageInfo struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` +} diff --git a/pkg/providers/anthropic_messages/provider_test.go b/pkg/providers/anthropic_messages/provider_test.go new file mode 100644 index 0000000000..da4213e92a --- /dev/null +++ b/pkg/providers/anthropic_messages/provider_test.go @@ -0,0 +1,622 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package anthropicmessages + +import ( + "context" + "encoding/json" + "reflect" + "strings" + "testing" +) + +func TestBuildRequestBody(t *testing.T) { + tests := []struct { + name string + messages []Message + tools []ToolDefinition + model string + options map[string]any + want map[string]any + wantErr bool + }{ + { + name: "basic user message", + messages: []Message{ + {Role: "user", Content: "Hello, world!"}, + }, + model: "test-model", + options: map[string]any{ + "max_tokens": 8192, + }, + want: map[string]any{ + "model": "test-model", + "max_tokens": int64(8192), + "messages": []any{ + map[string]any{ + "role": "user", + "content": "Hello, world!", + }, + }, + }, + }, + { + name: "user and assistant messages", + messages: []Message{ + {Role: "user", Content: "What is 2+2?"}, + {Role: "assistant", Content: "4"}, + }, + model: "test-model", + options: map[string]any{ + "max_tokens": 8192, + }, + want: map[string]any{ + "model": "test-model", + "max_tokens": int64(8192), + "messages": []any{ + map[string]any{ + "role": "user", + "content": "What is 2+2?", + }, + map[string]any{ + "role": "assistant", + "content": []any{ + map[string]any{ + "type": "text", + "text": "4", + }, + }, + }, + }, + }, + }, + { + name: "with system message", + messages: []Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello"}, + }, + model: "test-model", + options: map[string]any{ + "max_tokens": 8192, + }, + want: map[string]any{ + "model": "test-model", + "max_tokens": int64(8192), + "system": "You are a helpful assistant.", + "messages": []any{ + map[string]any{ + "role": "user", + "content": "Hello", + }, + }, + }, + }, + { + name: "with custom max_tokens and temperature", + messages: []Message{ + {Role: "user", Content: "Test"}, + }, + model: "test-model", + options: map[string]any{ + "max_tokens": 2048, + "temperature": 0.5, + }, + want: map[string]any{ + "model": "test-model", + "max_tokens": int64(2048), + "temperature": 0.5, + "messages": []any{ + map[string]any{ + "role": "user", + "content": "Test", + }, + }, + }, + }, + { + name: "missing max_tokens returns error", + messages: []Message{ + {Role: "user", Content: "Test"}, + }, + model: "test-model", + options: map[string]any{}, + want: nil, + wantErr: true, + }, + { + name: "with tools", + messages: []Message{ + {Role: "user", Content: "What's the weather?"}, + }, + tools: []ToolDefinition{ + { + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get current weather", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "type": "string", + "description": "City name", + }, + }, + }, + }, + }, + }, + model: "test-model", + options: map[string]any{ + "max_tokens": 8192, + }, + want: map[string]any{ + "model": "test-model", + "max_tokens": int64(8192), + "messages": []any{ + map[string]any{ + "role": "user", + "content": "What's the weather?", + }, + }, + "tools": []any{ + map[string]any{ + "name": "get_weather", + "description": "Get current weather", + "input_schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "type": "string", + "description": "City name", + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := buildRequestBody(tt.messages, tt.tools, tt.model, tt.options) + if (err != nil) != tt.wantErr { + t.Errorf("buildRequestBody() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + gotJSON, _ := json.MarshalIndent(got, "", " ") + wantJSON, _ := json.MarshalIndent(tt.want, "", " ") + t.Errorf("buildRequestBody() mismatch:\ngot:\n%s\nwant:\n%s", gotJSON, wantJSON) + } + }) + } +} + +func TestParseResponseBody(t *testing.T) { + tests := []struct { + name string + body []byte + want *LLMResponse + wantErr bool + }{ + { + name: "basic text response", + body: []byte(`{ + "id": "msg-123", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello, how can I help?"} + ], + "stop_reason": "end_turn", + "model": "test-model", + "usage": { + "input_tokens": 10, + "output_tokens": 5 + } + }`), + want: &LLMResponse{ + Content: "Hello, how can I help?", + ToolCalls: []ToolCall{}, + FinishReason: "stop", + Usage: &UsageInfo{ + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + }, + Reasoning: "", + ReasoningDetails: nil, + }, + wantErr: false, + }, + { + name: "response with tool use", + body: []byte(`{ + "id": "msg-456", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll check the weather for you."}, + { + "type": "tool_use", + "id": "toolu-123", + "name": "get_weather", + "input": {"location": "Tokyo"} + } + ], + "stop_reason": "tool_use", + "model": "test-model", + "usage": { + "input_tokens": 20, + "output_tokens": 15 + } + }`), + want: &LLMResponse{ + Content: "I'll check the weather for you.", + ToolCalls: []ToolCall{ + { + ID: "toolu-123", + Name: "get_weather", + Arguments: map[string]any{ + "location": "Tokyo", + }, + Function: &FunctionCall{ + Name: "get_weather", + Arguments: `{"location":"Tokyo"}`, + }, + }, + }, + FinishReason: "tool_calls", + Usage: &UsageInfo{ + PromptTokens: 20, + CompletionTokens: 15, + TotalTokens: 35, + }, + Reasoning: "", + ReasoningDetails: nil, + }, + wantErr: false, + }, + { + name: "invalid JSON", + body: []byte(`invalid json`), + want: nil, + wantErr: true, + }, + { + name: "max_tokens stop reason", + body: []byte(`{ + "id": "msg-789", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "Partial response"} + ], + "stop_reason": "max_tokens", + "model": "test-model", + "usage": { + "input_tokens": 100, + "output_tokens": 4096 + } + }`), + want: &LLMResponse{ + Content: "Partial response", + ToolCalls: []ToolCall{}, + FinishReason: "length", + Usage: &UsageInfo{ + PromptTokens: 100, + CompletionTokens: 4096, + TotalTokens: 4196, + }, + Reasoning: "", + ReasoningDetails: nil, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseResponseBody(tt.body) + if (err != nil) != tt.wantErr { + t.Errorf("parseResponseBody() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + return + } + + // Compare individual fields + if got.Content != tt.want.Content { + t.Errorf("Content = %q, want %q", got.Content, tt.want.Content) + } + if got.FinishReason != tt.want.FinishReason { + t.Errorf("FinishReason = %q, want %q", got.FinishReason, tt.want.FinishReason) + } + if got.Usage == nil && tt.want.Usage != nil { + t.Errorf("Usage = nil, want non-nil") + } else if got.Usage != nil && tt.want.Usage == nil { + t.Errorf("Usage = non-nil, want nil") + } else if got.Usage != nil && tt.want.Usage != nil { + if got.Usage.PromptTokens != tt.want.Usage.PromptTokens { + t.Errorf("Usage.PromptTokens = %d, want %d", got.Usage.PromptTokens, tt.want.Usage.PromptTokens) + } + if got.Usage.CompletionTokens != tt.want.Usage.CompletionTokens { + t.Errorf("Usage.CompletionTokens = %d, want %d", + got.Usage.CompletionTokens, tt.want.Usage.CompletionTokens) + } + if got.Usage.TotalTokens != tt.want.Usage.TotalTokens { + t.Errorf("Usage.TotalTokens = %d, want %d", got.Usage.TotalTokens, tt.want.Usage.TotalTokens) + } + } + if len(got.ToolCalls) != len(tt.want.ToolCalls) { + t.Errorf("ToolCalls length = %d, want %d", len(got.ToolCalls), len(tt.want.ToolCalls)) + } else { + for i := range got.ToolCalls { + if got.ToolCalls[i].ID != tt.want.ToolCalls[i].ID { + t.Errorf("ToolCalls[%d].ID = %q, want %q", + i, got.ToolCalls[i].ID, tt.want.ToolCalls[i].ID) + } + if got.ToolCalls[i].Name != tt.want.ToolCalls[i].Name { + t.Errorf("ToolCalls[%d].Name = %q, want %q", + i, got.ToolCalls[i].Name, tt.want.ToolCalls[i].Name) + } + } + } + }) + } +} + +func TestNormalizeBaseURL(t *testing.T) { + tests := []struct { + name string + apiBase string + expected string + }{ + { + name: "empty string defaults to official API", + apiBase: "", + expected: "https://api.anthropic.com/v1", + }, + { + name: "URL without /v1 gets it appended", + apiBase: "https://api.example.com/anthropic", + expected: "https://api.example.com/anthropic/v1", + }, + { + name: "URL with /v1 remains unchanged", + apiBase: "https://api.example.com/v1", + expected: "https://api.example.com/v1", + }, + { + name: "URL with trailing slash gets cleaned", + apiBase: "https://api.example.com/anthropic/", + expected: "https://api.example.com/anthropic/v1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeBaseURL(tt.apiBase) + if got != tt.expected { + t.Errorf("normalizeBaseURL(%q) = %q, want %q", tt.apiBase, got, tt.expected) + } + }) + } +} + +func TestNewProvider(t *testing.T) { + provider := NewProvider("test-key", "https://api.example.com") + if provider == nil { + t.Fatal("NewProvider() returned nil") + } + if provider.apiKey != "test-key" { + t.Errorf("provider.apiKey = %q, want %q", provider.apiKey, "test-key") + } + if provider.apiBase != "https://api.example.com/v1" { + t.Errorf("provider.apiBase = %q, want %q", provider.apiBase, "https://api.example.com/v1") + } +} + +func TestGetDefaultModel(t *testing.T) { + provider := NewProvider("test-key", "") + got := provider.GetDefaultModel() + expected := "claude-sonnet-4.6" + if got != expected { + t.Errorf("GetDefaultModel() = %q, want %q", got, expected) + } +} + +// TestBuildRequestBodyEdgeCases tests edge cases for buildRequestBody. +func TestBuildRequestBodyEdgeCases(t *testing.T) { + tests := []struct { + name string + messages []Message + tools []ToolDefinition + model string + options map[string]any + wantErr bool + }{ + { + name: "empty message list", + messages: []Message{}, + model: "test-model", + options: map[string]any{ + "max_tokens": 8192, + }, + wantErr: false, + }, + { + name: "very long system message", + messages: []Message{ + {Role: "system", Content: strings.Repeat("This is a very long system prompt. ", 1000)}, + {Role: "user", Content: "Hello"}, + }, + model: "test-model", + options: map[string]any{ + "max_tokens": 8192, + }, + wantErr: false, + }, + { + name: "multiple consecutive system messages", + messages: []Message{ + {Role: "system", Content: "First system message"}, + {Role: "system", Content: "Second system message"}, + {Role: "system", Content: "Third system message"}, + {Role: "user", Content: "Hello"}, + }, + model: "test-model", + options: map[string]any{ + "max_tokens": 8192, + }, + wantErr: false, + }, + { + name: "tool result without tool call", + messages: []Message{ + {Role: "user", Content: "Use a tool"}, + {Role: "assistant", Content: "", ToolCalls: []ToolCall{ + {ID: "tool-1", Name: "test_tool", Arguments: map[string]any{"arg": "value"}}, + }}, + {Role: "user", ToolCallID: "tool-1", Content: "Tool result"}, + }, + model: "test-model", + options: map[string]any{ + "max_tokens": 8192, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := buildRequestBody(tt.messages, tt.tools, tt.model, tt.options) + if (err != nil) != tt.wantErr { + t.Errorf("buildRequestBody() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + return + } + + // Verify basic structure + if got == nil { + t.Error("buildRequestBody() returned nil") + return + } + if got["model"] != tt.model { + t.Errorf("model = %v, want %v", got["model"], tt.model) + } + }) + } +} + +// TestParseResponseBodyEdgeCases tests edge cases for parseResponseBody. +func TestParseResponseBodyEdgeCases(t *testing.T) { + tests := []struct { + name string + body []byte + wantErr bool + check func(*testing.T, *LLMResponse) + }{ + { + name: "empty content blocks", + body: []byte(`{ + "id": "msg-empty", + "type": "message", + "role": "assistant", + "content": [], + "stop_reason": "end_turn", + "model": "test-model", + "usage": {"input_tokens": 5, "output_tokens": 0} + }`), + wantErr: false, + check: func(t *testing.T, resp *LLMResponse) { + if resp.Content != "" { + t.Errorf("Content = %q, want empty string", resp.Content) + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls length = %d, want 0", len(resp.ToolCalls)) + } + }, + }, + { + name: "multiple tool use blocks", + body: []byte(`{ + "id": "msg-multi", + "type": "message", + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "tool-1", "name": "func1", "input": {"arg": "val1"}}, + {"type": "tool_use", "id": "tool-2", "name": "func2", "input": {"arg": "val2"}} + ], + "stop_reason": "tool_use", + "model": "test-model", + "usage": {"input_tokens": 10, "output_tokens": 20} + }`), + wantErr: false, + check: func(t *testing.T, resp *LLMResponse) { + if len(resp.ToolCalls) != 2 { + t.Errorf("ToolCalls length = %d, want 2", len(resp.ToolCalls)) + } + }, + }, + { + name: "malformed JSON response", + body: []byte(`{invalid json`), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseResponseBody(tt.body) + if (err != nil) != tt.wantErr { + t.Errorf("parseResponseBody() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.check != nil && err == nil { + tt.check(t, got) + } + }) + } +} + +// TestProviderChatErrors tests error handling in Chat. +// Note: apiBase check removed as it's dead code - normalizeBaseURL() always provides a default. +func TestProviderChatErrors(t *testing.T) { + tests := []struct { + name string + apiKey string + messages []Message + wantErrMsg string + }{ + { + name: "missing API key", + apiKey: "", + messages: []Message{{Role: "user", Content: "Test"}}, + wantErrMsg: "API key not configured", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create provider using constructor to ensure proper initialization + provider := NewProvider(tt.apiKey, "https://api.example.com") + + _, err := provider.Chat(context.Background(), tt.messages, nil, "test-model", nil) + if err == nil { + t.Fatal("Chat() expected error, got nil") + } + if err.Error() != tt.wantErrMsg { + t.Errorf("Chat() error = %q, want %q", err.Error(), tt.wantErrMsg) + } + }) + } +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index a798154cbb..96699f6d69 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/sipeed/picoclaw/pkg/config" + anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages" ) // createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store. @@ -53,7 +54,8 @@ func ExtractProtocol(model string) (protocol, modelID string) { // CreateProviderFromConfig creates a provider based on the ModelConfig. // It uses the protocol prefix in the Model field to determine which provider to create. -// Supported protocols: openai, litellm, anthropic, antigravity, claude-cli, codex-cli, github-copilot +// Supported protocols: openai, litellm, anthropic, anthropic-messages, antigravity, +// claude-cli, codex-cli, github-copilot // Returns the provider, the model ID (without protocol prefix), and any error. func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) { if cfg == nil { @@ -137,6 +139,21 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.RequestTimeout, ), modelID, nil + case "anthropic-messages": + // Anthropic Messages API with native format (HTTP-based, no SDK) + apiBase := cfg.APIBase + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + if cfg.APIKey == "" { + return nil, "", fmt.Errorf("api_key is required for anthropic-messages protocol (model: %s)", cfg.Model) + } + return anthropicmessages.NewProviderWithTimeout( + cfg.APIKey, + apiBase, + cfg.RequestTimeout, + ), modelID, nil + case "antigravity": return NewAntigravityProvider(), modelID, nil