diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index f97bf3acd5..6405b6ced7 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -164,6 +164,10 @@ func (p *Provider) Chat( } } + if strings.Contains(p.apiBase, "api.minimaxi.com") { + requestBody["reasoning_split"] = true + } + jsonData, err := json.Marshal(requestBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 41f278a1b1..48648330fa 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -60,6 +60,71 @@ func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { } } +func TestProviderChat_SkipsPromptCacheKeyForNvidia(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`)) + })) + defer server.Close() + + p := NewProvider("key", server.URL+"/integrate.api.nvidia.com/v1", "") + p.apiBase = server.URL + "/integrate.api.nvidia.com/v1" + + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "nvidia/llama-3.1", + map[string]any{"prompt_cache_key": "agent-main"}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if _, ok := requestBody["prompt_cache_key"]; ok { + t.Fatalf("did not expect prompt_cache_key for nvidia endpoint") + } +} + +func TestProviderChat_EnablesReasoningSplitForMiniMax(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok","reasoning_details":[{"type":"reasoning","text":"think"}]},"finish_reason":"stop"}]}`)) + })) + defer server.Close() + + p := NewProvider("key", server.URL+"/api.minimaxi.com/v1", "") + p.apiBase = server.URL + "/api.minimaxi.com/v1" + + out, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "minimax/MiniMax-M2.5", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if requestBody["reasoning_split"] != true { + t.Fatalf("reasoning_split = %v, want true", requestBody["reasoning_split"]) + } + if len(out.ReasoningDetails) != 1 { + t.Fatalf("len(ReasoningDetails) = %d, want 1", len(out.ReasoningDetails)) + } +} + func TestProviderChat_ParsesToolCalls(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := map[string]any{