diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 12e3cdd4d..bde34561a 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -106,6 +106,8 @@ Your workspace is at: %s 4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content. +5. **Team delegation** - For any task that is non-trivial, multi-step, or involves distinct concerns (e.g. "convert React to Vue", "build a feature", "analyze and report"), you MUST use the 'team' tool to delegate and parallelize. Do NOT attempt to handle complex tasks inline by calling tools one by one yourself. Decompose first, delegate second, then report the outcome. + %s`, version, workspacePath, workspacePath, workspacePath, workspacePath, workspacePath, toolDiscovery) } diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 725d42614..b4e81f782 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -163,7 +163,10 @@ func registerSharedTools( if cfg.Tools.IsToolEnabled("web") { searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ - BraveAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Brave.APIKey(), cfg.Tools.Web.Brave.APIKeys()), + BraveAPIKeys: config.MergeAPIKeys( + cfg.Tools.Web.Brave.APIKey(), + cfg.Tools.Web.Brave.APIKeys(), + ), BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, BraveEnabled: cfg.Tools.Web.Brave.Enabled, TavilyAPIKeys: config.MergeAPIKeys( @@ -196,7 +199,11 @@ func registerSharedTools( Proxy: cfg.Tools.Web.Proxy, }) if err != nil { - logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()}) + logger.ErrorCF( + "agent", + "Failed to create web search tool", + map[string]any{"error": err.Error()}, + ) } else if searchTool != nil { agent.Tools.Register(searchTool) } @@ -209,7 +216,11 @@ func registerSharedTools( cfg.Tools.Web.FetchLimitBytes, cfg.Tools.Web.PrivateHostWhitelist) if err != nil { - logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + logger.ErrorCF( + "agent", + "Failed to create web fetch tool", + map[string]any{"error": err.Error()}, + ) } else { agent.Tools.Register(fetchTool) } @@ -284,12 +295,39 @@ func registerSharedTools( } } + // Team tool + teamSubagentManager := tools.NewSubagentManager( + provider, + agent.Model, + agent.Candidates, + agent.Workspace, + cfg.Tools.Team, + msgBus, + ) + teamSubagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) + + teamTool := tools.NewTeamTool(teamSubagentManager, cfg) + if cfg.Tools.IsToolEnabled("team") { + teamTool.SetSpawner(NewSubTurnSpawner(al)) + agent.Tools.Register(teamTool) + } + + // Share the fully-built registry back to team subagent manager + teamSubagentManager.SetTools(agent.Tools) + // Spawn and spawn_status tools share a SubagentManager. // Construct it when either tool is enabled (both require subagent). spawnEnabled := cfg.Tools.IsToolEnabled("spawn") spawnStatusEnabled := cfg.Tools.IsToolEnabled("spawn_status") if (spawnEnabled || spawnStatusEnabled) && cfg.Tools.IsToolEnabled("subagent") { - subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace) + subagentManager := tools.NewSubagentManager( + provider, + agent.Model, + agent.Candidates, + agent.Workspace, + cfg.Tools.Team, + msgBus, + ) subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) // Set the spawner that links into AgentLoop's turnState @@ -472,7 +510,12 @@ func (al *AgentLoop) Run(ctx context.Context) error { "queue_depth": al.pendingSteeringCountForScope(target.SessionKey), }) - continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID) + continued, continueErr := al.Continue( + ctx, + target.SessionKey, + target.Channel, + target.ChatID, + ) if continueErr != nil { logger.WarnCF("agent", "Failed to continue queued steering", map[string]any{ @@ -500,14 +543,22 @@ func (al *AgentLoop) Run(ctx context.Context) error { "queue_depth": al.pendingSteeringCountForScope(target.SessionKey), }) - continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID) + continued, continueErr := al.Continue( + ctx, + target.SessionKey, + target.Channel, + target.ChatID, + ) if continueErr != nil { - logger.WarnCF("agent", "Failed to continue queued steering after shutdown drain", + logger.WarnCF( + "agent", + "Failed to continue queued steering after shutdown drain", map[string]any{ "channel": target.Channel, "chat_id": target.ChatID, "error": continueErr.Error(), - }) + }, + ) return } if continued == "" { @@ -566,11 +617,15 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, active msgScope, _, scopeOK := al.resolveSteeringTarget(msg) if !scopeOK || msgScope != activeScope { if err := al.requeueInboundMessage(msg); err != nil { - logger.WarnCF("agent", "Failed to requeue non-steering inbound message", map[string]any{ - "error": err.Error(), - "channel": msg.Channel, - "sender_id": msg.SenderID, - }) + logger.WarnCF( + "agent", + "Failed to requeue non-steering inbound message", + map[string]any{ + "error": err.Error(), + "channel": msg.Channel, + "sender_id": msg.SenderID, + }, + ) } continue } @@ -604,7 +659,10 @@ func (al *AgentLoop) Stop() { al.running.Store(false) } -func (al *AgentLoop) publishResponseIfNeeded(ctx context.Context, channel, chatID, response string) { +func (al *AgentLoop) publishResponseIfNeeded( + ctx context.Context, + channel, chatID, response string, +) { if response == "" { return } @@ -1054,7 +1112,10 @@ var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`) // transcribeAudioInMessage resolves audio media refs, transcribes them, and // replaces audio annotations in msg.Content with the transcribed text. // Returns the (possibly modified) message and true if audio was transcribed. -func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.InboundMessage) (bus.InboundMessage, bool) { +func (al *AgentLoop) transcribeAudioInMessage( + ctx context.Context, + msg bus.InboundMessage, +) (bus.InboundMessage, bool) { if al.transcriber == nil || al.mediaStore == nil || len(msg.Media) == 0 { return msg, false } @@ -1064,7 +1125,11 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou for _, ref := range msg.Media { path, meta, err := al.mediaStore.ResolveWithMeta(ref) if err != nil { - logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err}) + logger.WarnCF( + "voice", + "Failed to resolve media ref", + map[string]any{"ref": ref, "error": err}, + ) continue } if !utils.IsAudioFile(meta.Filename, meta.ContentType) { @@ -1142,7 +1207,11 @@ func (al *AgentLoop) sendTranscriptionFeedback( ReplyToMessageID: messageID, }) if err != nil { - logger.WarnCF("voice", "Failed to send transcription feedback", map[string]any{"error": err.Error()}) + logger.WarnCF( + "voice", + "Failed to send transcription feedback", + map[string]any{"error": err.Error()}, + ) } } @@ -1342,7 +1411,9 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.runAgentLoop(ctx, agent, opts) } -func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) { +func (al *AgentLoop) resolveMessageRoute( + msg bus.InboundMessage, +) (routing.ResolvedRoute, *AgentInstance, error) { registry := al.GetRegistry() route := registry.ResolveRoute(routing.RouteInput{ Channel: msg.Channel, @@ -1358,7 +1429,10 @@ func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.Resolv agent = registry.GetDefaultAgent() } if agent == nil { - return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) + return routing.ResolvedRoute{}, nil, fmt.Errorf( + "no agent available for route (agent_id=%s)", + route.AgentID, + ) } return route, agent, nil @@ -2559,12 +2633,15 @@ turnLoop: } if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 { - logger.InfoCF("agent", "Steering arrived after turn completion; continuing turn before finalizing", + logger.InfoCF( + "agent", + "Steering arrived after turn completion; continuing turn before finalizing", map[string]any{ "agent_id": ts.agent.ID, "steering_count": len(steerMsgs), "session_key": ts.sessionKey, - }) + }, + ) pendingMessages = append(pendingMessages, steerMsgs...) finalContent = "" goto turnLoop @@ -2612,6 +2689,7 @@ turnLoop: finalContent: finalContent, status: turnStatus, followUps: append([]bus.InboundMessage(nil), ts.followUps...), + messages: append([]providers.Message(nil), messages...), }, nil } @@ -2680,11 +2758,18 @@ func (al *AgentLoop) selectCandidates( "score": score, "threshold": agent.Router.Threshold(), }) - return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()) + return agent.LightCandidates, resolvedCandidateModel( + agent.LightCandidates, + agent.Router.LightModel(), + ) } // maybeSummarize triggers summarization if the session history exceeds thresholds. -func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string, turnScope turnEventScope) { +func (al *AgentLoop) maybeSummarize( + agent *AgentInstance, + sessionKey string, + turnScope turnEventScope, +) { newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100 @@ -2718,7 +2803,10 @@ type compressionResult struct { // prompt is built dynamically by BuildMessages and is NOT stored here. // The compression note is recorded in the session summary so that // BuildMessages can include it in the next system prompt. -func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) (compressionResult, bool) { +func (al *AgentLoop) forceCompression( + agent *AgentInstance, + sessionKey string, +) (compressionResult, bool) { history := agent.Sessions.GetHistory(sessionKey) if len(history) <= 2 { return compressionResult{}, false @@ -2871,7 +2959,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string { } // summarizeSession summarizes the conversation history for a session. -func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string, turnScope turnEventScope) { +func (al *AgentLoop) summarizeSession( + agent *AgentInstance, + sessionKey string, + turnScope turnEventScope, +) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() @@ -3159,7 +3251,10 @@ func (al *AgentLoop) handleCommand( } } -func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime { +func (al *AgentLoop) buildCommandsRuntime( + agent *AgentInstance, + opts *processOptions, +) *commands.Runtime { registry := al.GetRegistry() cfg := al.GetConfig() rt := &commands.Runtime{ @@ -3200,7 +3295,10 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt } if agent != nil { rt.GetModelInfo = func() (string, string) { - return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider) + return agent.Model, resolvedCandidateProvider( + agent.Candidates, + cfg.Agents.Defaults.Provider, + ) } rt.SwitchModel = func(value string) (string, error) { value = strings.TrimSpace(value) @@ -3214,7 +3312,12 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt return "", fmt.Errorf("failed to initialize model %q: %w", value, err) } - nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks) + nextCandidates := resolveModelCandidates( + cfg, + cfg.Agents.Defaults.Provider, + modelCfg.Model, + agent.Fallbacks, + ) if len(nextCandidates) == 0 { return "", fmt.Errorf("model %q did not resolve to any provider candidates", value) } @@ -3303,7 +3406,10 @@ func (al *AgentLoop) applyExplicitSkillCommand( canonicalSkill, ok := agent.ContextBuilder.ResolveSkillName(fields[1]) if !ok { - return true, true, fmt.Sprintf("Unknown skill: %s\nUse /list skills to see installed skills.", fields[1]) + return true, true, fmt.Sprintf( + "Unknown skill: %s\nUse /list skills to see installed skills.", + fields[1], + ) } if len(fields) == 2 { diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index f5ba412ab..7035c8c1e 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -102,7 +102,9 @@ type subTurnRuntimeConfig struct { // // Parent turn will poll and process it in a later iteration type SubTurnConfig struct { Model string + Provider providers.LLMProvider // non-nil overrides the child agent's provider Tools []tools.Tool + EmptyTools bool // true: child agent gets an empty ToolRegistry (overrides Tools) SystemPrompt string MaxTokens int @@ -220,7 +222,9 @@ func (s *AgentLoopSpawner) SpawnSubTurn( // Convert tools.SubTurnConfig to agent.SubTurnConfig agentCfg := SubTurnConfig{ Model: cfg.Model, + Provider: cfg.Provider, Tools: cfg.Tools, + EmptyTools: cfg.EmptyTools, SystemPrompt: cfg.SystemPrompt, ActualSystemPrompt: cfg.ActualSystemPrompt, InitialMessages: cfg.InitialMessages, @@ -232,6 +236,17 @@ func (s *AgentLoopSpawner) SpawnSubTurn( MaxContextRunes: cfg.MaxContextRunes, } + // Resolve model → provider when only a model name is given (no explicit provider). + // This enables heterogeneous model routing from tool-layer callers (e.g. subagent tool). + if agentCfg.Provider == nil && agentCfg.Model != "" { + if modelCfg, err := s.al.GetConfig().GetModelConfig(agentCfg.Model); err == nil { + if p, m, err := providers.CreateProviderFromConfig(modelCfg); err == nil { + agentCfg.Provider = p + agentCfg.Model = m + } + } + } + return spawnSubTurn(ctx, s.al, parentTS, agentCfg) } @@ -344,9 +359,25 @@ func spawnSubTurn( ephemeralStore := newEphemeralSession(nil) agent := *baseAgent // shallow copy agent.Sessions = ephemeralStore + // Apply model/provider override for heterogeneous agents. + if cfg.Model != "" { + agent.Model = cfg.Model + } + if cfg.Provider != nil { + agent.Provider = cfg.Provider + } // Clone the tool registry so child turn's tool registrations // don't pollute the parent's registry. - if baseAgent.Tools != nil { + if cfg.EmptyTools { + agent.Tools = tools.NewToolRegistry() + } else if cfg.Tools != nil { + // Tools override will be applied via processOptions below. + // Clone parent registry as base, then replace with cfg.Tools entries. + agent.Tools = tools.NewToolRegistry() + for _, t := range cfg.Tools { + agent.Tools.Register(t) + } + } else if baseAgent.Tools != nil { agent.Tools = baseAgent.Tools.Clone() } @@ -477,8 +508,9 @@ func spawnSubTurn( } } else { result = &tools.ToolResult{ - ForLLM: turnRes.finalContent, - ForUser: turnRes.finalContent, + ForLLM: turnRes.finalContent, + ForUser: turnRes.finalContent, + Messages: turnRes.messages, } } diff --git a/pkg/agent/turn.go b/pkg/agent/turn.go index e4970c519..7545069c1 100644 --- a/pkg/agent/turn.go +++ b/pkg/agent/turn.go @@ -43,6 +43,7 @@ type turnResult struct { finalContent string status TurnEndStatus followUps []bus.InboundMessage + messages []providers.Message // ephemeral session history after execution (for state continuation) } type turnState struct { diff --git a/pkg/config/config.go b/pkg/config/config.go index f0d9aa580..33dcb524a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -78,6 +78,24 @@ func (f *FlexibleStringSlice) UnmarshalText(text []byte) error { return nil } +type TeamModelConfig struct { + Name string `json:"name"` + Tags []string `json:"tags,omitempty"` +} + +type TeamToolsConfig struct { + ToolConfig + MaxMembers int `json:"max_members"` + MaxTeamTokens int `json:"max_team_tokens"` + MaxEvaluatorLoops int `json:"max_evaluator_loops"` + MaxTimeoutMinutes int `json:"max_timeout_minutes"` + MaxContextRunes int `json:"max_context_runes"` + DisableAutoReviewer bool `json:"disable_auto_reviewer"` + ReviewerModel string `json:"reviewer_model"` + AllowedStrategies []string `json:"allowed_strategies"` + AllowedModels []TeamModelConfig `json:"allowed_models"` +} + // CurrentVersion is the latest config schema version const CurrentVersion = 1 @@ -997,10 +1015,12 @@ func (c *ModelConfig) SetAPIKey(value string) { } type GatewayConfig struct { - Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"` - Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` - HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"` - LogLevel string `json:"log_level,omitempty" env:"PICOCLAW_LOG_LEVEL"` + Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"` + Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` + HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"` + // LogLevel controls the logging verbosity for the gateway server. + // Valid values: "debug", "info", "warn", "error", "fatal" (default: "fatal") + LogLevel string `json:"log_level,omitempty" env:"PICOCLAW_LOG_LEVEL"` } type ToolDiscoveryConfig struct { @@ -1250,6 +1270,8 @@ type ToolsConfig struct { SpawnStatus ToolConfig `json:"spawn_status" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"` SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"` Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"` + SpawnSubAgent ToolConfig `json:"spawn_sub_agent" envPrefix:"PICOCLAW_TOOLS_SPAWN_SUB_AGENT_"` + Team TeamToolsConfig `json:"team" envPrefix:"PICOCLAW_TOOLS_TEAM_"` WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"` WriteFile ToolConfig `json:"write_file" envPrefix:"PICOCLAW_TOOLS_WRITE_FILE_"` } @@ -1377,7 +1399,10 @@ func LoadConfig(path string) (*Config, error) { var cfg *Config switch versionInfo.Version { case 0: - logger.InfoF("config migrate start", map[string]any{"from": versionInfo.Version, "to": CurrentVersion}) + logger.InfoF( + "config migrate start", + map[string]any{"from": versionInfo.Version, "to": CurrentVersion}, + ) // Legacy config (no version field) v, e := loadConfigV0(data) if e != nil { @@ -1422,9 +1447,11 @@ func LoadConfig(path string) (*Config, error) { for _, m := range cfg.ModelList { for _, k := range m.apiKeys { if k != "" && !strings.HasPrefix(k, "enc://") && !strings.HasPrefix(k, "file://") { - fmt.Fprintf(os.Stderr, + fmt.Fprintf( + os.Stderr, "picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n", - m.ModelName) + m.ModelName, + ) break // Only warn once per model } } diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 39d45013d..3db95f9ac 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -12,6 +12,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "github.com/sipeed/picoclaw/pkg/fileutil" @@ -306,6 +307,13 @@ func (t *ReadFileTool) Parameters() map[string]any { } } +func (t *ReadFileTool) UpgradeToConcurrent() Tool { + return &ReadFileTool{ + fs: &ConcurrentFS{baseFS: t.fs}, + maxSize: t.maxSize, + } +} + func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { path, ok := args["path"].(string) if !ok { @@ -521,6 +529,12 @@ func (t *WriteFileTool) Parameters() map[string]any { } } +func (t *WriteFileTool) UpgradeToConcurrent() Tool { + return &WriteFileTool{ + fs: &ConcurrentFS{baseFS: t.fs}, + } +} + func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { path, ok := args["path"].(string) if !ok { @@ -610,6 +624,7 @@ func formatDirEntries(entries []os.DirEntry) *ToolResult { type fileSystem interface { ReadFile(path string) ([]byte, error) WriteFile(path string, data []byte) error + EditFile(path string, editFn func([]byte) ([]byte, error)) error ReadDir(path string) ([]os.DirEntry, error) Open(path string) (fs.File, error) } @@ -655,6 +670,18 @@ func (h *hostFs) Open(path string) (fs.File, error) { return f, nil } +func (h *hostFs) EditFile(path string, editFn func([]byte) ([]byte, error)) error { + data, err := h.ReadFile(path) + if err != nil { + return err + } + newData, err := editFn(data) + if err != nil { + return err + } + return h.WriteFile(path, newData) +} + // sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root. type sandboxFs struct { workspace string @@ -786,6 +813,18 @@ func (r *sandboxFs) Open(path string) (fs.File, error) { return f, err } +func (r *sandboxFs) EditFile(path string, editFn func([]byte) ([]byte, error)) error { + data, err := r.ReadFile(path) + if err != nil { + return err + } + newData, err := editFn(data) + if err != nil { + return err + } + return r.WriteFile(path, newData) +} + // whitelistFs wraps a sandboxFs and allows access to specific paths outside // the workspace when they match any of the provided patterns. type whitelistFs struct { @@ -826,6 +865,13 @@ func (w *whitelistFs) Open(path string) (fs.File, error) { return w.sandbox.Open(path) } +func (w *whitelistFs) EditFile(path string, editFn func([]byte) ([]byte, error)) error { + if w.matches(path) { + return w.host.EditFile(path, editFn) + } + return w.sandbox.EditFile(path, editFn) +} + // buildFs returns the appropriate fileSystem implementation based on restriction // settings and optional path whitelist patterns. func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem { @@ -860,3 +906,56 @@ func getSafeRelPath(workspace, path string) (string, error) { return rel, nil } + +// ConcurrencyUpgradeable indicates a Tool operates on files and can be upgraded +// to use a thread-safe locking proxy backend (`ConcurrentFS`) for Parallel or DAG agent teams. +type ConcurrencyUpgradeable interface { + UpgradeToConcurrent() Tool +} + +// Global file locks explicitly for concurrent agent strategies +var globalFileLocks sync.Map // map[string]*sync.RWMutex + +func getPathLock(path string) *sync.RWMutex { + cleanPath := filepath.Clean(path) + actual, _ := globalFileLocks.LoadOrStore(cleanPath, &sync.RWMutex{}) + return actual.(*sync.RWMutex) +} + +// ConcurrentFS is a lightweight proxy wrapper around any `fileSystem`. +// It guarantees thread-safe, race-condition-free access by locking the absolute file path globally. +type ConcurrentFS struct { + baseFS fileSystem +} + +func (c *ConcurrentFS) ReadFile(path string) ([]byte, error) { + lock := getPathLock(path) + lock.RLock() + defer lock.RUnlock() + return c.baseFS.ReadFile(path) +} + +func (c *ConcurrentFS) WriteFile(path string, data []byte) error { + lock := getPathLock(path) + lock.Lock() + defer lock.Unlock() + return c.baseFS.WriteFile(path, data) +} + +func (c *ConcurrentFS) EditFile(path string, editFn func([]byte) ([]byte, error)) error { + lock := getPathLock(path) + lock.Lock() + defer lock.Unlock() + return c.baseFS.EditFile(path, editFn) +} + +func (c *ConcurrentFS) ReadDir(path string) ([]os.DirEntry, error) { + return c.baseFS.ReadDir(path) +} + +func (c *ConcurrentFS) Open(path string) (fs.File, error) { + lock := getPathLock(path) + lock.RLock() + defer lock.RUnlock() + return c.baseFS.Open(path) +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index ed373a28f..39b01827a 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -150,6 +150,17 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) { return entry.Tool, true } +// ListTools returns a slice of all registered tool names. +func (r *ToolRegistry) ListTools() []string { + r.mu.RLock() + defer r.mu.RUnlock() + names := make([]string, 0, len(r.tools)) + for name := range r.tools { + names = append(names, name) + } + return names +} + func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult { return r.ExecuteWithContext(ctx, name, args, "", "", nil) } diff --git a/pkg/tools/spawn_status_test.go b/pkg/tools/spawn_status_test.go index 9c772d61a..6b734191c 100644 --- a/pkg/tools/spawn_status_test.go +++ b/pkg/tools/spawn_status_test.go @@ -6,12 +6,14 @@ import ( "strings" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/config" ) func TestSpawnStatusTool_Name(t *testing.T) { provider := &MockLLMProvider{} workspace := t.TempDir() - manager := NewSubagentManager(provider, "test-model", workspace) + manager := NewSubagentManager(provider, "test-model", nil, workspace, config.TeamToolsConfig{}, nil) tool := NewSpawnStatusTool(manager) if tool.Name() != "spawn_status" { @@ -22,7 +24,7 @@ func TestSpawnStatusTool_Name(t *testing.T) { func TestSpawnStatusTool_Description(t *testing.T) { provider := &MockLLMProvider{} workspace := t.TempDir() - manager := NewSubagentManager(provider, "test-model", workspace) + manager := NewSubagentManager(provider, "test-model", nil, workspace, config.TeamToolsConfig{}, nil) tool := NewSpawnStatusTool(manager) desc := tool.Description() @@ -37,7 +39,7 @@ func TestSpawnStatusTool_Description(t *testing.T) { func TestSpawnStatusTool_Parameters(t *testing.T) { provider := &MockLLMProvider{} workspace := t.TempDir() - manager := NewSubagentManager(provider, "test-model", workspace) + manager := NewSubagentManager(provider, "test-model", nil, workspace, config.TeamToolsConfig{}, nil) tool := NewSpawnStatusTool(manager) params := tool.Parameters() @@ -64,7 +66,7 @@ func TestSpawnStatusTool_NilManager(t *testing.T) { func TestSpawnStatusTool_Empty(t *testing.T) { provider := &MockLLMProvider{} workspace := t.TempDir() - manager := NewSubagentManager(provider, "test-model", workspace) + manager := NewSubagentManager(provider, "test-model", nil, workspace, config.TeamToolsConfig{}, nil) tool := NewSpawnStatusTool(manager) result := tool.Execute(context.Background(), map[string]any{}) @@ -79,7 +81,7 @@ func TestSpawnStatusTool_Empty(t *testing.T) { func TestSpawnStatusTool_ListAll(t *testing.T) { provider := &MockLLMProvider{} workspace := t.TempDir() - manager := NewSubagentManager(provider, "test-model", workspace) + manager := NewSubagentManager(provider, "test-model", nil, workspace, config.TeamToolsConfig{}, nil) now := time.Now().UnixMilli() manager.mu.Lock() @@ -140,7 +142,7 @@ func TestSpawnStatusTool_ListAll(t *testing.T) { func TestSpawnStatusTool_GetByID(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) manager.mu.Lock() manager.tasks["subagent-42"] = &SubagentTask{ @@ -175,7 +177,7 @@ func TestSpawnStatusTool_GetByID(t *testing.T) { func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSpawnStatusTool(manager) result := tool.Execute(context.Background(), map[string]any{"task_id": "nonexistent-999"}) @@ -189,7 +191,7 @@ func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) { func TestSpawnStatusTool_TaskID_NonString(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSpawnStatusTool(manager) for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} { @@ -205,7 +207,7 @@ func TestSpawnStatusTool_TaskID_NonString(t *testing.T) { func TestSpawnStatusTool_ResultTruncation(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) longResult := strings.Repeat("X", 500) manager.mu.Lock() @@ -234,7 +236,7 @@ func TestSpawnStatusTool_ResultTruncation(t *testing.T) { func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) // Each CJK rune is 3 bytes; 400 runes = 1200 bytes — well over the 300-rune limit. cjkChar := string(rune(0x5b57)) @@ -265,7 +267,7 @@ func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) { func TestSpawnStatusTool_StatusCounts(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) manager.mu.Lock() for i, status := range []string{"running", "running", "completed", "failed", "canceled"} { @@ -290,7 +292,7 @@ func TestSpawnStatusTool_StatusCounts(t *testing.T) { func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) now := time.Now().UnixMilli() manager.mu.Lock() @@ -325,7 +327,7 @@ func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) { func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) manager.mu.Lock() manager.tasks["subagent-1"] = &SubagentTask{ @@ -357,7 +359,7 @@ func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) { func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) manager.mu.Lock() manager.tasks["subagent-99"] = &SubagentTask{ @@ -379,7 +381,7 @@ func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) { func TestSpawnStatusTool_ChannelFiltering_NoContext(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) manager.mu.Lock() manager.tasks["subagent-1"] = &SubagentTask{ diff --git a/pkg/tools/spawn_test.go b/pkg/tools/spawn_test.go index fda6bbd89..70e7b613e 100644 --- a/pkg/tools/spawn_test.go +++ b/pkg/tools/spawn_test.go @@ -4,6 +4,8 @@ import ( "context" "strings" "testing" + + "github.com/sipeed/picoclaw/pkg/config" ) // mockSpawner implements SubTurnSpawner for testing @@ -26,7 +28,7 @@ func (m *mockSpawner) SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*Too func TestSpawnTool_Execute_EmptyTask(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSpawnTool(manager) ctx := context.Background() @@ -60,7 +62,7 @@ func TestSpawnTool_Execute_EmptyTask(t *testing.T) { func TestSpawnTool_Execute_ValidTask(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSpawnTool(manager) tool.SetSpawner(&mockSpawner{}) diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 9a1a8b802..4af3eb26a 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -3,10 +3,13 @@ package tools import ( "context" "fmt" + "strings" "sync" "sync/atomic" "time" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -19,7 +22,9 @@ type SubTurnSpawner interface { // SubTurnConfig holds configuration for spawning a sub-turn. type SubTurnConfig struct { Model string + Provider providers.LLMProvider // non-nil overrides the child agent's provider Tools []Tool + EmptyTools bool // true: child agent gets an empty ToolRegistry (overrides Tools) SystemPrompt string MaxTokens int Temperature float64 @@ -32,6 +37,30 @@ type SubTurnConfig struct { InitialTokenBudget *atomic.Int64 // Shared token budget for team members; nil if no budget } +// ModelTag constants define the recognized capability labels for models in config.json. +// These are set via `"tags": ["vision", "code"]` under each model in the model list. +const ( + ModelTagVision = "vision" // Supports image/screenshot input (multimodal) + ModelTagImageGen = "image-gen" // Supports image generation output (e.g. DALL-E, Stable Diffusion) + ModelTagCode = "code" // Specialized for code generation and analysis + ModelTagFast = "fast" // Low-latency model, suited for lightweight tasks + ModelTagLongContext = "long-context" // Supports very long context windows (>100k tokens) + ModelTagReasoning = "reasoning" // Strong logical/math reasoning (e.g., o1, deepseek-r1) +) + +// modelTagDescriptions provides LLM-readable explanations of each known tag, +// injected at runtime into the tool description to guide model selection. +// +//nolint:unused // Reserved for future use in dynamic tool descriptions +var modelTagDescriptions = map[string]string{ + ModelTagVision: "can analyze images and screenshots (multimodal input)", + ModelTagImageGen: "can generate images from text descriptions (e.g. DALL-E, Stable Diffusion)", + ModelTagCode: "specialized in code generation and debugging", + ModelTagFast: "fast and lightweight, ideal for simple or high-frequency tasks", + ModelTagLongContext: "handles very long inputs (>100k tokens)", + ModelTagReasoning: "excels at logical reasoning, math, and multi-step planning", +} + type SubagentTask struct { ID string Task string @@ -58,8 +87,11 @@ type SubagentManager struct { mu sync.RWMutex provider providers.LLMProvider defaultModel string + allowedModels []providers.FallbackCandidate + bus *bus.MessageBus workspace string tools *ToolRegistry + teamConfig config.TeamToolsConfig maxIterations int maxTokens int temperature float64 @@ -71,12 +103,19 @@ type SubagentManager struct { func NewSubagentManager( provider providers.LLMProvider, - defaultModel, workspace string, + defaultModel string, + candidates []providers.FallbackCandidate, + workspace string, + teamConfig config.TeamToolsConfig, + bus *bus.MessageBus, ) *SubagentManager { return &SubagentManager{ tasks: make(map[string]*SubagentTask), provider: provider, defaultModel: defaultModel, + allowedModels: candidates, + teamConfig: teamConfig, + bus: bus, workspace: workspace, tools: NewToolRegistry(), maxIterations: 10, @@ -84,6 +123,65 @@ func NewSubagentManager( } } +// IsModelAllowed checks if a specific requested model exists in the permitted candidates list. +func (sm *SubagentManager) IsModelAllowed(model string) bool { + // If the user requested the default model directly, that's automatically allowed + if model == sm.defaultModel { + return true + } + + // 1. Check against explicitly allowed models in team config + for _, cand := range sm.teamConfig.AllowedModels { + if cand.Name == model { + return true + } + } + + // 2. Otherwise, check against the resolved candidates (primary + fallbacks + explicitly configured) + // If teamConfig.AllowedModels is set, we strictly enforce it and DO NOT fall back to candidates + // unless the candidate model has tags that overlap with AllowedTags. But since AllowedTags + // was not implemented yet, just check fallback for backwards compatibility if teamConfig is empty. + if len(sm.teamConfig.AllowedModels) > 0 { + return false + } + + for _, cand := range sm.allowedModels { + if cand.Model == model { + return true + } + } + return false +} + +// ModelCapabilityHint generates a human-readable summary of allowed models and their tags. +// This is injected into the coordinator's tool descriptions so the LLM can make better routing decisions. +func (sm *SubagentManager) ModelCapabilityHint() string { + if len(sm.allowedModels) == 0 { + return "" + } + + var modelLines []string + for _, cand := range sm.allowedModels { + modelLines = append(modelLines, fmt.Sprintf(" - %s (general purpose)", cand.Model)) + } + + hint := "When selecting a 'model' for sub-agents, use ONLY these configured models:\n" + if len(sm.teamConfig.AllowedModels) > 0 { + for _, cand := range sm.teamConfig.AllowedModels { + tagsStr := "" + if len(cand.Tags) > 0 { + tagsStr = fmt.Sprintf(" [%s]", strings.Join(cand.Tags, ", ")) + } + hint += fmt.Sprintf(" - %s%s\n", cand.Name, tagsStr) + } + } else { + hint += strings.Join(modelLines, "\n") + } + + hint += "\nIf a task requires vision/image analysis, you MUST select a model with the 'vision' tag. If no suitable model is available, omit the 'model' field to use the default." + return hint +} + func (sm *SubagentManager) SetSpawner(spawner SpawnSubTurnFunc) { sm.mu.Lock() defer sm.mu.Unlock() @@ -154,9 +252,6 @@ func (sm *SubagentManager) runTask( ) { task.Status = "running" task.Created = time.Now().UnixMilli() - // TODO(eventbus): once subagents are modeled as child turns inside - // pkg/agent, emit SubTurnEnd and SubTurnResultDelivered from the parent - // AgentLoop instead of this legacy manager. // Check if context is already canceled before starting select { @@ -244,7 +339,6 @@ After completing the task, provide a clear summary of what was done.` sm.mu.Lock() defer func() { sm.mu.Unlock() - // Call callback if provided and result is set if callback != nil && result != nil { callback(ctx, result) } @@ -253,7 +347,6 @@ After completing the task, provide a clear summary of what was done.` if err != nil { task.Status = "failed" task.Result = fmt.Sprintf("Error: %v", err) - // Check if it was canceled if ctx.Err() != nil { task.Status = "canceled" task.Result = "Task canceled during execution" @@ -315,13 +408,40 @@ func (sm *SubagentManager) ListTaskCopies() []SubagentTask { return copies } +// BuildBaseWorkerConfig returns a base ToolLoopConfig that can be customized for isolated workers. +func (sm *SubagentManager) BuildBaseWorkerConfig(ctx context.Context) ToolLoopConfig { + sm.mu.RLock() + defer sm.mu.RUnlock() + + var llmOptions map[string]any + if sm.hasMaxTokens || sm.hasTemperature { + llmOptions = map[string]any{} + if sm.hasMaxTokens { + llmOptions["max_tokens"] = sm.maxTokens + } + if sm.hasTemperature { + llmOptions["temperature"] = sm.temperature + } + } + + return ToolLoopConfig{ + Provider: sm.provider, + Model: sm.defaultModel, + Tools: sm.tools, + MaxIterations: sm.maxIterations, + LLMOptions: llmOptions, + } +} + // SubagentTool executes a subagent task synchronously and returns the result. // It directly calls SubTurnSpawner with Async=false for synchronous execution. type SubagentTool struct { - spawner SubTurnSpawner - defaultModel string - maxTokens int - temperature float64 + spawner SubTurnSpawner + defaultModel string + maxTokens int + temperature float64 + isModelAllowed func(string) bool // nil means no allowlist check + modelHint func() string // nil means no model hint in description } func NewSubagentTool(manager *SubagentManager) *SubagentTool { @@ -329,9 +449,11 @@ func NewSubagentTool(manager *SubagentManager) *SubagentTool { return &SubagentTool{} } return &SubagentTool{ - defaultModel: manager.defaultModel, - maxTokens: manager.maxTokens, - temperature: manager.temperature, + defaultModel: manager.defaultModel, + maxTokens: manager.maxTokens, + temperature: manager.temperature, + isModelAllowed: manager.IsModelAllowed, + modelHint: manager.ModelCapabilityHint, } } @@ -345,7 +467,13 @@ func (t *SubagentTool) Name() string { } func (t *SubagentTool) Description() string { - return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM." + base := "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance with an optional role (system prompt) and model. Returns execution summary to user and full details to LLM." + if t.modelHint != nil { + if hint := t.modelHint(); hint != "" { + return base + "\n\n" + hint + } + } + return base } func (t *SubagentTool) Parameters() map[string]any { @@ -360,6 +488,14 @@ func (t *SubagentTool) Parameters() map[string]any { "type": "string", "description": "Optional short label for the task (for display)", }, + "role": map[string]any{ + "type": "string", + "description": "Optional system prompt / role assignment for the subagent (e.g. 'You are an expert code reviewer'). If omitted, a default subagent prompt is used.", + }, + "model": map[string]any{ + "type": "string", + "description": "Optional specific LLM model ID to route this task to. If omitted, inherits the parent's model.", + }, }, "required": []string{"task"}, } @@ -372,15 +508,35 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe } label, _ := args["label"].(string) + role, _ := args["role"].(string) + modelParam, _ := args["model"].(string) + modelParam = strings.TrimSpace(modelParam) + + // Validate model against allowlist if provided + if modelParam != "" && t.isModelAllowed != nil && !t.isModelAllowed(modelParam) { + return ErrorResult(fmt.Sprintf("requested model '%s' is not in the allowed models list", modelParam)). + WithError(fmt.Errorf("model %s not allowed", modelParam)) + } + + // Determine the model to use + targetModel := t.defaultModel + if modelParam != "" { + targetModel = modelParam + } - // Build system prompt for subagent + // Build ActualSystemPrompt: prefer explicit role, fall back to auto-generated prompt + var actualSystemPrompt string + if role != "" { + actualSystemPrompt = role + } + + // Build SystemPrompt (task description, becomes first user message in sub-turn) systemPrompt := fmt.Sprintf( `You are a subagent. Complete the given task independently and provide a clear, concise result. Task: %s`, task, ) - if label != "" { systemPrompt = fmt.Sprintf( `You are a subagent labeled "%s". Complete the given task independently and provide a clear, concise result. @@ -394,12 +550,13 @@ Task: %s`, // Use spawner if available (direct SpawnSubTurn call) if t.spawner != nil { result, err := t.spawner.SpawnSubTurn(ctx, SubTurnConfig{ - Model: t.defaultModel, - Tools: nil, // Will inherit from parent via context - SystemPrompt: systemPrompt, - MaxTokens: t.maxTokens, - Temperature: t.temperature, - Async: false, // Synchronous execution + Model: targetModel, + Tools: nil, // Will inherit from parent via context + SystemPrompt: systemPrompt, + ActualSystemPrompt: actualSystemPrompt, + MaxTokens: t.maxTokens, + Temperature: t.temperature, + Async: false, // Synchronous execution }) if err != nil { return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go index 89ac7d4b5..68ced6f0b 100644 --- a/pkg/tools/subagent_tool_test.go +++ b/pkg/tools/subagent_tool_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -46,7 +47,7 @@ func (m *MockLLMProvider) GetContextWindow() int { func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) manager.SetLLMOptions(2048, 0.6) // Verify options are set on manager @@ -67,7 +68,7 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) { // TestSubagentTool_Name verifies tool name func TestSubagentTool_Name(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSubagentTool(manager) if tool.Name() != "subagent" { @@ -78,7 +79,7 @@ func TestSubagentTool_Name(t *testing.T) { // TestSubagentTool_Description verifies tool description func TestSubagentTool_Description(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSubagentTool(manager) desc := tool.Description() @@ -93,7 +94,7 @@ func TestSubagentTool_Description(t *testing.T) { // TestSubagentTool_Parameters verifies tool parameters schema func TestSubagentTool_Parameters(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSubagentTool(manager) params := tool.Parameters() @@ -143,7 +144,7 @@ func TestSubagentTool_Parameters(t *testing.T) { // TestSubagentTool_Execute_Success tests successful execution func TestSubagentTool_Execute_Success(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSubagentTool(manager) tool.SetSpawner(&mockSpawner{}) @@ -198,7 +199,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) { // TestSubagentTool_Execute_NoLabel tests execution without label func TestSubagentTool_Execute_NoLabel(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSubagentTool(manager) tool.SetSpawner(&mockSpawner{}) @@ -222,7 +223,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) { // TestSubagentTool_Execute_MissingTask tests error handling for missing task func TestSubagentTool_Execute_MissingTask(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSubagentTool(manager) ctx := context.Background() @@ -272,7 +273,7 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) { // TestSubagentTool_Execute_ContextPassing verifies context is properly used func TestSubagentTool_Execute_ContextPassing(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSubagentTool(manager) tool.SetSpawner(&mockSpawner{}) @@ -298,7 +299,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) { func TestSubagentTool_ForUserTruncation(t *testing.T) { // Create a mock provider that returns very long content provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test") + manager := NewSubagentManager(provider, "test-model", nil, "/tmp/test", config.TeamToolsConfig{}, nil) tool := NewSubagentTool(manager) tool.SetSpawner(&mockSpawner{}) diff --git a/pkg/tools/team.go b/pkg/tools/team.go new file mode 100644 index 000000000..d2fb49917 --- /dev/null +++ b/pkg/tools/team.go @@ -0,0 +1,1156 @@ +package tools + +import ( + "context" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +type TeamTool struct { + manager *SubagentManager + spawner SubTurnSpawner + cfg *config.Config + originChannel string + originChatID string +} + +type TeamMember struct { + ID string + Role string + Task string + Model string // Heterogeneous Agents: Optional specific model for this task + DependsOn []string // List of member IDs this member depends on + Produces string // Auto-reviewer: declares artifact type ("code", "data", "document") +} + +func NewTeamTool(manager *SubagentManager, cfg *config.Config) *TeamTool { + return &TeamTool{ + manager: manager, + cfg: cfg, + originChannel: "cli", + originChatID: "direct", + } +} + +// SetSpawner sets the SubTurnSpawner used to execute team members as sub-turns. +func (t *TeamTool) SetSpawner(spawner SubTurnSpawner) { + t.spawner = spawner +} + +func (t *TeamTool) Name() string { + return "team" +} + +func (t *TeamTool) Description() string { + base := `Compose and execute a team of specialized sub-agents to accomplish a complex task. + +WHEN TO USE THIS TOOL (use proactively — do not attempt to handle these alone): +- The task involves 2 or more distinct areas of concern (e.g. research + writing, coding + testing, data gathering + analysis). +- The task would require more than 5 consecutive tool calls if done alone. +- Any part of the task can be done in parallel to save time. +- The task is large enough that a single agent would likely lose context or quality midway. +- The user asks you to "build", "create", "generate", "analyze", or "convert" something non-trivial. +When in doubt, prefer delegation over doing everything yourself. + +CRITICAL RULES FOR TASK PLANNING: +1. Think like a project manager: analyze the full task first, then design the team structure before spawning anyone. +2. Decompose the task into the smallest independently-ownable units of work. A member should own exactly ONE distinct concern — not a broad compound goal. +3. Identify dependencies between units: if one member's output is required by another, declare it via 'depends_on'. Independent units should run concurrently. +4. Each member's 'task' must be precise and self-contained. Include relevant context (e.g. reference to outputs from dependencies) directly in the task description. +5. Sub-agents are full agents with access to the same tools, including this 'team' tool. If a member's sub-task is itself complex, it may recursively form its own team. + +Strategy guide: +- sequential: each step depends on the full output of the previous step in a strict chain. +- parallel: all tasks are fully independent with no shared inputs or outputs. +- dag: most real-world tasks — some tasks depend on others, some can run concurrently. +- evaluator_optimizer: the output needs iterative critique and revision cycles.` + + if t.manager != nil { + if hint := t.manager.ModelCapabilityHint(); hint != "" { + return base + "\n\n" + hint + } + } + return base +} + +func (t *TeamTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "strategy": map[string]any{ + "type": "string", + "enum": []string{"sequential", "parallel", "dag", "evaluator_optimizer"}, + "description": "How to run the team members. 'sequential': one after another. 'parallel': all at once. 'dag': execute based on declared dependencies. 'evaluator_optimizer': EXACTLY two members (worker & evaluator). The evaluator will check the worker's output; if it fails, the worker is revived with its FULL stateful memory intact and asked to fix it. Use this for complex generation tasks (like coding) requiring deep reasoning.", + }, + "max_team_tokens": map[string]any{ + "type": "integer", + "description": "The maximum combined LLM tokens (prompt + completion) this entire team is allowed to consume. Once exceeded, the team is instantly killed.", + }, + "members": map[string]any{ + "type": "array", + "description": "The list of sub-agents in the team.", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "string", + "description": "Unique identifier for this member, used for dependencies in 'dag' strategy.", + }, + "role": map[string]any{ + "type": "string", + "description": "The system prompt/role assignment for the member.", + }, + "task": map[string]any{ + "type": "string", + "description": "The specific task this member needs to accomplish.", + }, + "model": map[string]any{ + "type": "string", + "description": "Optional specific LLM model ID to route this task to (e.g., 'gpt-4o' for vision, 'claude-3-5-sonnet' for logic). If omitted, inherits the parent's model.", + }, + "depends_on": map[string]any{ + "type": "array", + "description": "List of 'id' strings this member depends on. Only applicable for 'dag' strategy.", + "items": map[string]any{"type": "string"}, + }, + "produces": map[string]any{ + "type": "string", + "description": "Declares the type of artifact this member produces. Use 'code' for source code files, 'data' for structured data/JSON/CSV, 'document' for prose documents/reports. When set, the framework automatically appends a QA reviewer step after all workers finish to validate output correctness. Omit if no verification is needed.", + }, + }, + "required": []string{"role", "task"}, + }, + }, + }, + "required": []string{"strategy", "members"}, + } +} + +func (t *TeamTool) SetContext(channel, chatID string) { + t.originChannel = channel + t.originChatID = chatID +} + +// reviewerTaskTemplates maps a `produces` artifact type to the task prompt +// that the auto-injected QA reviewer will receive. +var reviewerTaskTemplates = map[string]string{ + "code": "You are a code quality reviewer. Read all code files in the workspace that were just written by your predecessors. Check for: syntax errors, incorrect or missing imports, broken logic, type mismatches, and any issues that would cause compilation or runtime failures. List every issue found with the filename and line number if possible. If everything looks correct, respond with 'REVIEW PASSED'.", + "data": "You are a data validation reviewer. Read all output data files (JSON, CSV, YAML, etc.) in the workspace. Check for: invalid format, missing required fields, schema inconsistencies, and malformed values. List every issue found. If everything is valid, respond with 'REVIEW PASSED'.", + "document": "You are a document quality reviewer. Read all output documents in the workspace. Check for: logical inconsistencies, incomplete sections, factual contradictions, and poor structure. List every issue found. If the documents are complete and correct, respond with 'REVIEW PASSED'.", +} + +// maybeRunAutoReviewer inspects TeamMembers for `produces` declarations. +// If any member produced a verifiable artifact type, it runs an automatic +// QA reviewer agent after all workers have completed. +func (t *TeamTool) maybeRunAutoReviewer( + ctx context.Context, + members []TeamMember, + baseConfig ToolLoopConfig, + workerSummary string, +) string { + // Collect unique produces types from all members + producedTypes := make(map[string]bool) + for _, m := range members { + if m.Produces != "" { + producedTypes[m.Produces] = true + } + } + if len(producedTypes) == 0 { + return "" // No verifiable artifacts declared, skip review + } + + // Build reviewer task: combine templates for all declared artifact types + var taskParts []string + for artifactType := range producedTypes { + if tmpl, ok := reviewerTaskTemplates[artifactType]; ok { + taskParts = append(taskParts, tmpl) + } + } + if len(taskParts) == 0 { + return "" // Unknown produces types, skip + } + + sm := t.manager + sm.mu.RLock() + teamConfig := sm.teamConfig + sm.mu.RUnlock() + + if teamConfig.DisableAutoReviewer { + return "" + } + + reviewerTask := strings.Join(taskParts, "\n\n") + + "\n\nContext from the workers that produced these artifacts:\n" + workerSummary + + reviewerMessages := []providers.Message{ + {Role: "user", Content: reviewerTask}, + } + + // Use a dedicated reviewer model if configured — typically a cheaper/faster model + // is sufficient for QA review, saving tokens compared to the main worker model. + reviewerConfig := baseConfig + if teamConfig.ReviewerModel != "" && sm.IsModelAllowed(teamConfig.ReviewerModel) { + reviewerConfig.Model = teamConfig.ReviewerModel + } + + cnf, err := t.cfg.GetModelConfig(reviewerConfig.Model) + + if err == nil { + var provider providers.LLMProvider + var model string + provider, model, err = providers.CreateProviderFromConfig(cnf) + + if err == nil { + reviewerConfig.Model = model + reviewerConfig.Provider = provider + } + } + + providerName := "unknown" + if reviewerConfig.Provider != nil { + providerName = reviewerConfig.Provider.GetDefaultModel() + } + + logger.InfoCF( + "team", + fmt.Sprintf("reviewer use provider: [%s] and model: [%s]", providerName, reviewerConfig.Model), + map[string]any{ + "model": teamConfig.ReviewerModel, + }, + ) + + loopContent, _, err := t.spawnWorker(ctx, reviewerConfig, reviewerMessages, nil) + if err != nil { + return fmt.Sprintf("[Auto-Reviewer] Failed to run: %v", err) + } + return "[Auto-Reviewer Result]\n" + loopContent +} + +func (t *TeamTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + strategy, ok := args["strategy"].(string) + if !ok { + return ErrorResult("strategy is required") + } + + if t.manager == nil { + return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil")) + } + + sm := t.manager + sm.mu.RLock() + teamConfig := sm.teamConfig + sm.mu.RUnlock() + + // 1. Validate Strategy + validStrategy := false + if len(teamConfig.AllowedStrategies) > 0 { + for _, s := range teamConfig.AllowedStrategies { + if strategy == s { + validStrategy = true + break + } + } + } else { + // Default allowed strategies if not configured + if strategy == "sequential" || strategy == "parallel" || strategy == "dag" || + strategy == "evaluator_optimizer" { + validStrategy = true + } + } + + if !validStrategy { + return ErrorResult(fmt.Sprintf("strategy '%s' is not allowed by configuration", strategy)) + } + + membersRaw, ok := args["members"].([]any) + if !ok || len(membersRaw) == 0 { + return ErrorResult("members map array is required and must not be empty") + } + + // 2. Validate Max Members + if teamConfig.MaxMembers > 0 && len(membersRaw) > teamConfig.MaxMembers { + return ErrorResult( + fmt.Sprintf( + "Team exceeds maximum allowed members (%d). You requested %d members.", + teamConfig.MaxMembers, + len(membersRaw), + ), + ) + } + + maxTokensFloat, ok := args["max_team_tokens"].(float64) + + // Enforce hard budget from config as the ceiling. + effectiveMaxTokens := int64(0) + if teamConfig.MaxTeamTokens > 0 { + effectiveMaxTokens = int64(teamConfig.MaxTeamTokens) + } + + if ok && maxTokensFloat > 0 { + requestedTokens := int64(maxTokensFloat) + if effectiveMaxTokens > 0 && requestedTokens > effectiveMaxTokens { + // LLM requested more than config allows: clamp to the hard ceiling. + // effectiveMaxTokens already holds the correct ceiling, no change needed. + } else if effectiveMaxTokens == 0 || requestedTokens < effectiveMaxTokens { + // LLM asked for less, or there is no hard limit: honor the requested budget. + effectiveMaxTokens = requestedTokens + } + } + + var budget *atomic.Int64 + if effectiveMaxTokens > 0 { + budget = &atomic.Int64{} + budget.Store(effectiveMaxTokens) + } + + var members []TeamMember + for i, mRaw := range membersRaw { + mMap, ok := mRaw.(map[string]any) + if !ok { + return ErrorResult(fmt.Sprintf("member at index %d is invalid", i)) + } + + id, iOk := mMap["id"].(string) + role, rOk := mMap["role"].(string) + task, tOk := mMap["task"].(string) + + if !rOk || !tOk || strings.TrimSpace(role) == "" || strings.TrimSpace(task) == "" { + return ErrorResult(fmt.Sprintf("member at index %d is missing required 'role' or 'task'", i)) + } + + // ID is highly recommended, generate one if missing for backwards compatibility + if !iOk || strings.TrimSpace(id) == "" { + id = fmt.Sprintf("member_%d", i) + } + + modelStr, _ := mMap["model"].(string) + modelStr = strings.TrimSpace(modelStr) + + var dependsOn []string + if depRaw, dOk := mMap["depends_on"].([]any); dOk { + for _, d := range depRaw { + if dStr, dsOk := d.(string); dsOk { + dependsOn = append(dependsOn, dStr) + } + } + } + + producesStr, _ := mMap["produces"].(string) + producesStr = strings.TrimSpace(producesStr) + + members = append(members, TeamMember{ + ID: id, + Role: role, + Task: task, + Model: modelStr, + DependsOn: dependsOn, + Produces: producesStr, + }) + } + + // Base struct setup + baseConfig := t.manager.BuildBaseWorkerConfig(ctx) + if budget != nil { + baseConfig.RemainingTokenBudget = budget + } + + // Create a cancellable context for team bounding. + // cancel() is always deferred so any spawned goroutines are cleaned up on return. + timeoutDur := 15 * time.Minute + if teamConfig.MaxTimeoutMinutes > 0 { + timeoutDur = time.Duration(teamConfig.MaxTimeoutMinutes) * time.Minute + } + teamCtx, cancel := context.WithTimeout(ctx, timeoutDur) + defer cancel() + + // If strategy is parallel or dag, we must upgrade the file tools to be concurrent-safe (locking) + if strategy == "parallel" || strategy == "dag" { + baseConfig.Tools = upgradeRegistryForConcurrency(baseConfig.Tools) + } + + // Resolve max context runes for dependency injection into downstream prompts. + contextLimit := 8000 // default + if teamConfig.MaxContextRunes > 0 { + contextLimit = teamConfig.MaxContextRunes + } + + switch strategy { + case "sequential": + result := t.executeSequential(teamCtx, baseConfig, members, contextLimit) + if reviewNote := t.maybeRunAutoReviewer(teamCtx, members, baseConfig, result.ForLLM); reviewNote != "" { + result.ForLLM += "\n\n" + reviewNote + result.ForUser += "\n\n" + reviewNote + } + return result + case "dag": + result := t.executeDAG(teamCtx, cancel, baseConfig, members, contextLimit) + if reviewNote := t.maybeRunAutoReviewer(teamCtx, members, baseConfig, result.ForLLM); reviewNote != "" { + result.ForLLM += "\n\n" + reviewNote + result.ForUser += "\n\n" + reviewNote + } + return result + case "evaluator_optimizer": + return t.executeEvaluatorOptimizer(teamCtx, baseConfig, members, contextLimit) + } + // parallel + result := t.executeParallel(teamCtx, baseConfig, members) + if reviewNote := t.maybeRunAutoReviewer(teamCtx, members, baseConfig, result.ForLLM); reviewNote != "" { + result.ForLLM += "\n\n" + reviewNote + result.ForUser += "\n\n" + reviewNote + } + return result +} + +// upgradeRegistryForConcurrency takes an existing ToolRegistry, clones it, +// and upgrades any tools that implement ConcurrencyUpgradeable to their locking counterparts. +func upgradeRegistryForConcurrency(original *ToolRegistry) *ToolRegistry { + if original == nil { + return nil + } + + upgraded := NewToolRegistry() + for _, name := range original.ListTools() { + tool, ok := original.Get(name) + if !ok { + continue + } + + if upgradeable, isUpgradeable := tool.(ConcurrencyUpgradeable); isUpgradeable { + upgraded.Register(upgradeable.UpgradeToConcurrent()) + } else { + upgraded.Register(tool) + } + } + return upgraded +} + +// spawnWorker executes a single team member's turn, routing through SubTurnSpawner when available. +// Returns (content, messages, error). The messages slice is non-nil only for stateful workers +// (evaluator_optimizer) and can be passed as InitialMessages for the next iteration. +func (t *TeamTool) spawnWorker( + ctx context.Context, + cfg ToolLoopConfig, + messages []providers.Message, + budget *atomic.Int64, +) (string, []providers.Message, error) { + if t.spawner == nil { + // Fallback: direct RunToolLoop (no turnState integration) + res, err := RunToolLoop(ctx, cfg, messages, t.originChannel, t.originChatID) + if err != nil { + return "", nil, err + } + return res.Content, res.Messages, nil + } + + // Convert ToolLoopConfig + messages into SubTurnConfig for SubTurnSpawner. + var toolSlice []Tool + if cfg.Tools != nil { + for _, name := range cfg.Tools.ListTools() { + if tool, ok := cfg.Tools.Get(name); ok { + toolSlice = append(toolSlice, tool) + } + } + } + + // Extract system prompt and non-system messages + var actualSystemPrompt string + var initialMessages []providers.Message + for _, msg := range messages { + if msg.Role == "system" { + actualSystemPrompt = msg.Content + } else { + initialMessages = append(initialMessages, msg) + } + } + + maxTokens, temperature := getLLMOptionsFromConfig(cfg) + + subCfg := SubTurnConfig{ + Model: cfg.Model, + Provider: cfg.Provider, + Tools: toolSlice, + ActualSystemPrompt: actualSystemPrompt, + InitialMessages: initialMessages, + MaxTokens: maxTokens, + Temperature: temperature, + Async: false, + InitialTokenBudget: budget, + } + + res, err := t.spawner.SpawnSubTurn(ctx, subCfg) + if err != nil { + return "", nil, err + } + return res.ForLLM, res.Messages, nil +} + +// spawnWorkerEmptyTools is like spawnWorker but forces an empty tool registry on the sub-turn. +// Used for the evaluator in evaluator_optimizer to prevent side effects. +func (t *TeamTool) spawnWorkerEmptyTools( + ctx context.Context, + cfg ToolLoopConfig, + messages []providers.Message, +) (string, error) { + if t.spawner == nil { + // Fallback: direct RunToolLoop with empty registry + emptyConfig := cfg + emptyConfig.Tools = NewToolRegistry() + res, err := RunToolLoop(ctx, emptyConfig, messages, t.originChannel, t.originChatID) + if err != nil { + return "", err + } + return res.Content, nil + } + + var actualSystemPrompt string + var initialMessages []providers.Message + for _, msg := range messages { + if msg.Role == "system" { + actualSystemPrompt = msg.Content + } else { + initialMessages = append(initialMessages, msg) + } + } + + maxTokens, temperature := getLLMOptionsFromConfig(cfg) + + subCfg := SubTurnConfig{ + Model: cfg.Model, + Provider: cfg.Provider, + EmptyTools: true, + ActualSystemPrompt: actualSystemPrompt, + InitialMessages: initialMessages, + MaxTokens: maxTokens, + Temperature: temperature, + Async: false, + } + + res, err := t.spawner.SpawnSubTurn(ctx, subCfg) + if err != nil { + return "", err + } + return res.ForLLM, nil +} + +// getLLMOptionsFromConfig extracts MaxTokens and Temperature from a ToolLoopConfig's LLMOptions map. +func getLLMOptionsFromConfig(cfg ToolLoopConfig) (int, float64) { + var maxTokens int + var temperature float64 + if cfg.LLMOptions != nil { + if v, ok := cfg.LLMOptions["max_tokens"].(int); ok { + maxTokens = v + } + if v, ok := cfg.LLMOptions["temperature"].(float64); ok { + temperature = v + } + } + return maxTokens, temperature +} + +// potentially overriding the model based on the member's definition. +func (t *TeamTool) buildWorkerConfig( + baseConfig ToolLoopConfig, + registry *ToolRegistry, + m TeamMember, +) (ToolLoopConfig, error) { + cfg := baseConfig + cfg.Tools = registry + // Heterogeneous Agents: Override model if this team member requested a specific one + if m.Model != "" { + if !t.manager.IsModelAllowed(m.Model) { + return cfg, fmt.Errorf( + "requested model '%s' is not in the allowed fallback candidates list for this agent workspace", + m.Model, + ) + } + // Resolve model name from model_list if it's an alias + //resolvedModel := m.Model + //if t.cfg != nil { + // for _, mc := range t.cfg.ModelList { + // if mc.ModelName == m.Model && mc.Model != "" { + // resolvedModel = mc.Model + // break + // } + // } + //} + + cnf, err := t.cfg.GetModelConfig(m.Model) + if err != nil { + return cfg, err + } + + provider, model, err := providers.CreateProviderFromConfig(cnf) + if err != nil { + return ToolLoopConfig{}, err + } + + cfg.Model = model + cfg.Provider = provider + } + + providerName := "unknown" + if cfg.Provider != nil { + providerName = cfg.Provider.GetDefaultModel() + } + + logger.InfoCF( + "team", + fmt.Sprintf("[%s] use provider: [%s] and model: [%s]", m.Role, providerName, cfg.Model), + map[string]any{ + "member_index": m.ID, + "model": m.Model, + }, + ) + + return cfg, nil +} + +func (t *TeamTool) executeSequential( + ctx context.Context, + baseConfig ToolLoopConfig, + members []TeamMember, + contextLimit int, +) *ToolResult { + var finalOutput strings.Builder + finalOutput.WriteString("Team Execution Summary (Sequential):\n\n") + + var previousResult string + + for i, m := range members { + // If there is a previous result, we append it to the task so the new agent sees it. + actualTask := m.Task + if i > 0 && previousResult != "" { + actualTask = fmt.Sprintf( + "%s\n\n--- Context from previous phase ---\n%s", + m.Task, + truncateContextN(previousResult, contextLimit), + ) + } + + messages := []providers.Message{ + {Role: "system", Content: m.Role}, + {Role: "user", Content: actualTask}, + } + + workerConfig, err := t.buildWorkerConfig(baseConfig, baseConfig.Tools, m) + if err != nil { + errStr := fmt.Sprintf("Phase %d (Role: %s) configuration failed: %v", i+1, m.Role, err) + finalOutput.WriteString(errStr + "\n") + return ErrorResult(errStr).WithError(err) + } + + content, _, err := t.spawnWorker(ctx, workerConfig, messages, baseConfig.RemainingTokenBudget) + if err != nil { + errStr := fmt.Sprintf("Phase %d (Role: %s) failed: %v", i+1, m.Role, err) + finalOutput.WriteString(errStr + "\n") + return ErrorResult(errStr).WithError(err) // Fail fast + } + + previousResult = content + + finalOutput.WriteString( + fmt.Sprintf("### Phase %d completed by Role: [%s]\n%s\n\n", i+1, m.Role, previousResult), + ) + } + + return &ToolResult{ + ForLLM: finalOutput.String(), + ForUser: buildUserSummary("Sequential", members, nil), + } +} + +func (t *TeamTool) executeParallel(ctx context.Context, baseConfig ToolLoopConfig, members []TeamMember) *ToolResult { + var wg sync.WaitGroup + type workResult struct { + index int + role string + res string + err error + } + + resultsChan := make(chan workResult, len(members)) + + for i, m := range members { + wg.Add(1) + go func(index int, member TeamMember) { + defer wg.Done() + + logger.InfoCF("team", fmt.Sprintf("[%s] Parallel worker starting", member.Role), map[string]any{ + "member_index": index, + "model": member.Model, + }) + + messages := []providers.Message{ + {Role: "system", Content: member.Role}, + {Role: "user", Content: member.Task}, + } + + workerConfig, err := t.buildWorkerConfig(baseConfig, baseConfig.Tools, member) + if err != nil { + resultsChan <- workResult{index: index, role: member.Role, err: err} + return + } + + content, _, err := t.spawnWorker(ctx, workerConfig, messages, baseConfig.RemainingTokenBudget) + if err != nil { + resultsChan <- workResult{index: index, role: member.Role, err: err} + return + } + resultsChan <- workResult{index: index, role: member.Role, res: content} + logger.InfoCF("team", fmt.Sprintf("[%s] Parallel worker finished", member.Role), map[string]any{ + "member_index": index, + }) + }(i, m) + } + + // Wait for all goroutines to finish + wg.Wait() + close(resultsChan) + + // Pre-allocate to maintain order since channels don't guarantee arrival order + orderedResults := make([]workResult, len(members)) + for res := range resultsChan { + orderedResults[res.index] = res + } + + var successOutput strings.Builder + var failureOutput strings.Builder + successCount, failureCount := 0, 0 + + successOutput.WriteString("Team Execution Summary (Parallel):\n\n") + + for _, res := range orderedResults { + if res.err != nil { + failureCount++ + failureOutput.WriteString(fmt.Sprintf("### Worker [%s] FAILED:\n%v\n\n", res.role, res.err)) + } else { + successCount++ + successOutput.WriteString(fmt.Sprintf("### Worker [%s] Output:\n%s\n\n", res.role, res.res)) + } + } + + if failureCount == 0 { + // All workers succeeded + return &ToolResult{ + ForLLM: successOutput.String(), + ForUser: buildUserSummary("Parallel", members, nil), + } + } + + // Partial failure: preserve successful results and append failure summary. + // This lets the coordinator decide how to handle the partial outcome. + fullOutput := successOutput.String() + if failureCount > 0 { + fullOutput += "---\n## ⚠️ Partial Failures\n\n" + failureOutput.String() + + fmt.Sprintf( + "\n%d/%d workers succeeded. %d worker(s) failed. The successful results above may still be usable.", + successCount, + len(members), + failureCount, + ) + } + + return &ToolResult{ + ForLLM: fullOutput, + ForUser: fmt.Sprintf( + "⚠️ Parallel execution: %d/%d workers succeeded. %d failed.", + successCount, + len(members), + failureCount, + ), + IsError: failureCount == len(members), + } +} + +func (t *TeamTool) executeEvaluatorOptimizer( + ctx context.Context, + baseConfig ToolLoopConfig, + members []TeamMember, + contextLimit int, +) *ToolResult { + if len(members) != 2 { + return ErrorResult("The evaluator_optimizer strategy requires exactly two members: [0] Worker, [1] Evaluator.") + } + + worker := members[0] + evaluator := members[1] + + var finalOutput strings.Builder + finalOutput.WriteString("Team Execution Summary (Evaluator-Optimizer):\n\n") + + // 1. Initialize the stateful memory for the worker + workerMessages := []providers.Message{ + {Role: "system", Content: worker.Role}, + {Role: "user", Content: worker.Task}, + } + + sm := t.manager + sm.mu.RLock() + teamConfig := sm.teamConfig + sm.mu.RUnlock() + + maxLoops := 5 + if teamConfig.MaxEvaluatorLoops > 0 { + maxLoops = teamConfig.MaxEvaluatorLoops + } + + // Pre-compute both configs once — they don't change between loop iterations. + workerConfig, err := t.buildWorkerConfig(baseConfig, baseConfig.Tools, worker) + if err != nil { + return ErrorResult(fmt.Sprintf("Worker configuration failed: %v", err)).WithError(err) + } + evalConfig, err := t.buildWorkerConfig(baseConfig, NewToolRegistry(), evaluator) + if err != nil { + return ErrorResult(fmt.Sprintf("Evaluator configuration failed: %v", err)).WithError(err) + } + + logger.InfoCF("team", "Evaluator-Optimizer starting", map[string]any{ + "worker": worker.Role, + "evaluator": evaluator.Role, + "max_loops": maxLoops, + }) + + for attempt := 1; attempt <= maxLoops; attempt++ { + finalOutput.WriteString(fmt.Sprintf("## Attempt %d\n", attempt)) + logger.InfoCF("team", fmt.Sprintf("Evaluator-Optimizer attempt %d/%d", attempt, maxLoops), map[string]any{}) + + // 2. Trigger Worker (resumes from its exact previous state!) + workerContent, workerMsgs, err := t.spawnWorker( + ctx, + workerConfig, + workerMessages, + baseConfig.RemainingTokenBudget, + ) + if err != nil { + errStr := fmt.Sprintf("Worker failed on attempt %d: %v", attempt, err) + finalOutput.WriteString(errStr + "\n") + return ErrorResult(errStr).WithError(err) + } + + // Save the worker's cognitive state so it remembers its thought process for the next loop + if workerMsgs != nil { + workerMessages = workerMsgs + } + + finalOutput.WriteString(fmt.Sprintf("### Worker Output:\n%s\n\n", workerContent)) + + // 3. Trigger Evaluator (Ephemeral, stateless evaluation) + // The evaluator only needs to reason about text — give it no tools to avoid + // unnecessary tool calls, wasted tokens, and potential side effects. + evalContext := fmt.Sprintf( + "%s\n\n--- Worker's Output to Evaluate ---\n%s\n\nIf the output is completely correct and fulfills the task, you MUST reply starting with strictly '[PASS]'. Otherwise, explain the issues in detail.", + evaluator.Task, + truncateContextN(workerContent, contextLimit), + ) + + evalMessages := []providers.Message{ + {Role: "system", Content: evaluator.Role}, + {Role: "user", Content: evalContext}, + } + + evalContent, err := t.spawnWorkerEmptyTools(ctx, evalConfig, evalMessages) + if err != nil { + errStr := fmt.Sprintf("Evaluator failed on attempt %d: %v", attempt, err) + finalOutput.WriteString(errStr + "\n") + return ErrorResult(errStr).WithError(err) + } + + finalOutput.WriteString(fmt.Sprintf("### Evaluator Feedback:\n%s\n\n", evalContent)) + + // 4. Check for PASS condition + if strings.HasPrefix(strings.TrimSpace(evalContent), "[PASS]") { + finalOutput.WriteString("✅ Evaluation Passed! Loop finished successfully.\n") + logger.InfoCF("team", "Evaluator-Optimizer passed", map[string]any{"attempt": attempt}) + return &ToolResult{ + ForLLM: finalOutput.String(), + ForUser: fmt.Sprintf( + "✅ Evaluator-Optimizer passed on attempt %d/%d (worker: %s).", + attempt, + maxLoops, + worker.Role, + ), + } + } + + logger.InfoCF( + "team", + "Evaluator-Optimizer did not pass, retrying", + map[string]any{"attempt": attempt, "max_loops": maxLoops}, + ) + + // 5. If not passed, and not the last attempt, inject feedback into Worker's stateful memory + if attempt < maxLoops { + injection := fmt.Sprintf( + "The evaluator rejected your previous attempt. Please fix the issues based on this feedback:\n\n%s", + evalContent, + ) + workerMessages = append(workerMessages, providers.Message{ + Role: "user", + Content: injection, + }) + } + } + + finalOutput.WriteString("❌ Maximum evaluation loops reached without a [PASS]. Returning current state.\n") + logger.WarnCF("team", "Evaluator-Optimizer exhausted max loops", map[string]any{"max_loops": maxLoops}) + return &ToolResult{ + ForLLM: finalOutput.String(), + ForUser: fmt.Sprintf("❌ Evaluator-Optimizer exhausted %d attempts without a [PASS].", maxLoops), + } +} + +func (t *TeamTool) executeDAG( + ctx context.Context, + cancel context.CancelFunc, + baseConfig ToolLoopConfig, + members []TeamMember, + contextLimit int, +) *ToolResult { + logger.InfoCF("team", "DAG execution starting", map[string]any{"member_count": len(members)}) + // 1. Build and VALIDATE dependency graph + memberMap := make(map[string]TeamMember) + inDegree := make(map[string]int) + graph := make(map[string][]string) // node -> nodes that depend on it + + // Register all valid members first + for _, m := range members { + memberMap[m.ID] = m + inDegree[m.ID] = 0 + graph[m.ID] = []string{} + } + + // Build edges and check for ghost nodes + for _, m := range members { + for _, dep := range m.DependsOn { + if _, exists := memberMap[dep]; !exists { + return ErrorResult( + fmt.Sprintf("DAG Validation Error: Member [%s] depends on undefined member [%s]", m.ID, dep), + ) + } + graph[dep] = append(graph[dep], m.ID) + inDegree[m.ID]++ + } + } + + // 1.5. Cycle Detection using Kahn's Algorithm + var kahnQueue []string + kahnInDegree := make(map[string]int) + for k, v := range inDegree { + kahnInDegree[k] = v + if v == 0 { + kahnQueue = append(kahnQueue, k) + } + } + + processedCount := 0 + for len(kahnQueue) > 0 { + curr := kahnQueue[0] + kahnQueue = kahnQueue[1:] + processedCount++ + + for _, dependent := range graph[curr] { + kahnInDegree[dependent]-- + if kahnInDegree[dependent] == 0 { + kahnQueue = append(kahnQueue, dependent) + } + } + } + + if processedCount != len(members) { + return ErrorResult( + "DAG Validation Error: Circular dependency (cycle) detected in the team layout. Please fix your 'depends_on' definitions.", + ) + } + + // 2. Channels for coordination + type nodeResult struct { + id string + res string + err error + } + readyChan := make(chan string, len(members)) + resultChan := make(chan nodeResult, len(members)) + + // Channels specifically for passing context from dependencies to dependants + contextMap := make(map[string]*strings.Builder) + var contextMu sync.Mutex + + // 3. Initialize queue with nodes having 0 in-degree + nodesToProcess := len(members) + completedNodes := 0 + for id, deg := range inDegree { + if deg == 0 { + readyChan <- id + } + } + + var wg sync.WaitGroup + var masterErr error + var masterErrMu sync.Mutex + + // Shared results store for the final output + finalResults := make(map[string]string) + var finalResultsMu sync.Mutex + + // 4. DAG Execution Loop + // Continue until all nodes have completed + for completedNodes < nodesToProcess { + select { + case <-ctx.Done(): + return ErrorResult("DAG execution timed out or canceled") + + case memberID := <-readyChan: + wg.Add(1) + go func(id string) { + defer wg.Done() + + m := memberMap[id] + + // Construct the task with context from all dependencies + actualTask := m.Task + contextMu.Lock() + b := contextMap[id] + depsContext := "" + if b != nil { + depsContext = b.String() + } + contextMu.Unlock() + + if depsContext != "" { + actualTask = fmt.Sprintf( + "%s\n\n--- Context from dependencies ---\n%s", + m.Task, + truncateContextN(depsContext, contextLimit), + ) + } + + messages := []providers.Message{ + {Role: "system", Content: m.Role}, + {Role: "user", Content: actualTask}, + } + + workerConfig, err := t.buildWorkerConfig(baseConfig, baseConfig.Tools, m) + if err != nil { + masterErrMu.Lock() + if masterErr == nil { + masterErr = err + } + masterErrMu.Unlock() + resultChan <- nodeResult{id: id, err: err} + return + } + + content, _, err := t.spawnWorker(ctx, workerConfig, messages, baseConfig.RemainingTokenBudget) + if err != nil { + masterErrMu.Lock() + if masterErr == nil { + masterErr = fmt.Errorf("worker [%s] failed: %v", m.ID, err) + } + masterErrMu.Unlock() + resultChan <- nodeResult{id: id, err: err} + return + } + + // Store result for final output + finalResultsMu.Lock() + finalResults[id] = content + finalResultsMu.Unlock() + + // Pass result to dependents + resultChan <- nodeResult{id: id, res: content} + }(memberID) + + case res := <-resultChan: + if res.err != nil { + // Fast fail on first error. + // Cancel the team context first so that all in-flight goroutines + // receive the cancellation signal and terminate cleanly. + cancel() + wg.Wait() + return ErrorResult(res.err.Error()) + } + + // Mark this node as completed + completedNodes++ + + // Update dependents + for _, dependentID := range graph[res.id] { + contextMu.Lock() + b := contextMap[dependentID] + if b == nil { + b = &strings.Builder{} + } + b.WriteString(fmt.Sprintf("--- Result from [%s] ---\n%s\n\n", res.id, res.res)) + contextMap[dependentID] = b + contextMu.Unlock() + + inDegree[dependentID]-- + if inDegree[dependentID] == 0 { + readyChan <- dependentID + } + } + } + } + + // Wait for any remaining goroutines (though the select loop handles the exact count) + wg.Wait() + + if masterErr != nil { + return ErrorResult(masterErr.Error()) + } + + // 5. Format final output + var finalOutput strings.Builder + finalOutput.WriteString("Team Execution Summary (DAG):\n\n") + + // Preserve original member order for final output readability + for _, m := range members { + if res, ok := finalResults[m.ID]; ok { + finalOutput.WriteString(fmt.Sprintf("### Worker [%s] (Role: %s) Output:\n%s\n\n", m.ID, m.Role, res)) + } + } + + return &ToolResult{ + ForLLM: finalOutput.String(), + ForUser: buildUserSummary("DAG", members, nil), + } +} + +// truncateContextN limits the number of runes in ctx to maxRunes. +// It prevents Context Window Explosion (Token Bombs) when passing upstream +// worker results into downstream agent prompts. +func truncateContextN(ctx string, maxRunes int) string { + runes := []rune(ctx) + if len(runes) > maxRunes { + return string(runes[:maxRunes]) + "\n...[Context truncated due to length]..." + } + return ctx +} + +// buildUserSummary produces a concise human-readable summary for the ForUser field, +// listing each member's role. errors (if any) are appended as a separate section. +func buildUserSummary(strategy string, members []TeamMember, errors []string) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Team (%s) completed — %d member(s):\n", strategy, len(members))) + for i, m := range members { + sb.WriteString(fmt.Sprintf(" [%d] %s", i+1, m.Role)) + if m.Model != "" { + sb.WriteString(fmt.Sprintf(" (model: %s)", m.Model)) + } + sb.WriteString("\n") + } + if len(errors) > 0 { + sb.WriteString("\n⚠️ Failures:\n") + for _, e := range errors { + sb.WriteString(" • " + e + "\n") + } + } + return strings.TrimRight(sb.String(), "\n") +} diff --git a/pkg/tools/team_test.go b/pkg/tools/team_test.go new file mode 100644 index 000000000..89361ac71 --- /dev/null +++ b/pkg/tools/team_test.go @@ -0,0 +1,629 @@ +package tools + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +func TestUpgradeRegistryForConcurrency(t *testing.T) { + // Create a standard tool registry + original := NewToolRegistry() + + // Register a mix of tools: some upgradeable, some not + readTool := NewReadFileTool("", false, MaxReadFileSize) + listTool := NewListDirTool("", false) // Not upgradeable + writeTool := NewWriteFileTool("", false) + + original.Register(readTool) + original.Register(listTool) + original.Register(writeTool) + + // Perform the upgrade + upgraded := upgradeRegistryForConcurrency(original) + + // Verify count matches + assert.Equal( + t, + len(original.ListTools()), + len(upgraded.ListTools()), + "Upgraded registry should have same number of tools", + ) + + // Verify ReadFileTool got upgraded + actualReadTool, ok := upgraded.Get("read_file") + assert.True(t, ok) + if upgradedRead, isUpgraded := actualReadTool.(*ReadFileTool); isUpgraded { + _, isConcurrent := upgradedRead.fs.(*ConcurrentFS) + assert.True(t, isConcurrent, "read_file should have been upgraded to ConcurrentFS") + } + + // Verify WriteFileTool got upgraded + actualWriteTool, ok := upgraded.Get("write_file") + assert.True(t, ok) + if upgradedWrite, isUpgraded := actualWriteTool.(*WriteFileTool); isUpgraded { + _, isConcurrent := upgradedWrite.fs.(*ConcurrentFS) + assert.True(t, isConcurrent, "write_file should have been upgraded to ConcurrentFS") + } + + // Verify ListDirTool remained the same + actualListTool, ok := upgraded.Get("list_dir") + assert.True(t, ok) + _, isListDir := actualListTool.(*ListDirTool) + assert.True(t, isListDir, "list_dir should still be ListDirTool") + + // Double check list_dir doesn't randomly have ConcurrentFS injected + if listImpl, ok := actualListTool.(*ListDirTool); ok { + _, isConcurrent := listImpl.fs.(*ConcurrentFS) + assert.False(t, isConcurrent, "list_dir should NOT have ConcurrentFS because it's not upgradeable") + } + + // Double check original registry was entirely unmodified + origReadTool, _ := original.Get("read_file") + if origRead, _ := origReadTool.(*ReadFileTool); origRead != nil { + _, isConcurrent := origRead.fs.(*ConcurrentFS) + assert.False(t, isConcurrent, "Original registry components MUST REMAIN completely lock-free") + } +} + +func TestBuildWorkerConfig(t *testing.T) { + // 1. Setup global config with model aliases + cfg := &config.Config{ + ModelList: []*config.ModelConfig{ + func() *config.ModelConfig { + m := &config.ModelConfig{ + ModelName: "strong-model", + Model: "openai/gpt-4o", + APIBase: "https://api.openai.com/v1", + } + m.SetAPIKey("sk-test") + return m + }(), + func() *config.ModelConfig { + m := &config.ModelConfig{ + ModelName: "fast-model", + Model: "anthropic/claude-3-haiku", + APIBase: "https://api.anthropic.com/v1", + } + m.SetAPIKey("sk-test") + return m + }(), + func() *config.ModelConfig { + m := &config.ModelConfig{ + ModelName: "direct-id", + Model: "openai/gpt-3.5-turbo", + APIBase: "https://api.openai.com/v1", + } + m.SetAPIKey("sk-test") + return m + }(), + }, + } + + // 2. Setup SubagentManager (needed by TeamTool for IsModelAllowed check) + manager := NewSubagentManager(nil, "default-model", nil, "", config.TeamToolsConfig{ + AllowedModels: []config.TeamModelConfig{ + {Name: "fast-model", Tags: []string{"vision"}}, + {Name: "strong-model", Tags: []string{"coding"}}, + {Name: "direct-id", Tags: []string{"coding"}}, + }, + }, nil) + + // 3. Create TeamTool + tool := NewTeamTool(manager, cfg) + + baseConfig := ToolLoopConfig{ + Model: "base-model", + } + + tests := []struct { + name string + memberModel string + expectedModel string + expectError bool + }{ + { + name: "Resolve alias to actual ID", + memberModel: "strong-model", + expectedModel: "gpt-4o", + expectError: false, + }, + { + name: "Resolve another alias", + memberModel: "fast-model", + expectedModel: "claude-3-haiku", + expectError: false, + }, + { + name: "Resolve direct name if it matches an alias", + memberModel: "direct-id", + expectedModel: "gpt-3.5-turbo", + expectError: false, + }, + { + name: "Inherit base model if member model is empty", + memberModel: "", + expectedModel: "base-model", + expectError: false, + }, + { + name: "Error if model is not allowed", + memberModel: "forbidden-model", + expectedModel: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := TeamMember{ + Model: tt.memberModel, + } + res, err := tool.buildWorkerConfig(baseConfig, nil, m) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedModel, res.Model) + if tt.memberModel != "" { + assert.NotNil(t, res.Provider, "Provider should be set when model is specified") + } else { + assert.Nil(t, res.Provider, "Provider should be nil (inherited from baseConfig)") + } + } + }) + } +} + +type mockProvider struct { + responses []string + callCount int +} + +func (m *mockProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, +) (*providers.LLMResponse, error) { + if m.callCount >= len(m.responses) { + return &providers.LLMResponse{Content: "Default response"}, nil + } + resp := m.responses[m.callCount] + m.callCount++ + return &providers.LLMResponse{Content: resp}, nil +} + +func (m *mockProvider) GetDefaultModel() string { + return "mock-model" +} + +// mockProviderWithID returns responses based on the role in the system message +// This is needed for DAG tests where execution order is non-deterministic +type mockProviderWithID struct { + responses map[string]string + mu sync.Mutex + callCount int +} + +func (m *mockProviderWithID) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, +) (*providers.LLMResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.callCount++ + + // Extract role from system message to determine which response to return + for _, msg := range messages { + if msg.Role == "system" { + if resp, ok := m.responses[msg.Content]; ok { + return &providers.LLMResponse{Content: resp}, nil + } + } + } + return &providers.LLMResponse{Content: "Default response"}, nil +} + +func (m *mockProviderWithID) GetDefaultModel() string { + return "mock-model" +} + +func TestExecuteSequential(t *testing.T) { + // 1. Setup mock provider to return specific outputs for each agent + mock := &mockProvider{ + responses: []string{ + "Result from Agent A", + "Derived result from Agent B", + }, + } + + // 2. Setup TeamTool + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{}, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: mock, + Model: "mock-model", + MaxIterations: 1, + } + + members := []TeamMember{ + {ID: "worker-A", Role: "Researcher", Task: "Research topic X"}, + {ID: "worker-B", Role: "Writer", Task: "Write summary of researcher output"}, + } + + // 3. Run sequential execution + result := tool.executeSequential(context.Background(), baseConfig, members, 1000) + + // 4. Verify results + assert.False(t, result.IsError, "Should not return error") + assert.Contains(t, result.ForLLM, "Result from Agent A") + assert.Contains(t, result.ForLLM, "Derived result from Agent B") + assert.Equal(t, 2, mock.callCount, "Should have called mock provider exactly twice") +} + +func TestExecuteDAG(t *testing.T) { + t.Run("Simple DAG with dependencies", func(t *testing.T) { + // Setup: A -> C, B -> C (C depends on both A and B) + mock := &mockProviderWithID{ + responses: map[string]string{ + "Data Collector A": "Data from A", + "Data Collector B": "Data from B", + "Aggregator": "Combined result from C using A and B", + }, + } + + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{}, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: mock, + Model: "mock-model", + MaxIterations: 1, + Tools: NewToolRegistry(), // Empty registry for test + } + + members := []TeamMember{ + {ID: "A", Role: "Data Collector A", Task: "Collect data A"}, + {ID: "B", Role: "Data Collector B", Task: "Collect data B"}, + {ID: "C", Role: "Aggregator", Task: "Combine data", DependsOn: []string{"A", "B"}}, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + result := tool.executeDAG(ctx, cancel, baseConfig, members, 1000) + + assert.False(t, result.IsError, "Should not return error") + assert.Contains(t, result.ForLLM, "Data from A") + assert.Contains(t, result.ForLLM, "Data from B") + assert.Contains(t, result.ForLLM, "Combined result from C") + assert.Equal(t, 3, mock.callCount, "Should have called mock provider 3 times") + }) + + t.Run("Linear chain DAG", func(t *testing.T) { + // Setup: A -> B -> C (linear dependency chain) + mock := &mockProviderWithID{ + responses: map[string]string{ + "Step 1": "Step 1 output", + "Step 2": "Step 2 output", + "Step 3": "Step 3 output", + }, + } + + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{}, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: mock, + Model: "mock-model", + MaxIterations: 1, + Tools: NewToolRegistry(), + } + + members := []TeamMember{ + {ID: "A", Role: "Step 1", Task: "Do step 1"}, + {ID: "B", Role: "Step 2", Task: "Do step 2", DependsOn: []string{"A"}}, + {ID: "C", Role: "Step 3", Task: "Do step 3", DependsOn: []string{"B"}}, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + result := tool.executeDAG(ctx, cancel, baseConfig, members, 1000) + + assert.False(t, result.IsError) + assert.Contains(t, result.ForLLM, "Step 1 output") + assert.Contains(t, result.ForLLM, "Step 2 output") + assert.Contains(t, result.ForLLM, "Step 3 output") + }) + + t.Run("Detect circular dependency", func(t *testing.T) { + // Setup: A -> B -> C -> A (circular) + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{}, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: &mockProvider{}, + Model: "mock-model", + MaxIterations: 1, + Tools: NewToolRegistry(), + } + + members := []TeamMember{ + {ID: "A", Role: "Worker A", Task: "Task A", DependsOn: []string{"C"}}, + {ID: "B", Role: "Worker B", Task: "Task B", DependsOn: []string{"A"}}, + {ID: "C", Role: "Worker C", Task: "Task C", DependsOn: []string{"B"}}, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + result := tool.executeDAG(ctx, cancel, baseConfig, members, 1000) + + assert.True(t, result.IsError, "Should detect circular dependency") + assert.Contains(t, result.ForLLM, "cycle") + }) + + t.Run("Detect undefined dependency", func(t *testing.T) { + // Setup: A depends on non-existent "X" + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{}, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: &mockProvider{}, + Model: "mock-model", + MaxIterations: 1, + Tools: NewToolRegistry(), + } + + members := []TeamMember{ + {ID: "A", Role: "Worker A", Task: "Task A", DependsOn: []string{"X"}}, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + result := tool.executeDAG(ctx, cancel, baseConfig, members, 1000) + + assert.True(t, result.IsError, "Should detect undefined dependency") + assert.Contains(t, result.ForLLM, "undefined member") + }) + + t.Run("Complex DAG with multiple roots", func(t *testing.T) { + // Setup: A, B (roots) -> C, D -> E + mock := &mockProviderWithID{ + responses: map[string]string{ + "Root A": "Root A output", + "Root B": "Root B output", + "Worker C": "C output", + "Worker D": "D output", + "Final": "Final E output", + }, + } + + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{}, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: mock, + Model: "mock-model", + MaxIterations: 1, + Tools: NewToolRegistry(), + } + + members := []TeamMember{ + {ID: "A", Role: "Root A", Task: "Root task A"}, + {ID: "B", Role: "Root B", Task: "Root task B"}, + {ID: "C", Role: "Worker C", Task: "Task C", DependsOn: []string{"A"}}, + {ID: "D", Role: "Worker D", Task: "Task D", DependsOn: []string{"B"}}, + {ID: "E", Role: "Final", Task: "Final task", DependsOn: []string{"C", "D"}}, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + result := tool.executeDAG(ctx, cancel, baseConfig, members, 1000) + + assert.False(t, result.IsError) + assert.Contains(t, result.ForLLM, "Root A output") + assert.Contains(t, result.ForLLM, "Root B output") + assert.Contains(t, result.ForLLM, "Final E output") + assert.Equal(t, 5, mock.callCount) + }) +} + +func TestExecuteEvaluatorOptimizer(t *testing.T) { + t.Run("Pass on first attempt", func(t *testing.T) { + mock := &mockProvider{ + responses: []string{ + "Perfect code implementation", + "[PASS] The code is correct", + }, + } + + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{ + MaxEvaluatorLoops: 3, + }, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: mock, + Model: "mock-model", + MaxIterations: 1, + Tools: NewToolRegistry(), + } + + members := []TeamMember{ + {ID: "worker", Role: "Coder", Task: "Write a function"}, + {ID: "evaluator", Role: "Code Reviewer", Task: "Review the code"}, + } + + result := tool.executeEvaluatorOptimizer(context.Background(), baseConfig, members, 1000) + + assert.False(t, result.IsError) + assert.Contains(t, result.ForLLM, "Perfect code implementation") + assert.Contains(t, result.ForLLM, "[PASS]") + assert.Contains(t, result.ForUser, "passed on attempt 1") + assert.Equal(t, 2, mock.callCount, "Should call worker once and evaluator once") + }) + + t.Run("Pass on second attempt after feedback", func(t *testing.T) { + mock := &mockProvider{ + responses: []string{ + "Initial code with bug", + "Missing error handling", + "Fixed code with error handling", + "[PASS] Now it's correct", + }, + } + + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{ + MaxEvaluatorLoops: 3, + }, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: mock, + Model: "mock-model", + MaxIterations: 1, + Tools: NewToolRegistry(), + } + + members := []TeamMember{ + {ID: "worker", Role: "Coder", Task: "Write a function"}, + {ID: "evaluator", Role: "Code Reviewer", Task: "Review the code"}, + } + + result := tool.executeEvaluatorOptimizer(context.Background(), baseConfig, members, 1000) + + assert.False(t, result.IsError) + assert.Contains(t, result.ForLLM, "Initial code with bug") + assert.Contains(t, result.ForLLM, "Missing error handling") + assert.Contains(t, result.ForLLM, "Fixed code with error handling") + assert.Contains(t, result.ForLLM, "[PASS]") + assert.Contains(t, result.ForUser, "passed on attempt 2") + assert.Equal(t, 4, mock.callCount, "Should call worker twice and evaluator twice") + }) + + t.Run("Exhaust max loops without pass", func(t *testing.T) { + mock := &mockProvider{ + responses: []string{ + "Attempt 1", + "Still has issues", + "Attempt 2", + "Still not good", + "Attempt 3", + "Still failing", + }, + } + + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{ + MaxEvaluatorLoops: 3, + }, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: mock, + Model: "mock-model", + MaxIterations: 1, + Tools: NewToolRegistry(), + } + + members := []TeamMember{ + {ID: "worker", Role: "Coder", Task: "Write a function"}, + {ID: "evaluator", Role: "Code Reviewer", Task: "Review the code"}, + } + + result := tool.executeEvaluatorOptimizer(context.Background(), baseConfig, members, 1000) + + assert.False(t, result.IsError, "Should not error, just report exhaustion") + assert.Contains(t, result.ForLLM, "Maximum evaluation loops reached") + assert.Contains(t, result.ForUser, "exhausted 3 attempts") + assert.Equal(t, 6, mock.callCount, "Should call worker 3 times and evaluator 3 times") + }) + + t.Run("Require exactly two members", func(t *testing.T) { + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{}, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: &mockProvider{}, + Model: "mock-model", + MaxIterations: 1, + } + + // Test with 1 member + members := []TeamMember{ + {ID: "worker", Role: "Coder", Task: "Write a function"}, + } + + result := tool.executeEvaluatorOptimizer(context.Background(), baseConfig, members, 1000) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "exactly two members") + + // Test with 3 members + members = []TeamMember{ + {ID: "worker", Role: "Coder", Task: "Write a function"}, + {ID: "evaluator", Role: "Reviewer", Task: "Review"}, + {ID: "extra", Role: "Extra", Task: "Extra task"}, + } + + result = tool.executeEvaluatorOptimizer(context.Background(), baseConfig, members, 1000) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "exactly two members") + }) + + t.Run("Stateful worker memory across iterations", func(t *testing.T) { + // This test verifies that the worker's message history is preserved + // across iterations, allowing it to "remember" previous feedback + mockWithMemory := &mockProvider{ + responses: []string{ + "First attempt", + "Needs improvement: add validation", + "Second attempt with validation", + "[PASS] Good now", + }, + } + + manager := NewSubagentManager(nil, "mock-model", nil, "", config.TeamToolsConfig{ + MaxEvaluatorLoops: 3, + }, nil) + tool := NewTeamTool(manager, &config.Config{}) + + baseConfig := ToolLoopConfig{ + Provider: mockWithMemory, + Model: "mock-model", + MaxIterations: 1, + Tools: NewToolRegistry(), + } + + members := []TeamMember{ + {ID: "worker", Role: "Coder", Task: "Write validation logic"}, + {ID: "evaluator", Role: "Reviewer", Task: "Check validation"}, + } + + result := tool.executeEvaluatorOptimizer(context.Background(), baseConfig, members, 1000) + + assert.False(t, result.IsError) + assert.Contains(t, result.ForLLM, "First attempt") + assert.Contains(t, result.ForLLM, "Needs improvement") + assert.Contains(t, result.ForLLM, "Second attempt with validation") + assert.Contains(t, result.ForLLM, "[PASS]") + + // Verify the worker was called twice (once initially, once after feedback) + // and evaluator was called twice + assert.Equal(t, 4, mockWithMemory.callCount) + }) +} diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index 244f0d4a2..deac2b833 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "sync" + "sync/atomic" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" @@ -19,17 +20,19 @@ import ( // ToolLoopConfig configures the tool execution loop. type ToolLoopConfig struct { - Provider providers.LLMProvider - Model string - Tools *ToolRegistry - MaxIterations int - LLMOptions map[string]any + Provider providers.LLMProvider + Model string + Tools *ToolRegistry + MaxIterations int + LLMOptions map[string]any + RemainingTokenBudget *atomic.Int64 } // ToolLoopResult contains the result of running the tool loop. type ToolLoopResult struct { Content string Iterations int + Messages []providers.Message // Allows caller to retain stateful context across executions } // RunToolLoop executes the LLM + tool call iteration loop. @@ -64,7 +67,13 @@ func RunToolLoop( llmOpts = map[string]any{} } // 3. Call LLM - response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts) + response, err := config.Provider.Chat( + ctx, + messages, + providerToolDefs, + config.Model, + llmOpts, + ) if err != nil { logger.ErrorCF("toolloop", "LLM call failed", map[string]any{ @@ -74,14 +83,89 @@ func RunToolLoop( return nil, fmt.Errorf("LLM call failed: %w", err) } + // 3.5 Token Budget: Soft enforcement with graceful degradation. + // Budget exhaustion is NOT a hard error — workers get a chance to wrap up gracefully. + if response.Usage != nil && config.RemainingTokenBudget != nil { + newBudget := config.RemainingTokenBudget.Add(-int64(response.Usage.TotalTokens)) + originalBudget := newBudget + int64(response.Usage.TotalTokens) + + if newBudget <= 0 { + // Budget exhausted: signal the worker to wrap up and return partial result. + logger.WarnCF("toolloop", "Token budget exhausted, injecting wrap-up signal", + map[string]any{ + "deficit": -newBudget, + "iteration": iteration, + }) + finalContent = response.Content + messages = append(messages, providers.Message{ + Role: "assistant", + Content: response.Content, + ReasoningContent: response.ReasoningContent, // [Fix] Preserve reasoning content to maintain context + }) + messages = append(messages, providers.Message{ + Role: "user", + Content: "[SYSTEM] Token budget has been exhausted. Stop all tool calls immediately and return the best result you have completed so far. Do not call any more tools.", + }) + // One final LLM call to get a summary/wrap-up from the model + if finalResp, err := config.Provider.Chat( + ctx, messages, nil, config.Model, config.LLMOptions, + ); err == nil { + finalContent = finalResp.Content + } + break + } else if originalBudget > 0 && newBudget < originalBudget/2 { + // Budget below 50%: soft warning injected into next iteration's context. + logger.WarnCF("toolloop", "Token budget below 50%, injecting advisory", + map[string]any{"remaining": newBudget, "iteration": iteration}) + messages = append(messages, providers.Message{ + Role: "user", + Content: "[SYSTEM] Advisory: token budget is running low. Please prioritize completing the most critical parts of your task and avoid unnecessary tool calls.", + }) + } + } + + // 3.6 Truncation Recovery: LLM response was cut off (max_tokens hit or malformed JSON). + // Inject a recovery message so the LLM knows to retry with a shorter, complete response. + if response.FinishReason == "truncated" { + logger.WarnCF( + "toolloop", + "LLM response was truncated (max_tokens hit), injecting recovery message", + map[string]any{"iteration": iteration}, + ) + messages = append(messages, providers.Message{ + Role: "assistant", + Content: response.Content, + ReasoningContent: response.ReasoningContent, // [Fix] Preserve reasoning content to prevent broken chain of thought + }) + messages = append(messages, providers.Message{ + Role: "user", + Content: "[SYSTEM] Your previous response was cut off because it exceeded the token limit. Please retry by producing a shorter, complete response. If you were about to call a tool, make sure the full JSON arguments are included without truncation.", + }) + continue + } + // 4. If no tool calls, we're done if len(response.ToolCalls) == 0 { finalContent = response.Content + // [Fix] Fallback for models (like Gemini 2.0 Pro Thinking) that put output in reasoning block + if finalContent == "" && response.ReasoningContent != "" { + finalContent = response.ReasoningContent + } + logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)", map[string]any{ "iteration": iteration, "content_chars": len(finalContent), }) + + // [Fix] Append the final answer to the messages array! + // Essential for Team's evaluator_optimizer strategy to retain state in the next loop. + messages = append(messages, providers.Message{ + Role: "assistant", + Content: finalContent, + ReasoningContent: response.ReasoningContent, + }) + break } @@ -104,20 +188,32 @@ func RunToolLoop( // 6. Build assistant message with tool calls assistantMsg := providers.Message{ - Role: "assistant", - Content: response.Content, + Role: "assistant", + Content: response.Content, + ReasoningContent: response.ReasoningContent, // [Fix] Include ReasoningContent } for _, tc := range normalizedToolCalls { argumentsJSON, _ := json.Marshal(tc.Arguments) + + // [Fix] Preserve ThoughtSignature and ExtraContent for compatibility with models like Gemini 2.0/3.0 + extraContent := tc.ExtraContent + thoughtSignature := "" + if tc.Function != nil { + thoughtSignature = tc.Function.ThoughtSignature + } + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ ID: tc.ID, Type: "function", Name: tc.Name, Arguments: tc.Arguments, Function: &providers.FunctionCall{ - Name: tc.Name, - Arguments: string(argumentsJSON), + Name: tc.Name, + Arguments: string(argumentsJSON), + ThoughtSignature: thoughtSignature, // [Fix] Preserve thought signature }, + ExtraContent: extraContent, // [Fix] Preserve extra content + ThoughtSignature: thoughtSignature, // [Fix] Preserve thought signature }) } messages = append(messages, assistantMsg) @@ -148,7 +244,14 @@ func RunToolLoop( var toolResult *ToolResult if config.Tools != nil { - toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) + toolResult = config.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + channel, + chatID, + nil, + ) } else { toolResult = ErrorResult("No tools available") } @@ -175,5 +278,6 @@ func RunToolLoop( return &ToolLoopResult{ Content: finalContent, Iterations: iteration, + Messages: messages, }, nil }