Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
Expand Down Expand Up @@ -323,12 +323,7 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {

if tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = tc.Function.Arguments
}
}
arguments = decodeToolCallArguments(tc.Function.Arguments, name)
}

// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
Expand Down Expand Up @@ -361,6 +356,39 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
}, nil
}

func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
raw = bytes.TrimSpace(raw)
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
return arguments
}

var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err)
arguments["raw"] = string(raw)
return arguments
}

switch v := decoded.(type) {
case string:
if strings.TrimSpace(v) == "" {
return arguments
}
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = v
}
return arguments
case map[string]any:
return v
default:
log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded)
arguments["raw"] = string(raw)
return arguments
}
}

// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
Expand Down
49 changes: 49 additions & 0 deletions pkg/providers/openai_compat/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,55 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) {
}
}

func TestProviderChat_ParsesToolCallsWithObjectArguments(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{
"content": "",
"tool_calls": []map[string]any{
{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": map[string]any{
"city": "SF",
"metric": true,
},
},
},
},
},
"finish_reason": "tool_calls",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()

p := NewProvider("key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
if out.ToolCalls[0].Arguments["city"] != "SF" {
t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
}
if out.ToolCalls[0].Arguments["metric"] != true {
t.Fatalf("ToolCalls[0].Arguments[metric] = %v, want true", out.ToolCalls[0].Arguments["metric"])
}
}

func TestProviderChat_ParsesReasoningContent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
Expand Down
Loading