diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 693f2227be..2b1c3251f0 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -544,21 +544,17 @@ func (al *AgentLoop) runLLMIteration( }) } - // Retry loop for context/token errors + // Retry loop for context-window and transient provider errors. + // Context-window errors trigger compression; transient errors retry with backoff. maxRetries := 2 + transientBackoff := []time.Duration{1 * time.Second, 2 * time.Second} for retry := 0; retry <= maxRetries; retry++ { response, err = callLLM() if err == nil { break } - errMsg := strings.ToLower(err.Error()) - isContextError := strings.Contains(errMsg, "token") || - strings.Contains(errMsg, "context") || - strings.Contains(errMsg, "invalidparameter") || - strings.Contains(errMsg, "length") - - if isContextError && retry < maxRetries { + if retry < maxRetries && isContextWindowError(err) { logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{ "error": err.Error(), "retry": retry, @@ -581,6 +577,33 @@ func (al *AgentLoop) runLLMIteration( ) continue } + + if retry < maxRetries { + if retryable, reason := isTransientLLMError(err); retryable { + logger.WarnCF("agent", "Transient LLM error detected, retrying", map[string]any{ + "error": err.Error(), + "reason": reason, + "retry": retry, + }) + + if retry == 0 && !constants.IsInternalChannel(opts.Channel) { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: "Temporary LLM error. Retrying...", + }) + } + + if retry < len(transientBackoff) { + select { + case <-ctx.Done(): + return "", iteration, ctx.Err() + case <-time.After(transientBackoff[retry]): + } + } + continue + } + } break } @@ -766,6 +789,61 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c } } +func isContextWindowError(err error) bool { + if err == nil { + return false + } + + errMsg := strings.ToLower(err.Error()) + contextPatterns := []string{ + "context window", + "context length", + "maximum context length", + "max context length", + "too many tokens", + "max message tokens", + "token limit", + "prompt is too long", + "exceed max message tokens", + } + for _, pattern := range contextPatterns { + if strings.Contains(errMsg, pattern) { + return true + } + } + + // Provider-specific "invalid parameter" style errors frequently include token/length hints. + if strings.Contains(errMsg, "invalidparameter") && + (strings.Contains(errMsg, "token") || + strings.Contains(errMsg, "length") || + strings.Contains(errMsg, "context")) { + return true + } + + return false +} + +func isTransientLLMError(err error) (bool, string) { + if err == nil { + return false, "" + } + + classified := providers.ClassifyError(err, "", "") + if classified == nil { + return false, "" + } + + switch classified.Reason { + case providers.FailoverTimeout, providers.FailoverRateLimit: + if classified.Status > 0 { + return true, fmt.Sprintf("%s:%d", classified.Reason, classified.Status) + } + return true, string(classified.Reason) + default: + return false, "" + } +} + // forceCompression aggressively reduces context when the limit is hit. // It drops the oldest 50% of messages (keeping system prompt and last user message). func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 4414398b17..ce3b4251be 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -631,3 +631,116 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) } } + +// TestAgentLoop_TransientLLMErrorRetry verifies transient 5xx failures are retried +// without triggering context compression. +func TestAgentLoop_TransientLLMErrorRetry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &failFirstMockProvider{ + failures: 1, + failError: fmt.Errorf("API request failed: status: 502 body: bad gateway"), + successResp: "Recovered from transient error", + } + + al := NewAgentLoop(cfg, msgBus, provider) + routedSessionKey := "agent:main:main" + + history := []providers.Message{ + {Role: "system", Content: "System prompt"}, + {Role: "user", Content: "Old message 1"}, + {Role: "assistant", Content: "Old response 1"}, + {Role: "user", Content: "Old message 2"}, + {Role: "assistant", Content: "Old response 2"}, + {Role: "user", Content: "Trigger message"}, + } + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("No default agent found") + } + for _, m := range history { + defaultAgent.Sessions.AddFullMessage(routedSessionKey, m) + } + + response, err := al.ProcessDirectWithChannel( + context.Background(), + "Trigger message", + routedSessionKey, + "test", + "test-chat", + ) + if err != nil { + t.Fatalf("Expected success after transient retry, got error: %v", err) + } + if response != "Recovered from transient error" { + t.Errorf("Expected 'Recovered from transient error', got '%s'", response) + } + if provider.currentCall != 2 { + t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall) + } + + // Transient errors should not trigger context compression. + finalHistory := defaultAgent.Sessions.GetHistory(routedSessionKey) + if len(finalHistory) != 8 { + t.Errorf("Expected no compression for transient retries (len == 8), got %d", len(finalHistory)) + } +} + +// TestAgentLoop_NonRetryableLLMError_NoRetry verifies non-retryable 4xx failures +// return immediately without additional attempts. +func TestAgentLoop_NonRetryableLLMError_NoRetry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &failFirstMockProvider{ + failures: 1, + failError: fmt.Errorf("API request failed: status: 400 body: invalid request"), + successResp: "should not be used", + } + + al := NewAgentLoop(cfg, msgBus, provider) + _, err = al.ProcessDirectWithChannel( + context.Background(), + "Trigger message", + "test-session-no-retry", + "test", + "test-chat", + ) + if err == nil { + t.Fatal("Expected non-retryable 400 error, got nil") + } + if provider.currentCall != 1 { + t.Errorf("Expected 1 call for non-retryable error, got %d", provider.currentCall) + } +}