diff --git a/pkg/agent/context.go b/pkg/agent/context.go index a9db5afddc..7bd55d4ab0 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -229,8 +229,19 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message logger.DebugCF("agent", "Dropping orphaned leading tool message", map[string]any{}) continue } - last := sanitized[len(sanitized)-1] - if last.Role != "assistant" || len(last.ToolCalls) == 0 { + // Walk backwards to find the nearest assistant message, + // skipping over any preceding tool messages (multi-tool-call case). + foundAssistant := false + for i := len(sanitized) - 1; i >= 0; i-- { + if sanitized[i].Role == "tool" { + continue + } + if sanitized[i].Role == "assistant" && len(sanitized[i].ToolCalls) > 0 { + foundAssistant = true + } + break + } + if !foundAssistant { logger.DebugCF("agent", "Dropping orphaned tool message", map[string]any{}) continue } diff --git a/pkg/agent/context_test.go b/pkg/agent/context_test.go new file mode 100644 index 0000000000..e023c9c306 --- /dev/null +++ b/pkg/agent/context_test.go @@ -0,0 +1,209 @@ +package agent + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func msg(role, content string) providers.Message { + return providers.Message{Role: role, Content: content} +} + +func assistantWithTools(toolIDs ...string) providers.Message { + calls := make([]providers.ToolCall, len(toolIDs)) + for i, id := range toolIDs { + calls[i] = providers.ToolCall{ID: id, Type: "function"} + } + return providers.Message{Role: "assistant", ToolCalls: calls} +} + +func toolResult(id string) providers.Message { + return providers.Message{Role: "tool", Content: "result", ToolCallID: id} +} + +func TestSanitizeHistoryForProvider_EmptyHistory(t *testing.T) { + result := sanitizeHistoryForProvider(nil) + if len(result) != 0 { + t.Fatalf("expected empty, got %d messages", len(result)) + } + + result = sanitizeHistoryForProvider([]providers.Message{}) + if len(result) != 0 { + t.Fatalf("expected empty, got %d messages", len(result)) + } +} + +func TestSanitizeHistoryForProvider_SingleToolCall(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + assistantWithTools("A"), + toolResult("A"), + msg("assistant", "done"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 4 { + t.Fatalf("expected 4 messages, got %d", len(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_MultiToolCalls(t *testing.T) { + history := []providers.Message{ + msg("user", "do two things"), + assistantWithTools("A", "B"), + toolResult("A"), + toolResult("B"), + msg("assistant", "both done"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 5 { + t.Fatalf("expected 5 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_AssistantToolCallAfterPlainAssistant(t *testing.T) { + history := []providers.Message{ + msg("user", "hi"), + msg("assistant", "thinking"), + assistantWithTools("A"), + toolResult("A"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 2 { + t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant") +} + +func TestSanitizeHistoryForProvider_OrphanedLeadingTool(t *testing.T) { + history := []providers.Message{ + toolResult("A"), + msg("user", "hello"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user") +} + +func TestSanitizeHistoryForProvider_ToolAfterUserDropped(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + toolResult("A"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user") +} + +func TestSanitizeHistoryForProvider_ToolAfterAssistantNoToolCalls(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + msg("assistant", "hi"), + toolResult("A"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 2 { + t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant") +} + +func TestSanitizeHistoryForProvider_AssistantToolCallAtStart(t *testing.T) { + history := []providers.Message{ + assistantWithTools("A"), + toolResult("A"), + msg("user", "hello"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user") +} + +func TestSanitizeHistoryForProvider_MultiToolCallsThenNewRound(t *testing.T) { + history := []providers.Message{ + msg("user", "do two things"), + assistantWithTools("A", "B"), + toolResult("A"), + toolResult("B"), + msg("assistant", "done"), + msg("user", "hi"), + assistantWithTools("C"), + toolResult("C"), + msg("assistant", "done again"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 9 { + t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "user", "assistant", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_ConsecutiveMultiToolRounds(t *testing.T) { + history := []providers.Message{ + msg("user", "start"), + assistantWithTools("A", "B"), + toolResult("A"), + toolResult("B"), + assistantWithTools("C", "D"), + toolResult("C"), + toolResult("D"), + msg("assistant", "all done"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 8 { + t.Fatalf("expected 8 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "tool", "tool", "assistant") +} + +func TestSanitizeHistoryForProvider_PlainConversation(t *testing.T) { + history := []providers.Message{ + msg("user", "hello"), + msg("assistant", "hi"), + msg("user", "how are you"), + msg("assistant", "fine"), + } + + result := sanitizeHistoryForProvider(history) + if len(result) != 4 { + t.Fatalf("expected 4 messages, got %d", len(result)) + } + assertRoles(t, result, "user", "assistant", "user", "assistant") +} + +func roles(msgs []providers.Message) []string { + r := make([]string, len(msgs)) + for i, m := range msgs { + r[i] = m.Role + } + return r +} + +func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) { + t.Helper() + if len(msgs) != len(expected) { + t.Fatalf("role count mismatch: got %v, want %v", roles(msgs), expected) + } + for i, exp := range expected { + if msgs[i].Role != exp { + t.Errorf("message[%d]: got role %q, want %q", i, msgs[i].Role, exp) + } + } +}