Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
415 changes: 415 additions & 0 deletions pkg/providers/anthropic_messages/provider.go

Large diffs are not rendered by default.

622 changes: 622 additions & 0 deletions pkg/providers/anthropic_messages/provider_test.go

Large diffs are not rendered by default.

150 changes: 150 additions & 0 deletions pkg/providers/azure/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package azure

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)

type (
LLMResponse = protocoltypes.LLMResponse
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
)

const (
// azureAPIVersion is the Azure OpenAI API version used for all requests.
azureAPIVersion = "2024-10-21"
defaultRequestTimeout = common.DefaultRequestTimeout
)

// Provider implements the LLM provider interface for Azure OpenAI endpoints.
// It handles Azure-specific authentication (api-key header), URL construction
// (deployment-based), and request body formatting (max_completion_tokens, no model field).
type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
}

// Option configures the Azure Provider.
type Option func(*Provider)

// WithRequestTimeout sets the HTTP request timeout.
func WithRequestTimeout(timeout time.Duration) Option {
return func(p *Provider) {
if timeout > 0 {
p.httpClient.Timeout = timeout
}
}
}

// NewProvider creates a new Azure OpenAI provider.
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: common.NewHTTPClient(proxy),
}

for _, opt := range opts {
if opt != nil {
opt(p)
}
}

return p
}

// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
return NewProvider(
apiKey, apiBase, proxy,
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
)
}

// Chat sends a chat completion request to the Azure OpenAI endpoint.
// The model parameter is used as the Azure deployment name in the URL.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("Azure API base not configured")
}

// model is the deployment name for Azure OpenAI
deployment := model

// Build Azure-specific URL safely using url.JoinPath and query encoding
// to prevent path traversal or query injection via deployment names.
base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
if err != nil {
return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
}
requestURL := base + "?api-version=" + azureAPIVersion

// Build request body — no "model" field (Azure infers from deployment URL)
requestBody := map[string]any{
"messages": common.SerializeMessages(messages),
}

if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}

// Azure OpenAI always uses max_completion_tokens
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
requestBody["max_completion_tokens"] = maxTokens
}

if temperature, ok := common.AsFloat(options["temperature"]); ok {
requestBody["temperature"] = temperature
}

jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

// Azure uses api-key header instead of Authorization: Bearer
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Api-Key", p.apiKey)
}

resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}

return common.ReadAndParseResponse(resp, p.apiBase)
}

// GetDefaultModel returns an empty string as Azure deployments are user-configured.
func (p *Provider) GetDefaultModel() string {
return ""
}
232 changes: 232 additions & 0 deletions pkg/providers/azure/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package azure

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)

// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
func writeValidResponse(w http.ResponseWriter) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}

func TestProviderChat_AzureURLConstruction(t *testing.T) {
var capturedPath string
var capturedAPIVersion string

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedAPIVersion = r.URL.Query().Get("api-version")
writeValidResponse(w)
}))
defer server.Close()

p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}

wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
if capturedPath != wantPath {
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
}
if capturedAPIVersion != azureAPIVersion {
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
}
}

func TestProviderChat_AzureAuthHeader(t *testing.T) {
var capturedAPIKey string
var capturedAuth string

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAPIKey = r.Header.Get("Api-Key")
capturedAuth = r.Header.Get("Authorization")
writeValidResponse(w)
}))
defer server.Close()

p := NewProvider("test-azure-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}

if capturedAPIKey != "test-azure-key" {
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
}
if capturedAuth != "" {
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
}
}

func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
var requestBody map[string]any

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()

p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}

if _, exists := requestBody["model"]; exists {
t.Error("request body should not contain 'model' field for Azure OpenAI")
}
}

func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
var requestBody map[string]any

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()

p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"deployment",
map[string]any{"max_tokens": 2048},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}

if _, exists := requestBody["max_completion_tokens"]; !exists {
t.Error("request body should contain 'max_completion_tokens'")
}
if _, exists := requestBody["max_tokens"]; exists {
t.Error("request body should not contain 'max_tokens'")
}
}

func TestProviderChat_AzureHTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
}))
defer server.Close()

p := NewProvider("bad-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
}

func TestProviderChat_AzureParseToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{
"content": "",
"tool_calls": []map[string]any{
{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": `{"city":"Seattle"}`,
},
},
},
},
"finish_reason": "tool_calls",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()

p := NewProvider("test-key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}

if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
}

func TestProvider_AzureEmptyAPIBase(t *testing.T) {
p := NewProvider("test-key", "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for empty API base")
}
}

func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "")
if p.httpClient.Timeout != defaultRequestTimeout {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}

func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
if p.httpClient.Timeout != 300*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
}
}

func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
if p.httpClient.Timeout != 180*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
}
}

func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
var capturedPath string

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
if capturedPath == "" {
capturedPath = r.URL.Path
}
writeValidResponse(w)
}))
defer server.Close()

p := NewProvider("test-key", server.URL, "")

// Deployment name with characters that could cause path injection
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}

// The slash and special chars in the deployment name must be escaped, not treated as path separators
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
t.Fatal("deployment name was interpolated without escaping — path injection possible")
}
}
Loading
Loading