Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log"
"net/http"
"net/url"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -40,6 +41,18 @@ type Option func(*Provider)

const defaultRequestTimeout = 120 * time.Second

var (
escapedTagReplacer = strings.NewReplacer(
`\u003c`, "<",
`\u003e`, ">",
`\u003C`, "<",
`\u003E`, ">",
)
reasoningTagPattern = regexp.MustCompile(`(?is)<(?:think|thinking|thought|reasoning)\b[^>]*>.*?</(?:think|thinking|thought|reasoning)\s*>`)
trailingReasoningTagPattern = regexp.MustCompile(`(?is)<(?:think|thinking|thought|reasoning)\b[^>]*>.*$`)
finalTagPattern = regexp.MustCompile(`(?is)</?final\b[^>]*>`)
)

func WithMaxTokensField(maxTokensField string) Option {
return func(p *Provider) {
p.maxTokensField = maxTokensField
Expand Down Expand Up @@ -351,7 +364,7 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
}

return &LLMResponse{
Content: choice.Message.Content,
Content: sanitizeAssistantContent(choice.Message.Content),
ReasoningContent: choice.Message.ReasoningContent,
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
Expand All @@ -361,6 +374,19 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
}, nil
}

func sanitizeAssistantContent(content string) string {
if content == "" {
return ""
}

sanitized := escapedTagReplacer.Replace(content)
sanitized = reasoningTagPattern.ReplaceAllString(sanitized, "")
sanitized = trailingReasoningTagPattern.ReplaceAllString(sanitized, "")
sanitized = finalTagPattern.ReplaceAllString(sanitized, "")

return strings.TrimSpace(sanitized)
}

// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
Expand Down
57 changes: 57 additions & 0 deletions pkg/providers/openai_compat/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,63 @@ func TestProviderChat_ParsesReasoningContent(t *testing.T) {
}
}

func TestParseResponse_StripsThinkingAndFinalTags(t *testing.T) {
body := strings.NewReader(`{
"choices": [{
"message": {
"content": "<think>internal reasoning</think><final>The answer is 2</final>"
},
"finish_reason": "stop"
}]
}`)

out, err := parseResponse(body)
if err != nil {
t.Fatalf("parseResponse() error = %v", err)
}
if out.Content != "The answer is 2" {
t.Fatalf("Content = %q, want %q", out.Content, "The answer is 2")
}
}

func TestParseResponse_StripsEscapedThinkingAndFinalTags(t *testing.T) {
body := strings.NewReader(`{
"choices": [{
"message": {
"content": "\\u003cthink\\u003einternal reasoning\\u003c/think\\u003e\\u003cfinal\\u003eThe answer is 2\\u003c/final\\u003e"
},
"finish_reason": "stop"
}]
}`)

out, err := parseResponse(body)
if err != nil {
t.Fatalf("parseResponse() error = %v", err)
}
if out.Content != "The answer is 2" {
t.Fatalf("Content = %q, want %q", out.Content, "The answer is 2")
}
}

func TestParseResponse_DropsTrailingUnclosedThinkingBlock(t *testing.T) {
body := strings.NewReader(`{
"choices": [{
"message": {
"content": "<final>The answer is 2</final><think>internal reasoning"
},
"finish_reason": "stop"
}]
}`)

out, err := parseResponse(body)
if err != nil {
t.Fatalf("parseResponse() error = %v", err)
}
if out.Content != "The answer is 2" {
t.Fatalf("Content = %q, want %q", out.Content, "The answer is 2")
}
}

func TestProviderChat_PreservesReasoningContentInHistory(t *testing.T) {
var requestBody map[string]any

Expand Down