From e4956f6b4f95be24087d93111d62da0af809735b Mon Sep 17 00:00:00 2001 From: XYSK-lilong007 <267018309+XYSK-lilong007@users.noreply.github.com> Date: Thu, 12 Mar 2026 01:24:27 +0800 Subject: [PATCH] fix: accept object tool arguments in openai compat --- pkg/providers/openai_compat/provider.go | 57 +++++++++++++++++--- pkg/providers/openai_compat/provider_test.go | 52 ++++++++++++++++++ 2 files changed, 101 insertions(+), 8 deletions(-) diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 0e8db74097..8feaea7ac9 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -283,8 +283,8 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { ID string `json:"id"` Type string `json:"type"` Function *struct { - Name string `json:"name"` - Arguments string `json:"arguments"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` } `json:"function"` ExtraContent *struct { Google *struct { @@ -314,6 +314,7 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { for _, tc := range choice.Message.ToolCalls { arguments := make(map[string]any) name := "" + rawArguments := "" // Extract thought_signature from Gemini/Google-specific extra content thoughtSignature := "" @@ -323,22 +324,25 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { if tc.Function != nil { name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) - arguments["raw"] = tc.Function.Arguments - } - } + rawArguments, arguments = decodeToolCallArguments(tc.Function.Arguments) } // Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence toolCall := ToolCall{ ID: tc.ID, + Type: tc.Type, Name: name, Arguments: arguments, ThoughtSignature: thoughtSignature, } + if tc.Function != nil { + toolCall.Function = &FunctionCall{ + Name: name, + Arguments: rawArguments, + } + } + if thoughtSignature != "" { toolCall.ExtraContent = &ExtraContent{ Google: &GoogleExtra{ @@ -361,6 +365,43 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { }, nil } +func decodeToolCallArguments(raw json.RawMessage) (string, map[string]any) { + raw = bytes.TrimSpace(raw) + if len(raw) == 0 || bytes.Equal(raw, []byte("null")) { + return "", map[string]any{} + } + + if raw[0] == '"' { + var encoded string + if err := json.Unmarshal(raw, &encoded); err != nil { + log.Printf("openai_compat: failed to decode tool call argument string: %v", err) + return string(raw), map[string]any{"raw": string(raw)} + } + + arguments := map[string]any{} + if encoded == "" { + return "", arguments + } + if err := json.Unmarshal([]byte(encoded), &arguments); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments JSON: %v", err) + arguments["raw"] = encoded + } + return encoded, arguments + } + + var decoded any + if err := json.Unmarshal(raw, &decoded); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments payload: %v", err) + return string(raw), map[string]any{"raw": string(raw)} + } + + if object, ok := decoded.(map[string]any); ok { + return string(raw), object + } + + return string(raw), map[string]any{"raw": decoded} +} + // openaiMessage is the wire-format message for OpenAI-compatible APIs. // It mirrors protocoltypes.Message but omits SystemParts, which is an // internal field that would be unknown to third-party endpoints. diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 9a3a7acc5c..84a408eaab 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -106,6 +106,58 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) { if out.ToolCalls[0].Arguments["city"] != "SF" { t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) } + if out.ToolCalls[0].Function == nil { + t.Fatal("ToolCalls[0].Function = nil, want populated function call") + } + if out.ToolCalls[0].Function.Arguments != "{\"city\":\"SF\"}" { + t.Fatalf("ToolCalls[0].Function.Arguments = %q, want JSON string", out.ToolCalls[0].Function.Arguments) + } +} + +func TestProviderChat_ParsesToolCallsWithObjectArguments(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": "", + "tool_calls": []map[string]any{ + { + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": map[string]any{"city": "SF"}, + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } + if out.ToolCalls[0].Function == nil { + t.Fatal("ToolCalls[0].Function = nil, want populated function call") + } + if out.ToolCalls[0].Function.Arguments != "{\"city\":\"SF\"}" { + t.Fatalf("ToolCalls[0].Function.Arguments = %q, want JSON object string", out.ToolCalls[0].Function.Arguments) + } } func TestProviderChat_ParsesReasoningContent(t *testing.T) {