Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -1111,8 +1111,15 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
return
}

// Helper to find the mid-point of the conversation
// Find the mid-point of the conversation, avoiding splitting tool call/result pairs.
// A tool-call message (role=assistant with ToolCalls) must be followed by its
// tool-result message (role=tool). Splitting between them causes API errors.
mid := len(conversation) / 2
if mid < len(conversation) && mid > 0 {
if conversation[mid].Role == "tool" {
mid++ // move past the tool result to keep the pair together
}
}

// New history structure:
// 1. System Prompt (with compression note appended)
Expand Down
79 changes: 79 additions & 0 deletions pkg/agent/loop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,85 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
}
}

// TestForceCompression_ToolMessageBoundary verifies that forceCompression does not
// split a tool call/result pair when the midpoint falls on a "tool" role message.
// Regression test for: API errors when orphaned tool result messages appear
// without their preceding assistant tool-call message.
func TestForceCompression_ToolMessageBoundary(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)

cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}

msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)

sessionKey := "test-session-tool-boundary"
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}

// Construct a history where len(conversation)/2 falls exactly on a "tool" message.
// history = [system, user, assistant(tool_call), tool, user, assistant, user_trigger]
// conversation = history[1:6] = [user, assistant(tool_call), tool, user, assistant]
// len(conversation) = 5, mid = 5/2 = 2 => conversation[2].Role == "tool"
// Without the fix, this would split between assistant(tool_call) and tool result.
history := []providers.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What files are in the current directory?"},
{Role: "assistant", Content: "", ToolCalls: []providers.ToolCall{
{ID: "call_1", Name: "exec", Arguments: map[string]any{"command": "ls"}},
}},
{Role: "tool", Content: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
{Role: "user", Content: "Tell me about file1.txt"},
{Role: "assistant", Content: "file1.txt is a text file."},
{Role: "user", Content: "Thanks"}, // trigger message
}

// Create the session first (AddMessage creates the session entry),
// then overwrite with our full history via SetHistory.
defaultAgent.Sessions.AddMessage(sessionKey, "system", "init")
defaultAgent.Sessions.SetHistory(sessionKey, history)

// Call forceCompression
al.forceCompression(defaultAgent, sessionKey)

// Verify the result
compressed := defaultAgent.Sessions.GetHistory(sessionKey)

// Check that no message with role="tool" is the first conversation message
// (after the system prompt). If it is, it means the tool result was orphaned.
for i := 1; i < len(compressed); i++ {
if compressed[i].Role == "tool" {
// There must be an assistant message with tool calls before it
if i == 1 {
t.Errorf("Tool result message at position %d is orphaned (no preceding assistant with tool call)", i)
} else if compressed[i-1].Role != "assistant" || len(compressed[i-1].ToolCalls) == 0 {
t.Errorf("Tool result at position %d is not preceded by assistant with tool calls (preceded by role=%q)", i, compressed[i-1].Role)
}
}
}

// Verify the system prompt has the compression note
if !strings.Contains(compressed[0].Content, "Emergency compression") {
t.Errorf("Expected compression note in system prompt, got: %s", compressed[0].Content)
}
}

func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion pkg/config/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {

// Check if this is the user's configured provider
if slices.Contains(m.providerNames, userProvider) && userModel != "" {
// Use the user's configured model instead of default
// Use the user's configured model instead of default.
// Also set ModelName so GetModelConfig(userModel) can find this entry.
mc.ModelName = userModel
mc.Model = buildModelWithProtocol(m.protocol, userModel)
} else if userProvider == "" && userModel != "" && !legacyModelNameApplied {
// Legacy config: no explicit provider field but model is specified
Expand Down
69 changes: 69 additions & 0 deletions pkg/config/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,72 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T)
t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto")
}
}

// Test that ModelName is set to the user's configured model when provider matches.
// This ensures GetModelConfig(userModel) can find the migrated entry.
// Regression test for: gateway startup failure when user model differs from provider name.
func TestConvertProvidersToModelList_ModelNameMatchesUserModel(t *testing.T) {
cfg := &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Provider: "moonshot",
Model: "k2p5",
},
},
Providers: ProvidersConfig{
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
},
}

result := ConvertProvidersToModelList(cfg)

if len(result) != 1 {
t.Fatalf("len(result) = %d, want 1", len(result))
}

// ModelName must match the user's configured model, not the provider name.
// Without this, GetModelConfig("k2p5") would fail because it would look
// for ModelName == "k2p5" but find ModelName == "moonshot".
if result[0].ModelName != "k2p5" {
t.Errorf("ModelName = %q, want %q (must match user's model for GetModelConfig lookup)", result[0].ModelName, "k2p5")
}

if result[0].Model != "moonshot/k2p5" {
t.Errorf("Model = %q, want %q", result[0].Model, "moonshot/k2p5")
}

// Other providers (not matching the user's configured provider) should keep their provider name
cfg2 := &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Provider: "moonshot",
Model: "k2p5",
},
},
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}},
Moonshot: ProviderConfig{APIKey: "sk-kimi-test"},
},
}

result2 := ConvertProvidersToModelList(cfg2)

if len(result2) != 2 {
t.Fatalf("len(result2) = %d, want 2", len(result2))
}

for _, mc := range result2 {
switch {
case mc.APIKey == "sk-openai":
// OpenAI is not the user's provider, should keep default ModelName
if mc.ModelName != "openai" {
t.Errorf("OpenAI ModelName = %q, want %q (non-matching provider keeps default)", mc.ModelName, "openai")
}
case mc.APIKey == "sk-kimi-test":
// Moonshot is the user's provider, ModelName must be the user's model
if mc.ModelName != "k2p5" {
t.Errorf("Moonshot ModelName = %q, want %q (matching provider uses user model)", mc.ModelName, "k2p5")
}
}
}
}
4 changes: 2 additions & 2 deletions pkg/tools/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ var (
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
// Match disk wiping commands (must be followed by space/args)
// Match disk wiping commands, avoid matching --format flags
regexp.MustCompile(
`\b(format|mkfs|diskpart)\b\s`,
`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`,
),
regexp.MustCompile(`\bdd\s+if=`),
// Block writes to block devices (all common naming schemes).
Expand Down
50 changes: 50 additions & 0 deletions pkg/tools/shell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,56 @@ func TestShellTool_BlockDevices(t *testing.T) {
}
}

// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping
// commands (format, mkfs, diskpart) blocks them when preceded by shell separators
// but does NOT block legitimate uses like --format flags.
func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
tool, err := NewExecTool("", false)
if err != nil {
t.Fatalf("unable to configure exec tool: %s", err)
}

// These should be BLOCKED (disk wiping commands)
blockedCmds := []struct {
name string
cmd string
}{
{"format with space", "format c:"},
{"mkfs standalone", "mkfs /dev/sda"},
{"semicolon format", "echo hello; format c:"},
{"pipe format", "echo hello | format c:"},
{"and format", "echo hello && format c:"},
{"diskpart standalone", "diskpart /s script.txt"},
}

for _, tt := range blockedCmds {
t.Run("blocked_"+tt.name, func(t *testing.T) {
msg := tool.guardCommand(tt.cmd, "")
if !strings.Contains(msg, "blocked") {
t.Errorf("Expected %q to be blocked by safety guard, got: %q", tt.cmd, msg)
}
})
}

// These should be ALLOWED (not disk wiping)
allowed := []struct {
name string
cmd string
}{
{"--format flag", "echo test --format json"},
{"go fmt", "echo go fmt ./..."},
}

for _, tt := range allowed {
t.Run("allowed_"+tt.name, func(t *testing.T) {
msg := tool.guardCommand(tt.cmd, "")
if msg != "" {
t.Errorf("Expected %q to be allowed, but it was blocked: %s", tt.cmd, msg)
}
})
}
}

// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
// are allowed even when workspace restriction is active.
func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
Expand Down