Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 84 additions & 8 deletions pkg/providers/openai_compat/provider.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai_compat

import (
"bufio"
"bytes"
"context"
"encoding/json"
Expand Down Expand Up @@ -183,19 +184,94 @@ func (p *Provider) Chat(
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
contentType := resp.Header.Get("Content-Type")

// Non-200: read a prefix to tell HTML error page apart from JSON error body.
if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
if readErr != nil {
return nil, fmt.Errorf("failed to read response: %w", readErr)
}
if looksLikeHTML(body, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
}
return nil, fmt.Errorf(
"API request failed:\n Status: %d\n Body: %s",
resp.StatusCode,
responsePreview(body, 128),
)
}

// Peek without consuming so the full stream reaches the JSON decoder.
reader := bufio.NewReader(resp.Body)
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
return nil, fmt.Errorf("failed to inspect response: %w", err)
}
if looksLikeHTML(prefix, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
}

out, err := parseResponse(reader)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
return out, nil
}

func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
respPreview := responsePreview(body, 128)
return fmt.Errorf(
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
apiBase,
contentType,
statusCode,
respPreview,
)
}

func looksLikeHTML(body []byte, contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
bytes.HasPrefix(prefix, []byte("<html")) ||
bytes.HasPrefix(prefix, []byte("<head")) ||
bytes.HasPrefix(prefix, []byte("<body"))
}

func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
i := 0
for i < len(body) {
switch body[i] {
case ' ', '\t', '\n', '\r', '\f', '\v':
i++
default:
end := i + maxLen
if end > len(body) {
end = len(body)
}
return body[i:end]
}
}
return nil
}

return parseResponse(body)
func responsePreview(body []byte, maxLen int) string {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return "<empty>"
}
if len(trimmed) <= maxLen {
return string(trimmed)
}
return string(trimmed[:maxLen]) + "..."
}

func parseResponse(body []byte) (*LLMResponse, error) {
func parseResponse(body io.Reader) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Expand All @@ -222,8 +298,8 @@ func parseResponse(body []byte) (*LLMResponse, error) {
Usage *UsageInfo `json:"usage"`
}

if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}

if len(apiResponse.Choices) == 0 {
Expand Down
163 changes: 163 additions & 0 deletions pkg/providers/openai_compat/provider_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package openai_compat

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -212,6 +215,132 @@ func TestProviderChat_HTTPError(t *testing.T) {
}
}

func TestProviderChat_JSONHTTPErrorDoesNotReportHTML(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"bad request"}`))
}))
defer server.Close()

p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "Status: 400") {
t.Fatalf("expected status code in error, got %v", err)
}
if strings.Contains(err.Error(), "returned HTML instead of JSON") {
t.Fatalf("expected non-HTML http error, got %v", err)
}
}

func TestProviderChat_HTMLResponsesReturnHelpfulError(t *testing.T) {
tests := []struct {
name string
contentType string
statusCode int
body string
}{
{
name: "html success response",
contentType: "text/html; charset=utf-8",
statusCode: http.StatusOK,
body: "<!DOCTYPE html><html><body>gateway login</body></html>",
},
{
name: "html error response",
contentType: "text/html; charset=utf-8",
statusCode: http.StatusBadGateway,
body: "<!DOCTYPE html><html><body>bad gateway</body></html>",
},
{
name: "mislabeled html success response",
contentType: "application/json",
statusCode: http.StatusOK,
body: " \r\n\t<!DOCTYPE html><html><body>gateway login</body></html>",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.WriteHeader(tt.statusCode)
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()

p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), fmt.Sprintf("Status: %d", tt.statusCode)) {
t.Fatalf("expected status code in error, got %v", err)
}
if !strings.Contains(err.Error(), "returned HTML instead of JSON") {
t.Fatalf("expected helpful HTML error, got %v", err)
}
if !strings.Contains(err.Error(), "check api_base or proxy configuration") {
t.Fatalf("expected configuration hint, got %v", err)
}
})
}
}

func TestProviderChat_SuccessResponseUsesStreamingDecoder(t *testing.T) {
content := strings.Repeat("a", 1024)
body := `{"choices":[{"message":{"content":"` + content + `"},"finish_reason":"stop"}]}`

p := NewProvider("key", "https://example.com/v1", "")
p.httpClient = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: &errAfterDataReadCloser{
data: []byte(body),
chunkSize: 64,
},
}, nil
}),
}

out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if out.Content != content {
t.Fatalf("Content = %q, want %q", out.Content, content)
}
}

func TestProviderChat_LargeHTMLResponsePreviewIsTruncated(t *testing.T) {
body := append([]byte("<!DOCTYPE html><html><body>"), bytes.Repeat([]byte("A"), 2048)...)
body = append(body, []byte("</body></html>")...)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write(body)
}))
defer server.Close()

p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "Body: <!DOCTYPE html><html><body>") {
t.Fatalf("expected html preview in error, got %v", err)
}
if !strings.Contains(err.Error(), "...") {
t.Fatalf("expected truncated preview, got %v", err)
}
}

func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) {
var requestBody map[string]any

Expand Down Expand Up @@ -399,6 +528,40 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) {
}
}

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}

type errAfterDataReadCloser struct {
data []byte
chunkSize int
offset int
}

func (r *errAfterDataReadCloser) Read(p []byte) (int, error) {
if r.offset >= len(r.data) {
return 0, io.ErrUnexpectedEOF
}

n := r.chunkSize
if n <= 0 || n > len(p) {
n = len(p)
}
remaining := len(r.data) - r.offset
if n > remaining {
n = remaining
}
copy(p, r.data[r.offset:r.offset+n])
r.offset += n
return n, nil
}

func (r *errAfterDataReadCloser) Close() error {
return nil
}

func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) {
p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens"))
if p.maxTokensField != "max_completion_tokens" {
Expand Down
Loading