Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
Expand Down Expand Up @@ -314,6 +314,7 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]any)
name := ""
rawArguments := ""

// Extract thought_signature from Gemini/Google-specific extra content
thoughtSignature := ""
Expand All @@ -323,22 +324,25 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {

if tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = tc.Function.Arguments
}
}
rawArguments, arguments = decodeToolCallArguments(tc.Function.Arguments)
}

// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
toolCall := ToolCall{
ID: tc.ID,
Type: tc.Type,
Name: name,
Arguments: arguments,
ThoughtSignature: thoughtSignature,
}

if tc.Function != nil {
toolCall.Function = &FunctionCall{
Name: name,
Arguments: rawArguments,
}
}

if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{
Expand All @@ -361,6 +365,43 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
}, nil
}

func decodeToolCallArguments(raw json.RawMessage) (string, map[string]any) {
raw = bytes.TrimSpace(raw)
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
return "", map[string]any{}
}

if raw[0] == '"' {
var encoded string
if err := json.Unmarshal(raw, &encoded); err != nil {
log.Printf("openai_compat: failed to decode tool call argument string: %v", err)
return string(raw), map[string]any{"raw": string(raw)}
}

arguments := map[string]any{}
if encoded == "" {
return "", arguments
}
if err := json.Unmarshal([]byte(encoded), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments JSON: %v", err)
arguments["raw"] = encoded
}
return encoded, arguments
}

var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments payload: %v", err)
return string(raw), map[string]any{"raw": string(raw)}
}

if object, ok := decoded.(map[string]any); ok {
return string(raw), object
}

return string(raw), map[string]any{"raw": decoded}
}

// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
Expand Down
52 changes: 52 additions & 0 deletions pkg/providers/openai_compat/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,58 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) {
if out.ToolCalls[0].Arguments["city"] != "SF" {
t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
}
if out.ToolCalls[0].Function == nil {
t.Fatal("ToolCalls[0].Function = nil, want populated function call")
}
if out.ToolCalls[0].Function.Arguments != "{\"city\":\"SF\"}" {
t.Fatalf("ToolCalls[0].Function.Arguments = %q, want JSON string", out.ToolCalls[0].Function.Arguments)
}
}

func TestProviderChat_ParsesToolCallsWithObjectArguments(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{
"content": "",
"tool_calls": []map[string]any{
{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": map[string]any{"city": "SF"},
},
},
},
},
"finish_reason": "tool_calls",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()

p := NewProvider("key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Arguments["city"] != "SF" {
t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
}
if out.ToolCalls[0].Function == nil {
t.Fatal("ToolCalls[0].Function = nil, want populated function call")
}
if out.ToolCalls[0].Function.Arguments != "{\"city\":\"SF\"}" {
t.Fatalf("ToolCalls[0].Function.Arguments = %q, want JSON object string", out.ToolCalls[0].Function.Arguments)
}
}

func TestProviderChat_ParsesReasoningContent(t *testing.T) {
Expand Down