diff --git a/.golangci.yaml b/.golangci.yaml index d0ba907169..ea3107ec81 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -7,7 +7,6 @@ linters: - containedctx - cyclop - depguard - - dupl - dupword - err113 - exhaustruct diff --git a/cmd/picoclaw-launcher-tui/internal/ui/channel.go b/cmd/picoclaw-launcher-tui/internal/ui/channel.go index ad91714247..49a6ccc5d3 100644 --- a/cmd/picoclaw-launcher-tui/internal/ui/channel.go +++ b/cmd/picoclaw-launcher-tui/internal/ui/channel.go @@ -10,8 +10,8 @@ import ( picoclawconfig "github.com/sipeed/picoclaw/pkg/config" ) -func (s *appState) channelMenu() tview.Primitive { - items := []MenuItem{ +func (s *appState) buildChannelMenuItems() []MenuItem { + return []MenuItem{ {Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }}, channelItem( "Telegram", @@ -86,8 +86,10 @@ func (s *appState) channelMenu() tview.Primitive { func() { s.push("channel-wecomapp", s.wecomAppForm()) }, ), } +} - menu := NewMenu("Channels", items) +func (s *appState) channelMenu() tview.Primitive { + menu := NewMenu("Channels", s.buildChannelMenuItems()) menu.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEsc { s.pop() @@ -103,199 +105,72 @@ func (s *appState) channelMenu() tview.Primitive { } func refreshChannelMenuFromState(menu *Menu, s *appState) { - items := []MenuItem{ - {Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }}, - channelItem( - "Telegram", - "Telegram bot settings", - s.config.Channels.Telegram.Enabled, - func() { s.push("channel-telegram", s.telegramForm()) }, - ), - channelItem( - "Discord", - "Discord bot settings", - s.config.Channels.Discord.Enabled, - func() { s.push("channel-discord", s.discordForm()) }, - ), - channelItem( - "QQ", - "QQ bot settings", - s.config.Channels.QQ.Enabled, - func() { s.push("channel-qq", s.qqForm()) }, - ), - channelItem( - "MaixCam", - "MaixCam gateway", - s.config.Channels.MaixCam.Enabled, - func() { s.push("channel-maixcam", s.maixcamForm()) }, - ), - channelItem( - "WhatsApp", - "WhatsApp bridge", - s.config.Channels.WhatsApp.Enabled, - func() { s.push("channel-whatsapp", s.whatsappForm()) }, - ), - channelItem( - "Feishu", - "Feishu bot settings", - s.config.Channels.Feishu.Enabled, - func() { s.push("channel-feishu", s.feishuForm()) }, - ), - channelItem( - "DingTalk", - "DingTalk bot settings", - s.config.Channels.DingTalk.Enabled, - func() { s.push("channel-dingtalk", s.dingtalkForm()) }, - ), - channelItem( - "Slack", - "Slack bot settings", - s.config.Channels.Slack.Enabled, - func() { s.push("channel-slack", s.slackForm()) }, - ), - channelItem( - "LINE", - "LINE bot settings", - s.config.Channels.LINE.Enabled, - func() { s.push("channel-line", s.lineForm()) }, - ), - channelItem( - "OneBot", - "OneBot settings", - s.config.Channels.OneBot.Enabled, - func() { s.push("channel-onebot", s.onebotForm()) }, - ), - channelItem( - "WeCom", - "WeCom bot settings", - s.config.Channels.WeCom.Enabled, - func() { s.push("channel-wecom", s.wecomForm()) }, - ), - channelItem( - "WeCom App", - "WeCom App settings", - s.config.Channels.WeComApp.Enabled, - func() { s.push("channel-wecomapp", s.wecomAppForm()) }, - ), - } - menu.applyItems(items) + menu.applyItems(s.buildChannelMenuItems()) } func (s *appState) telegramForm() tview.Primitive { cfg := &s.config.Channels.Telegram - form := baseChannelForm("Telegram", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("Telegram", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Token", cfg.Token, 128, nil, func(text string) { cfg.Token = strings.TrimSpace(text) }) form.AddInputField("Proxy", cfg.Proxy, 128, nil, func(text string) { cfg.Proxy = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) discordForm() tview.Primitive { cfg := &s.config.Channels.Discord - form := baseChannelForm("Discord", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("Discord", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Token", cfg.Token, 128, nil, func(text string) { cfg.Token = strings.TrimSpace(text) }) form.AddCheckbox("Mention Only", cfg.MentionOnly, func(checked bool) { cfg.MentionOnly = checked }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) qqForm() tview.Primitive { cfg := &s.config.Channels.QQ - form := baseChannelForm("QQ", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("QQ", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) { cfg.AppID = strings.TrimSpace(text) }) form.AddInputField("App Secret", cfg.AppSecret, 128, nil, func(text string) { cfg.AppSecret = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) maixcamForm() tview.Primitive { cfg := &s.config.Channels.MaixCam - form := baseChannelForm("MaixCam", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("MaixCam", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Host", cfg.Host, 64, nil, func(text string) { cfg.Host = strings.TrimSpace(text) }) addIntField(form, "Port", cfg.Port, func(value int) { cfg.Port = value }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) whatsappForm() tview.Primitive { cfg := &s.config.Channels.WhatsApp - form := baseChannelForm("WhatsApp", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("WhatsApp", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Bridge URL", cfg.BridgeURL, 128, nil, func(text string) { cfg.BridgeURL = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) feishuForm() tview.Primitive { cfg := &s.config.Channels.Feishu - form := baseChannelForm("Feishu", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("Feishu", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) { cfg.AppID = strings.TrimSpace(text) }) @@ -308,66 +183,39 @@ func (s *appState) feishuForm() tview.Primitive { form.AddInputField("Verification Token", cfg.VerificationToken, 128, nil, func(text string) { cfg.VerificationToken = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) dingtalkForm() tview.Primitive { cfg := &s.config.Channels.DingTalk - form := baseChannelForm("DingTalk", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("DingTalk", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Client ID", cfg.ClientID, 64, nil, func(text string) { cfg.ClientID = strings.TrimSpace(text) }) form.AddInputField("Client Secret", cfg.ClientSecret, 128, nil, func(text string) { cfg.ClientSecret = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) slackForm() tview.Primitive { cfg := &s.config.Channels.Slack - form := baseChannelForm("Slack", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("Slack", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Bot Token", cfg.BotToken, 128, nil, func(text string) { cfg.BotToken = strings.TrimSpace(text) }) form.AddInputField("App Token", cfg.AppToken, 128, nil, func(text string) { cfg.AppToken = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) lineForm() tview.Primitive { cfg := &s.config.Channels.LINE - form := baseChannelForm("LINE", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("LINE", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Channel Secret", cfg.ChannelSecret, 128, nil, func(text string) { cfg.ChannelSecret = strings.TrimSpace(text) }) @@ -381,22 +229,13 @@ func (s *appState) lineForm() tview.Primitive { form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) { cfg.WebhookPath = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) onebotForm() tview.Primitive { cfg := &s.config.Channels.OneBot - form := baseChannelForm("OneBot", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("OneBot", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("WS URL", cfg.WSUrl, 128, nil, func(text string) { cfg.WSUrl = strings.TrimSpace(text) }) @@ -418,22 +257,13 @@ func (s *appState) onebotForm() tview.Primitive { cfg.GroupTriggerPrefix = splitCSV(text) }, ) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) wecomForm() tview.Primitive { cfg := &s.config.Channels.WeCom - form := baseChannelForm("WeCom", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("WeCom", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Token", cfg.Token, 128, nil, func(text string) { cfg.Token = strings.TrimSpace(text) }) @@ -450,9 +280,7 @@ func (s *appState) wecomForm() tview.Primitive { form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) { cfg.WebhookPath = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) addIntField( form, "Reply Timeout", @@ -464,14 +292,7 @@ func (s *appState) wecomForm() tview.Primitive { func (s *appState) wecomAppForm() tview.Primitive { cfg := &s.config.Channels.WeComApp - form := baseChannelForm("WeCom App", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("WeCom App", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Corp ID", cfg.CorpID, 64, nil, func(text string) { cfg.CorpID = strings.TrimSpace(text) }) @@ -492,9 +313,7 @@ func (s *appState) wecomAppForm() tview.Primitive { form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) { cfg.WebhookPath = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) addIntField( form, "Reply Timeout", @@ -504,6 +323,23 @@ func (s *appState) wecomAppForm() tview.Primitive { return wrapWithBack(form, s) } +func (s *appState) makeChannelOnEnabled(enabledPtr *bool) func(bool) { + return func(v bool) { + *enabledPtr = v + s.dirty = true + refreshMainMenuIfPresent(s) + if menu, ok := s.menus["channel"]; ok { + refreshChannelMenuFromState(menu, s) + } + } +} + +func addAllowFromField(form *tview.Form, allowFrom *picoclawconfig.FlexibleStringSlice) { + form.AddInputField("Allow From", strings.Join(*allowFrom, ","), 128, nil, func(text string) { + *allowFrom = splitCSV(text) + }) +} + func baseChannelForm(title string, enabled bool, onEnabled func(bool)) *tview.Form { form := tview.NewForm() form.SetBorder(true).SetTitle(fmt.Sprintf("Channel: %s", title)) diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go index af1bf2eadd..4f41ecd1cc 100644 --- a/pkg/agent/instance_test.go +++ b/pkg/agent/instance_test.go @@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) { } func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "step-3.5-flash", - }, + tests := []struct { + name string + aliasName string + modelName string + apiBase string + wantProvider string + wantModel string + }{ + { + name: "alias with provider prefix", + aliasName: "step-3.5-flash", + modelName: "openrouter/stepfun/step-3.5-flash:free", + apiBase: "https://openrouter.ai/api/v1", + wantProvider: "openrouter", + wantModel: "stepfun/step-3.5-flash:free", }, - ModelList: []config.ModelConfig{ - { - ModelName: "step-3.5-flash", - Model: "openrouter/stepfun/step-3.5-flash:free", - APIBase: "https://openrouter.ai/api/v1", - }, + { + name: "alias without provider prefix", + aliasName: "glm-5", + modelName: "glm-5", + apiBase: "https://api.z.ai/api/coding/paas/v4", + wantProvider: "openai", + wantModel: "glm-5", }, } - provider := &mockProvider{} - agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) - - if len(agent.Candidates) != 1 { - t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates)) - } - if agent.Candidates[0].Provider != "openrouter" { - t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter") - } - if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" { - t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free") - } -} - -func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "glm-5", - }, - }, - ModelList: []config.ModelConfig{ - { - ModelName: "glm-5", - Model: "glm-5", - APIBase: "https://api.z.ai/api/coding/paas/v4", - }, - }, - } - - provider := &mockProvider{} - agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) - - if len(agent.Candidates) != 1 { - t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates)) - } - if agent.Candidates[0].Provider != "openai" { - t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai") - } - if agent.Candidates[0].Model != "glm-5" { - t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: tt.aliasName, + }, + }, + ModelList: []config.ModelConfig{ + { + ModelName: tt.aliasName, + Model: tt.modelName, + APIBase: tt.apiBase, + }, + }, + } + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if len(agent.Candidates) != 1 { + t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates)) + } + if agent.Candidates[0].Provider != tt.wantProvider { + t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider) + } + if agent.Candidates[0].Model != tt.wantModel { + t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel) + } + }) } } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 801b6a46ed..51cca90cf0 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -26,16 +26,15 @@ func (f *fakeChannel) IsAllowed(string) bool { func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true } func (f *fakeChannel) ReasoningChannelID() string { return f.id } -func TestRecordLastChannel(t *testing.T) { - // Create temp workspace +func newTestAgentLoop( + t *testing.T, +) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) { + t.Helper() tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { t.Fatalf("Failed to create temp dir: %v", err) } - defer os.RemoveAll(tmpDir) - - // Create test config - cfg := &config.Config{ + cfg = &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, @@ -45,74 +44,43 @@ func TestRecordLastChannel(t *testing.T) { }, }, } + msgBus = bus.NewMessageBus() + provider = &mockProvider{} + al = NewAgentLoop(cfg, msgBus, provider) + return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) } +} - // Create agent loop - msgBus := bus.NewMessageBus() - provider := &mockProvider{} - al := NewAgentLoop(cfg, msgBus, provider) +func TestRecordLastChannel(t *testing.T) { + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) + defer cleanup() - // Test RecordLastChannel testChannel := "test-channel" - err = al.RecordLastChannel(testChannel) - if err != nil { + if err := al.RecordLastChannel(testChannel); err != nil { t.Fatalf("RecordLastChannel failed: %v", err) } - - // Verify channel was saved - lastChannel := al.state.GetLastChannel() - if lastChannel != testChannel { - t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel) + if got := al.state.GetLastChannel(); got != testChannel { + t.Errorf("Expected channel '%s', got '%s'", testChannel, got) } - - // Verify persistence by creating a new agent loop al2 := NewAgentLoop(cfg, msgBus, provider) - if al2.state.GetLastChannel() != testChannel { - t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel()) + if got := al2.state.GetLastChannel(); got != testChannel { + t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got) } } func TestRecordLastChatID(t *testing.T) { - // Create temp workspace - tmpDir, err := os.MkdirTemp("", "agent-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) + defer cleanup() - // Create test config - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "test-model", - MaxTokens: 4096, - MaxToolIterations: 10, - }, - }, - } - - // Create agent loop - msgBus := bus.NewMessageBus() - provider := &mockProvider{} - al := NewAgentLoop(cfg, msgBus, provider) - - // Test RecordLastChatID testChatID := "test-chat-id-123" - err = al.RecordLastChatID(testChatID) - if err != nil { + if err := al.RecordLastChatID(testChatID); err != nil { t.Fatalf("RecordLastChatID failed: %v", err) } - - // Verify chat ID was saved - lastChatID := al.state.GetLastChatID() - if lastChatID != testChatID { - t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID) + if got := al.state.GetLastChatID(); got != testChatID { + t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got) } - - // Verify persistence by creating a new agent loop al2 := NewAgentLoop(cfg, msgBus, provider) - if al2.state.GetLastChatID() != testChatID { - t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID()) + if got := al2.state.GetLastChatID(); got != testChatID { + t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got) } } diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 155e50b397..be48f85fc4 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -539,86 +539,88 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork }) } -func (m *Manager) dispatchOutbound(ctx context.Context) { - logger.InfoC("channels", "Outbound dispatcher started") +func dispatchLoop[M any]( + ctx context.Context, + m *Manager, + subscribe func(context.Context) (M, bool), + getChannel func(M) string, + enqueue func(context.Context, *channelWorker, M) bool, + startMsg, stopMsg, unknownMsg, noWorkerMsg string, +) { + logger.InfoC("channels", startMsg) for { - msg, ok := m.bus.SubscribeOutbound(ctx) + msg, ok := subscribe(ctx) if !ok { - logger.InfoC("channels", "Outbound dispatcher stopped") + logger.InfoC("channels", stopMsg) return } + channel := getChannel(msg) + // Silently skip internal channels - if constants.IsInternalChannel(msg.Channel) { + if constants.IsInternalChannel(channel) { continue } m.mu.RLock() - _, exists := m.channels[msg.Channel] - w, wExists := m.workers[msg.Channel] + _, exists := m.channels[channel] + w, wExists := m.workers[channel] m.mu.RUnlock() if !exists { - logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{ - "channel": msg.Channel, - }) + logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel}) continue } if wExists && w != nil { - select { - case w.queue <- msg: - case <-ctx.Done(): + if !enqueue(ctx, w, msg) { return } } else if exists { - logger.WarnCF("channels", "Channel has no active worker, skipping message", map[string]any{ - "channel": msg.Channel, - }) + logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel}) } } } -func (m *Manager) dispatchOutboundMedia(ctx context.Context) { - logger.InfoC("channels", "Outbound media dispatcher started") - - for { - msg, ok := m.bus.SubscribeOutboundMedia(ctx) - if !ok { - logger.InfoC("channels", "Outbound media dispatcher stopped") - return - } - - // Silently skip internal channels - if constants.IsInternalChannel(msg.Channel) { - continue - } - - m.mu.RLock() - _, exists := m.channels[msg.Channel] - w, wExists := m.workers[msg.Channel] - m.mu.RUnlock() - - if !exists { - logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{ - "channel": msg.Channel, - }) - continue - } +func (m *Manager) dispatchOutbound(ctx context.Context) { + dispatchLoop( + ctx, m, + m.bus.SubscribeOutbound, + func(msg bus.OutboundMessage) string { return msg.Channel }, + func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool { + select { + case w.queue <- msg: + return true + case <-ctx.Done(): + return false + } + }, + "Outbound dispatcher started", + "Outbound dispatcher stopped", + "Unknown channel for outbound message", + "Channel has no active worker, skipping message", + ) +} - if wExists && w != nil { +func (m *Manager) dispatchOutboundMedia(ctx context.Context) { + dispatchLoop( + ctx, m, + m.bus.SubscribeOutboundMedia, + func(msg bus.OutboundMediaMessage) string { return msg.Channel }, + func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool { select { case w.mediaQueue <- msg: + return true case <-ctx.Done(): - return + return false } - } else if exists { - logger.WarnCF("channels", "Channel has no active worker, skipping media message", map[string]any{ - "channel": msg.Channel, - }) - } - } + }, + "Outbound media dispatcher started", + "Outbound media dispatcher stopped", + "Unknown channel for outbound media message", + "Channel has no active worker, skipping media message", + ) } // runMediaWorker processes outbound media messages for a single channel. diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 292a71fd28..b793403158 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -342,18 +342,11 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp return result.MediaID, nil } -// sendImageMessage sends an image message using a media_id. -func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error { +// sendWeComMessage marshals payload and POSTs it to the WeCom message API. +func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error { apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) - msg := WeComImageMessage{ - ToUser: userID, - MsgType: "image", - AgentID: c.config.AgentID, - } - msg.Image.MediaID = mediaID - - jsonData, err := json.Marshal(msg) + jsonData, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to marshal message: %w", err) } @@ -400,6 +393,17 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use return nil } +// sendImageMessage sends an image message using a media_id. +func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error { + msg := WeComImageMessage{ + ToUser: userID, + MsgType: "image", + AgentID: c.config.AgentID, + } + msg.Image.MediaID = mediaID + return c.sendWeComMessage(ctx, accessToken, msg) +} + // WebhookPath returns the path for registering on the shared HTTP server. func (c *WeComAppChannel) WebhookPath() string { if c.config.WebhookPath != "" { @@ -722,63 +726,15 @@ func (c *WeComAppChannel) getAccessToken() string { return c.accessToken } -// sendTextMessage sends a text message to a user +// sendTextMessage sends a text message to a user. func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error { - apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) - msg := WeComTextMessage{ ToUser: userID, MsgType: "text", AgentID: c.config.AgentID, } msg.Text.Content = content - - jsonData, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("failed to marshal message: %w", err) - } - - // Use configurable timeout (default 5 seconds) - timeout := c.config.ReplyTimeout - if timeout <= 0 { - timeout = 5 - } - - reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(req) - if err != nil { - return channels.ClassifyNetError(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body))) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - var sendResp WeComSendMessageResponse - if err := json.Unmarshal(body, &sendResp); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if sendResp.ErrCode != 0 { - return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) - } - - return nil + return c.sendWeComMessage(ctx, accessToken, msg) } // handleHealth handles health check requests diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go index 5420949de7..8a5faaaa8e 100644 --- a/pkg/channels/wecom/app_test.go +++ b/pkg/channels/wecom/app_test.go @@ -323,60 +323,6 @@ func TestWeComAppDecryptMessage(t *testing.T) { }) } -func TestWeComAppPKCS7Unpad(t *testing.T) { - tests := []struct { - name string - input []byte - expected []byte - }{ - { - name: "empty input", - input: []byte{}, - expected: []byte{}, - }, - { - name: "valid padding 3 bytes", - input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), - expected: []byte("hello"), - }, - { - name: "valid padding 16 bytes (full block)", - input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), - expected: []byte("123456789012345"), - }, - { - name: "invalid padding larger than data", - input: []byte{20}, - expected: nil, // should return error - }, - { - name: "invalid padding zero", - input: append([]byte("test"), byte(0)), - expected: nil, // should return error - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7Unpad(tt.input) - if tt.expected == nil { - // This case should return an error - if err == nil { - t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) - } - return - } - if err != nil { - t.Errorf("pkcs7Unpad() unexpected error: %v", err) - return - } - if !bytes.Equal(result, tt.expected) { - t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) - } - }) - } -} - func TestWeComAppHandleVerification(t *testing.T) { msgBus := bus.NewMessageBus() aesKey := generateTestAESKeyApp() diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go index 328b145c2d..1950800c96 100644 --- a/pkg/channels/wecom/bot_test.go +++ b/pkg/channels/wecom/bot_test.go @@ -412,22 +412,9 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - t.Run("valid direct message callback", func(t *testing.T) { - // Create JSON message for direct chat (single) - jsonMsg := `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chattype": "single", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }` - - // Encrypt message + runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder { + t.Helper() encrypted, _ := encryptTestMessage(jsonMsg, aesKey) - - // Create encrypted XML wrapper encryptedWrapper := struct { XMLName xml.Name `xml:"xml"` Encrypt string `xml:"Encrypt"` @@ -435,20 +422,29 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { Encrypt: encrypted, } wrapperData, _ := xml.Marshal(encryptedWrapper) - timestamp := "1234567890" nonce := "test_nonce" signature := generateSignature("test_token", timestamp, nonce, encrypted) - req := httptest.NewRequest( http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData), ) w := httptest.NewRecorder() - ch.handleMessageCallback(context.Background(), w, req) + return w + } + t.Run("valid direct message callback", func(t *testing.T) { + w := runBotMessageCallback(t, `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chattype": "single", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }`) if w.Code != http.StatusOK { t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) } @@ -458,8 +454,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { }) t.Run("valid group message callback", func(t *testing.T) { - // Create JSON message for group chat - jsonMsg := `{ + w := runBotMessageCallback(t, `{ "msgid": "test_msg_id_456", "aibotid": "test_aibot_id", "chatid": "group_chat_id_123", @@ -468,33 +463,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", "msgtype": "text", "text": {"content": "Hello Group"} - }` - - // Encrypt message - encrypted, _ := encryptTestMessage(jsonMsg, aesKey) - - // Create encrypted XML wrapper - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: encrypted, - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encrypted) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - + }`) if w.Code != http.StatusOK { t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) } diff --git a/pkg/config/config.go b/pkg/config/config.go index d84772d2b0..4210bf3091 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -742,25 +742,7 @@ func (c *Config) findMatches(modelName string) []ModelConfig { // HasProvidersConfig checks if any provider in the old providers config has configuration. func (c *Config) HasProvidersConfig() bool { - v := c.Providers - return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" || - v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" || - v.OpenRouter.APIKey != "" || v.OpenRouter.APIBase != "" || - v.Groq.APIKey != "" || v.Groq.APIBase != "" || - v.Zhipu.APIKey != "" || v.Zhipu.APIBase != "" || - v.VLLM.APIKey != "" || v.VLLM.APIBase != "" || - v.Gemini.APIKey != "" || v.Gemini.APIBase != "" || - v.Nvidia.APIKey != "" || v.Nvidia.APIBase != "" || - v.Ollama.APIKey != "" || v.Ollama.APIBase != "" || - v.Moonshot.APIKey != "" || v.Moonshot.APIBase != "" || - v.ShengSuanYun.APIKey != "" || v.ShengSuanYun.APIBase != "" || - v.DeepSeek.APIKey != "" || v.DeepSeek.APIBase != "" || - v.Cerebras.APIKey != "" || v.Cerebras.APIBase != "" || - v.VolcEngine.APIKey != "" || v.VolcEngine.APIBase != "" || - v.GitHubCopilot.APIKey != "" || v.GitHubCopilot.APIBase != "" || - v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" || - v.Qwen.APIKey != "" || v.Qwen.APIBase != "" || - v.Mistral.APIKey != "" || v.Mistral.APIBase != "" + return !c.Providers.IsEmpty() } // ValidateModelList validates all ModelConfig entries in the model_list. diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go index a7aef8c3a9..3b7eeeefb3 100644 --- a/pkg/heartbeat/service_test.go +++ b/pkg/heartbeat/service_test.go @@ -47,79 +47,63 @@ func TestExecuteHeartbeat_Async(t *testing.T) { } } -func TestExecuteHeartbeat_Error(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - hs := NewHeartbeatService(tmpDir, 30, true) - hs.stopChan = make(chan struct{}) // Enable for testing - - hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { - return &tools.ToolResult{ - ForLLM: "Heartbeat failed: connection error", - ForUser: "", - Silent: false, - IsError: true, - Async: false, - } - }) - - // Create HEARTBEAT.md - os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) - - hs.executeHeartbeat() - - // Check log file for error message - logFile := filepath.Join(tmpDir, "heartbeat.log") - data, err := os.ReadFile(logFile) - if err != nil { - t.Fatalf("Failed to read log file: %v", err) - } - - logContent := string(data) - if logContent == "" { - t.Error("Expected log file to contain error message") - } -} - -func TestExecuteHeartbeat_Silent(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - hs := NewHeartbeatService(tmpDir, 30, true) - hs.stopChan = make(chan struct{}) // Enable for testing - - hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { - return &tools.ToolResult{ - ForLLM: "Heartbeat completed successfully", - ForUser: "", - Silent: true, - IsError: false, - Async: false, - } - }) - - // Create HEARTBEAT.md - os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) - - hs.executeHeartbeat() - - // Check log file for completion message - logFile := filepath.Join(tmpDir, "heartbeat.log") - data, err := os.ReadFile(logFile) - if err != nil { - t.Fatalf("Failed to read log file: %v", err) - } - - logContent := string(data) - if logContent == "" { - t.Error("Expected log file to contain completion message") +func TestExecuteHeartbeat_ResultLogging(t *testing.T) { + tests := []struct { + name string + result *tools.ToolResult + wantLog string + }{ + { + name: "error result", + result: &tools.ToolResult{ + ForLLM: "Heartbeat failed: connection error", + ForUser: "", + Silent: false, + IsError: true, + Async: false, + }, + wantLog: "error message", + }, + { + name: "silent result", + result: &tools.ToolResult{ + ForLLM: "Heartbeat completed successfully", + ForUser: "", + Silent: true, + IsError: false, + Async: false, + }, + wantLog: "completion message", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return tt.result + }) + + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) + hs.executeHeartbeat() + + logFile := filepath.Join(tmpDir, "heartbeat.log") + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + if string(data) == "" { + t.Errorf("Expected log file to contain %s", tt.wantLog) + } + }) } } diff --git a/pkg/migrate/internal/common_test.go b/pkg/migrate/internal/common_test.go index a089157f57..a67293c197 100644 --- a/pkg/migrate/internal/common_test.go +++ b/pkg/migrate/internal/common_test.go @@ -118,64 +118,55 @@ func TestPlanWorkspaceMigration(t *testing.T) { assert.GreaterOrEqual(t, len(actions), 1) } -func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) { - tmpDir := t.TempDir() - srcWorkspace := filepath.Join(tmpDir, "src", "workspace") - dstWorkspace := filepath.Join(tmpDir, "dst", "workspace") - - err := os.MkdirAll(srcWorkspace, 0o755) - require.NoError(t, err) - - err = os.MkdirAll(dstWorkspace, 0o755) - require.NoError(t, err) - - err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644) - require.NoError(t, err) - - err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644) - require.NoError(t, err) - - actions, err := PlanWorkspaceMigration( - srcWorkspace, - dstWorkspace, - []string{"file1.txt"}, - []string{}, - false, - ) - require.NoError(t, err) - - require.GreaterOrEqual(t, len(actions), 1) - assert.Equal(t, ActionBackup, actions[0].Type) -} - -func TestPlanWorkspaceMigrationForce(t *testing.T) { - tmpDir := t.TempDir() - srcWorkspace := filepath.Join(tmpDir, "src", "workspace") - dstWorkspace := filepath.Join(tmpDir, "dst", "workspace") - - err := os.MkdirAll(srcWorkspace, 0o755) - require.NoError(t, err) - - err = os.MkdirAll(dstWorkspace, 0o755) - require.NoError(t, err) - - err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644) - require.NoError(t, err) - - err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644) - require.NoError(t, err) - - actions, err := PlanWorkspaceMigration( - srcWorkspace, - dstWorkspace, - []string{"file1.txt"}, - []string{}, - true, - ) - require.NoError(t, err) +func TestPlanWorkspaceMigrationExistingFile(t *testing.T) { + tests := []struct { + name string + force bool + wantActionType ActionType + }{ + { + name: "backup when not forced", + force: false, + wantActionType: ActionBackup, + }, + { + name: "copy when forced", + force: true, + wantActionType: ActionCopy, + }, + } - require.GreaterOrEqual(t, len(actions), 1) - assert.Equal(t, ActionCopy, actions[0].Type) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + srcWorkspace := filepath.Join(tmpDir, "src", "workspace") + dstWorkspace := filepath.Join(tmpDir, "dst", "workspace") + + err := os.MkdirAll(srcWorkspace, 0o755) + require.NoError(t, err) + + err = os.MkdirAll(dstWorkspace, 0o755) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644) + require.NoError(t, err) + + actions, err := PlanWorkspaceMigration( + srcWorkspace, + dstWorkspace, + []string{"file1.txt"}, + []string{}, + tt.force, + ) + require.NoError(t, err) + + require.GreaterOrEqual(t, len(actions), 1) + assert.Equal(t, tt.wantActionType, actions[0].Type) + }) + } } func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) { diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go index 74ec33b98d..6c4f6a767b 100644 --- a/pkg/providers/claude_cli_provider.go +++ b/pkg/providers/claude_cli_provider.go @@ -100,44 +100,12 @@ func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDe } if len(tools) > 0 { - parts = append(parts, p.buildToolsPrompt(tools)) + parts = append(parts, buildCLIToolsPrompt(tools)) } return strings.Join(parts, "\n\n") } -// buildToolsPrompt creates the tool definitions section for the system prompt. -func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string { - var sb strings.Builder - - sb.WriteString("## Available Tools\n\n") - sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") - sb.WriteString("```json\n") - sb.WriteString( - `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`, - ) - sb.WriteString("\n```\n\n") - sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") - sb.WriteString("### Tool Definitions:\n\n") - - for _, tool := range tools { - if tool.Type != "function" { - continue - } - sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) - if tool.Function.Description != "" { - sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) - } - if len(tool.Function.Parameters) > 0 { - paramsJSON, _ := json.Marshal(tool.Function.Parameters) - sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) - } - sb.WriteString("\n") - } - - return sb.String() -} - // parseClaudeCliResponse parses the JSON output from the claude CLI. func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) { var resp claudeCliJSONResponse diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go index 3a3cafaca9..d4d648f5a5 100644 --- a/pkg/providers/claude_cli_provider_test.go +++ b/pkg/providers/claude_cli_provider_test.go @@ -660,12 +660,11 @@ func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) { // --- buildToolsPrompt tests --- func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) { - p := NewClaudeCliProvider("/workspace") tools := []ToolDefinition{ {Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}}, {Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}}, } - got := p.buildToolsPrompt(tools) + got := buildCLIToolsPrompt(tools) if strings.Contains(got, "skip_me") { t.Error("buildToolsPrompt() should skip non-function tools") } @@ -675,11 +674,10 @@ func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) { } func TestBuildToolsPrompt_NoDescription(t *testing.T) { - p := NewClaudeCliProvider("/workspace") tools := []ToolDefinition{ {Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}}, } - got := p.buildToolsPrompt(tools) + got := buildCLIToolsPrompt(tools) if !strings.Contains(got, "bare_tool") { t.Error("should include tool name") } @@ -689,14 +687,13 @@ func TestBuildToolsPrompt_NoDescription(t *testing.T) { } func TestBuildToolsPrompt_NoParameters(t *testing.T) { - p := NewClaudeCliProvider("/workspace") tools := []ToolDefinition{ {Type: "function", Function: ToolFunctionDefinition{ Name: "no_params_tool", Description: "A tool with no parameters", }}, } - got := p.buildToolsPrompt(tools) + got := buildCLIToolsPrompt(tools) if strings.Contains(got, "Parameters:") { t.Error("should not include Parameters: section when nil") } diff --git a/pkg/providers/codex_cli_provider.go b/pkg/providers/codex_cli_provider.go index 4c783ece53..13f53ad9eb 100644 --- a/pkg/providers/codex_cli_provider.go +++ b/pkg/providers/codex_cli_provider.go @@ -115,7 +115,7 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio } if len(tools) > 0 { - sb.WriteString(p.buildToolsPrompt(tools)) + sb.WriteString(buildCLIToolsPrompt(tools)) sb.WriteString("\n\n") } @@ -128,38 +128,6 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio return sb.String() } -// buildToolsPrompt creates a tool definitions section for the prompt. -func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string { - var sb strings.Builder - - sb.WriteString("## Available Tools\n\n") - sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") - sb.WriteString("```json\n") - sb.WriteString( - `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`, - ) - sb.WriteString("\n```\n\n") - sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") - sb.WriteString("### Tool Definitions:\n\n") - - for _, tool := range tools { - if tool.Type != "function" { - continue - } - sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) - if tool.Function.Description != "" { - sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) - } - if len(tool.Function.Parameters) > 0 { - paramsJSON, _ := json.Marshal(tool.Function.Parameters) - sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) - } - sb.WriteString("\n") - } - - return sb.String() -} - // codexEvent represents a single JSONL event from `codex exec --json`. type codexEvent struct { Type string `json:"type"` diff --git a/pkg/providers/toolcall_utils.go b/pkg/providers/toolcall_utils.go index 49218b1b10..a33e1eb5ce 100644 --- a/pkg/providers/toolcall_utils.go +++ b/pkg/providers/toolcall_utils.go @@ -5,7 +5,43 @@ package providers -import "encoding/json" +import ( + "encoding/json" + "fmt" + "strings" +) + +// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt. +func buildCLIToolsPrompt(tools []ToolDefinition) string { + var sb strings.Builder + + sb.WriteString("## Available Tools\n\n") + sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") + sb.WriteString("```json\n") + sb.WriteString( + `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`, + ) + sb.WriteString("\n```\n\n") + sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") + sb.WriteString("### Tool Definitions:\n\n") + + for _, tool := range tools { + if tool.Type != "function" { + continue + } + sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) + if tool.Function.Description != "" { + sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) + } + if len(tool.Function.Parameters) > 0 { + paramsJSON, _ := json.Marshal(tool.Function.Parameters) + sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) + } + sb.WriteString("\n") + } + + return sb.String() +} // NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated. // It handles cases where Name/Arguments might be in different locations (top-level vs Function)