diff --git a/cmd/picoclaw/internal/agent/helpers.go b/cmd/picoclaw/internal/agent/helpers.go index f754abc652..f89c23795b 100644 --- a/cmd/picoclaw/internal/agent/helpers.go +++ b/cmd/picoclaw/internal/agent/helpers.go @@ -12,6 +12,7 @@ import ( "github.com/chzyer/readline" "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/pluginruntime" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/logger" @@ -51,6 +52,24 @@ func agentCmd(message, sessionKey, model string, debug bool) error { defer msgBus.Close() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + pluginsToEnable, pluginSummary, err := pluginruntime.ResolveConfiguredPlugins(cfg) + if err != nil { + return fmt.Errorf("error resolving configured plugins: %w", err) + } + if len(pluginsToEnable) > 0 { + if err := agentLoop.EnablePlugins(pluginsToEnable...); err != nil { + return fmt.Errorf("error enabling plugins: %w", err) + } + } + logger.InfoCF("agent", "Plugin selection resolved", + map[string]any{ + "plugins_enabled": pluginSummary.Enabled, + "plugins_disabled": pluginSummary.Disabled, + "plugins_unknown_enabled": pluginSummary.UnknownEnabled, + "plugins_unknown_disabled": pluginSummary.UnknownDisabled, + "plugins_warnings": pluginSummary.Warnings, + }) + // Print agent startup info (only for interactive mode) startupInfo := agentLoop.GetStartupInfo() logger.InfoCF("agent", "Agent initialized", diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 747f7d44e9..e3467cb267 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -10,6 +10,7 @@ import ( "time" "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/pluginruntime" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -61,6 +62,23 @@ func gatewayCmd(debug bool) error { msgBus := bus.NewMessageBus() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + pluginsToEnable, pluginSummary, err := pluginruntime.ResolveConfiguredPlugins(cfg) + if err != nil { + return fmt.Errorf("error resolving configured plugins: %w", err) + } + if len(pluginsToEnable) > 0 { + if enableErr := agentLoop.EnablePlugins(pluginsToEnable...); enableErr != nil { + return fmt.Errorf("error enabling plugins: %w", enableErr) + } + } + logger.InfoCF("agent", "Plugin selection resolved", + map[string]any{ + "plugins_enabled": pluginSummary.Enabled, + "plugins_disabled": pluginSummary.Disabled, + "plugins_unknown_enabled": pluginSummary.UnknownEnabled, + "plugins_unknown_disabled": pluginSummary.UnknownDisabled, + "plugins_warnings": pluginSummary.Warnings, + }) // Print agent startup info fmt.Println("\nšŸ“¦ Agent Status:") diff --git a/cmd/picoclaw/internal/pluginruntime/bootstrap.go b/cmd/picoclaw/internal/pluginruntime/bootstrap.go new file mode 100644 index 0000000000..4cfcffde2c --- /dev/null +++ b/cmd/picoclaw/internal/pluginruntime/bootstrap.go @@ -0,0 +1,64 @@ +package pluginruntime + +import ( + "fmt" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/plugin/builtin" +) + +type Summary struct { + Enabled []string + Disabled []string + UnknownEnabled []string + UnknownDisabled []string + Warnings []string +} + +func ResolveConfiguredPlugins(cfg *config.Config) ([]plugin.Plugin, Summary, error) { + if cfg == nil { + return nil, Summary{}, fmt.Errorf("config is nil") + } + + resolved, err := plugin.ResolveSelection( + builtin.Names(), + plugin.SelectionInput{ + DefaultEnabled: cfg.Plugins.DefaultEnabled, + Enabled: cfg.Plugins.Enabled, + Disabled: cfg.Plugins.Disabled, + }, + ) + + summary := Summary{ + Enabled: resolved.EnabledNames, + Disabled: resolved.DisabledNames, + UnknownEnabled: resolved.UnknownEnabled, + UnknownDisabled: resolved.UnknownDisabled, + Warnings: resolved.Warnings, + } + if err != nil { + return nil, summary, err + } + + catalog := builtin.Catalog() + normalizedCatalog := make(map[string]builtin.Factory, len(catalog)) + for name, factory := range catalog { + normalizedCatalog[plugin.NormalizePluginName(name)] = factory + } + + instances := make([]plugin.Plugin, 0, len(resolved.EnabledNames)) + for _, name := range resolved.EnabledNames { + factory, ok := normalizedCatalog[name] + if !ok { + return nil, summary, fmt.Errorf("builtin plugin %q has no factory", name) + } + instance := factory() + if instance == nil { + return nil, summary, fmt.Errorf("builtin plugin %q factory returned nil", name) + } + instances = append(instances, instance) + } + + return instances, summary, nil +} diff --git a/cmd/picoclaw/internal/pluginruntime/bootstrap_test.go b/cmd/picoclaw/internal/pluginruntime/bootstrap_test.go new file mode 100644 index 0000000000..93b2be2b37 --- /dev/null +++ b/cmd/picoclaw/internal/pluginruntime/bootstrap_test.go @@ -0,0 +1,106 @@ +package pluginruntime + +import ( + "slices" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/plugin/builtin" +) + +func TestResolveConfiguredPlugins_UnknownEnabledReturnsError(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Plugins = config.PluginsConfig{ + DefaultEnabled: false, + Enabled: []string{"missing-plugin"}, + Disabled: []string{}, + } + + instances, summary, err := ResolveConfiguredPlugins(cfg) + if err == nil { + t.Fatal("expected error for unknown enabled plugin") + } + if !strings.Contains(err.Error(), "missing-plugin") { + t.Fatalf("expected error to mention missing plugin, got %v", err) + } + if len(instances) != 0 { + t.Fatalf("expected no instances on error, got %d", len(instances)) + } + if !slices.Equal(summary.UnknownEnabled, []string{"missing-plugin"}) { + t.Fatalf("UnknownEnabled mismatch: got %v", summary.UnknownEnabled) + } +} + +func TestResolveConfiguredPlugins_ReturnsDeterministicInstances(t *testing.T) { + available := builtin.Names() + if len(available) == 0 { + t.Fatal("expected at least one builtin plugin") + } + + enabled := slices.Clone(available) + slices.Reverse(enabled) + + cfg := config.DefaultConfig() + cfg.Plugins = config.PluginsConfig{ + DefaultEnabled: false, + Enabled: enabled, + Disabled: []string{}, + } + + instances, summary, err := ResolveConfiguredPlugins(cfg) + if err != nil { + t.Fatalf("ResolveConfiguredPlugins() error = %v", err) + } + + gotNames := pluginNames(instances) + if !slices.Equal(gotNames, available) { + t.Fatalf("plugin names mismatch: got %v, want %v", gotNames, available) + } + if !slices.Equal(summary.Enabled, available) { + t.Fatalf("summary enabled mismatch: got %v, want %v", summary.Enabled, available) + } +} + +func TestResolveConfiguredPlugins_UnknownDisabledWarns(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Plugins = config.PluginsConfig{ + DefaultEnabled: true, + Enabled: []string{}, + Disabled: []string{"missing-plugin"}, + } + + instances, summary, err := ResolveConfiguredPlugins(cfg) + if err != nil { + t.Fatalf("ResolveConfiguredPlugins() error = %v", err) + } + + expectedEnabled := builtin.Names() + if !slices.Equal(pluginNames(instances), expectedEnabled) { + t.Fatalf("plugin names mismatch: got %v, want %v", pluginNames(instances), expectedEnabled) + } + if !slices.Equal(summary.UnknownDisabled, []string{"missing-plugin"}) { + t.Fatalf("UnknownDisabled mismatch: got %v", summary.UnknownDisabled) + } + if !hasWarningSubstring(summary.Warnings, `unknown disabled plugin "missing-plugin" ignored`) { + t.Fatalf("expected unknown disabled warning, got %v", summary.Warnings) + } +} + +func pluginNames(instances []plugin.Plugin) []string { + names := make([]string, 0, len(instances)) + for _, instance := range instances { + names = append(names, instance.Name()) + } + return names +} + +func hasWarningSubstring(warnings []string, sub string) bool { + for _, warning := range warnings { + if strings.Contains(warning, sub) { + return true + } + } + return false +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 8fd7328d10..a59604571d 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -22,8 +22,10 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/hooks" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/plugin" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/skills" @@ -42,6 +44,8 @@ type AgentLoop struct { fallback *providers.FallbackChain channelManager *channels.Manager mediaStore media.MediaStore + hooks *hooks.HookRegistry + pluginManager *plugin.Manager } // processOptions configures how a message is processed @@ -56,8 +60,6 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } -const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." - func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { registry := NewAgentRegistry(cfg, provider) @@ -172,61 +174,33 @@ func (al *AgentLoop) Run(ctx context.Context) error { continue } - // Process message - func() { - // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. - // Currently disabled because files are deleted before the LLM can access their content. - // defer func() { - // if al.mediaStore != nil && msg.MediaScope != "" { - // if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { - // logger.WarnCF("agent", "Failed to release media", map[string]any{ - // "scope": msg.MediaScope, - // "error": releaseErr.Error(), - // }) - // } - // } - // }() - - response, err := al.processMessage(ctx, msg) - if err != nil { - response = fmt.Sprintf("Error processing message: %v", err) - } + response, err := al.processMessage(ctx, msg) + if err != nil { + response = fmt.Sprintf("Error processing message: %v", err) + } - if response != "" { - // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). - alreadySent := false - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent != nil { - if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() - } + if response != "" { + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). + alreadySent := false + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() } } + } - if !alreadySent { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) - logger.InfoCF("agent", "Published outbound response", - map[string]any{ - "channel": msg.Channel, - "chat_id": msg.ChatID, - "content_len": len(response), - }) - } else { - logger.DebugCF( - "agent", - "Skipped outbound (message tool already sent)", - map[string]any{"channel": msg.Channel}, - ) - } + if !alreadySent { + al.sendOutbound(ctx, bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) } - }() + } } } @@ -284,6 +258,111 @@ func inferMediaType(filename, contentType string) string { return "file" } +// SetHooks installs a hook registry. Must be called before Run starts. +func (al *AgentLoop) SetHooks(h *hooks.HookRegistry) error { + if al.running.Load() { + return fmt.Errorf("SetHooks must be called before Run starts") + } + al.hooks = h + + // Rewire MessageTool callbacks to route through sendOutbound for hook interception. + for _, agentID := range al.registry.ListAgentIDs() { + if agent, ok := al.registry.GetAgent(agentID); ok { + if tool, ok := agent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + if h == nil { + mt.SetSendCallback(func(channel, chatID, content string) error { + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) + }) + continue + } + mt.SetSendCallback(func(channel, chatID, content string) error { + if sent, reason := al.sendOutbound(context.Background(), bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }); !sent { + if strings.TrimSpace(reason) == "" { + reason = "unspecified" + } + return fmt.Errorf("message canceled by hook: %s", reason) + } + return nil + }) + } + } + } + } + return nil +} + +// SetPluginManager installs a plugin manager and routes its hook registry into the loop. +// Must be called before Run starts. +func (al *AgentLoop) SetPluginManager(pm *plugin.Manager) error { + if pm == nil { + if err := al.SetHooks(nil); err != nil { + return err + } + al.pluginManager = nil + return nil + } + if err := al.SetHooks(pm.HookRegistry()); err != nil { + return err + } + al.pluginManager = pm + return nil +} + +// EnablePlugins is a convenience helper to build and install a plugin manager. +func (al *AgentLoop) EnablePlugins(plugins ...plugin.Plugin) error { + pm := plugin.NewManager() + if err := pm.RegisterAll(plugins...); err != nil { + return err + } + return al.SetPluginManager(pm) +} + +// sendOutbound wraps bus.PublishOutbound with the message_sending hook. +// Returns whether the message was sent and, if canceled, the cancel reason. +func (al *AgentLoop) sendOutbound(ctx context.Context, msg bus.OutboundMessage) (bool, string) { + if ctx == nil { + ctx = context.Background() + } + if al.hooks != nil { + event := &hooks.MessageSendingEvent{Channel: msg.Channel, ChatID: msg.ChatID, Content: msg.Content} + al.hooks.TriggerMessageSending(ctx, event) + if event.Cancel { + reason := event.CancelReason + if reason == "" { + reason = "unspecified" + } + logger.WarnCF("hooks", "Outbound message canceled by hook", + map[string]any{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "reason": reason, + }) + return false, reason + } + msg.Content = event.Content + } + if err := al.bus.PublishOutbound(ctx, msg); err != nil { + logger.WarnCF("agent", "Failed to publish outbound message", map[string]any{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "error": err.Error(), + }) + return false, err.Error() + } + return true, "" +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { @@ -333,7 +412,7 @@ func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, cha Channel: channel, ChatID: chatID, UserMessage: content, - DefaultResponse: defaultResponse, + DefaultResponse: "I've completed processing but have no response to give.", EnableSummary: false, SendResponse: false, NoHistory: true, // Don't load session history for heartbeat @@ -356,6 +435,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "session_key": msg.SessionKey, }) + // Fire message_received hook + if al.hooks != nil { + al.hooks.TriggerMessageReceived(ctx, &hooks.MessageReceivedEvent{ + Channel: msg.Channel, + SenderID: msg.SenderID, + ChatID: msg.ChatID, + Content: msg.Content, + Media: msg.Media, + Metadata: msg.Metadata, + }) + } + // Route system messages to processSystemMessage if msg.Channel == "system" { return al.processSystemMessage(ctx, msg) @@ -384,13 +475,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) } - // Reset message-tool state for this round so we don't skip publishing due to a previous round. - if tool, ok := agent.Tools.Get("message"); ok { - if mt, ok := tool.(tools.ContextualTool); ok { - mt.SetContext(msg.Channel, msg.ChatID) - } - } - // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) sessionKey := route.SessionKey if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") { @@ -409,7 +493,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, - DefaultResponse: defaultResponse, + DefaultResponse: "I've completed processing but have no response to give.", EnableSummary: true, SendResponse: false, }) @@ -490,6 +574,18 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 1. Update tool contexts al.updateToolContexts(agent, opts.Channel, opts.ChatID) + // Fire session hooks + if al.hooks != nil { + sessionEvt := &hooks.SessionEvent{ + AgentID: agent.ID, + SessionKey: opts.SessionKey, + Channel: opts.Channel, + ChatID: opts.ChatID, + } + al.hooks.TriggerSessionStart(ctx, sessionEvt) + defer al.hooks.TriggerSessionEnd(ctx, sessionEvt) + } + // 2. Build messages (skip history for heartbeat) var history []providers.Message var summary string @@ -529,12 +625,12 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 7. Optional: summarization if opts.EnableSummary { - al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) + al.maybeSummarize(ctx, agent, opts.SessionKey, opts.Channel, opts.ChatID) } // 8. Optional: send response via bus if opts.SendResponse { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: finalContent, @@ -576,7 +672,7 @@ func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, chan } // Use a short timeout so the goroutine does not block indefinitely when - // the outbound bus is full. Reasoning output is best-effort; dropping it + // the outbound bus is full. Reasoning output is best-effort; dropping it // is acceptable to avoid goroutine accumulation. pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second) defer pubCancel() @@ -587,7 +683,7 @@ func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, chan Content: reasoningContent, }); err != nil { // Treat context.DeadlineExceeded / context.Canceled as expected - // (bus full under load, or parent canceled). Check the error + // (bus full under load, or parent canceled). Check the error // itself rather than ctx.Err(), because pubCtx may time out // (5 s) while the parent ctx is still active. // Also treat ErrBusClosed as expected — it occurs during normal @@ -684,8 +780,19 @@ func (al *AgentLoop) runLLMIteration( } // Retry loop for context/token errors + llmStart := time.Now() maxRetries := 2 for retry := 0; retry <= maxRetries; retry++ { + // Fire llm_input hook (re-fires after compression so hooks see actual messages) + if al.hooks != nil { + al.hooks.TriggerLLMInput(ctx, &hooks.LLMInputEvent{ + AgentID: agent.ID, + Model: agent.Model, + Messages: messages, + Tools: providerToolDefs, + Iteration: iteration, + }) + } response, err = callLLM() if err == nil { break @@ -729,7 +836,7 @@ func (al *AgentLoop) runLLMIteration( }) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: "Context window exceeded. Compressing history and retrying...", @@ -748,6 +855,8 @@ func (al *AgentLoop) runLLMIteration( break } + llmDuration := time.Since(llmStart) + if err != nil { logger.ErrorCF("agent", "LLM call failed", map[string]any{ @@ -760,16 +869,18 @@ func (al *AgentLoop) runLLMIteration( go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel)) - logger.DebugCF("agent", "LLM response", - map[string]any{ - "agent_id": agent.ID, - "iteration": iteration, - "content_chars": len(response.Content), - "tool_calls": len(response.ToolCalls), - "reasoning": response.Reasoning, - "target_channel": al.targetReasoningChannelID(opts.Channel), - "channel": opts.Channel, + // Fire llm_output hook + if al.hooks != nil { + al.hooks.TriggerLLMOutput(ctx, &hooks.LLMOutputEvent{ + AgentID: agent.ID, + Model: agent.Model, + Content: response.Content, + ToolCalls: response.ToolCalls, + Iteration: iteration, + Duration: llmDuration, }) + } + // Check if no tool calls - we're done if len(response.ToolCalls) == 0 { finalContent = response.Content @@ -832,9 +943,14 @@ func (al *AgentLoop) runLLMIteration( // Save assistant message with tool calls to session agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) + assistantMsgIndex := len(messages) - 1 + assistantSessionIndex := -1 + if history := agent.Sessions.GetHistory(opts.SessionKey); len(history) > 0 { + assistantSessionIndex = len(history) - 1 + } // Execute tool calls - for _, tc := range normalizedToolCalls { + for tcIdx, tc := range normalizedToolCalls { argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), @@ -860,18 +976,74 @@ func (al *AgentLoop) runLLMIteration( } } - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, - asyncCallback, - ) + // Fire before_tool_call hook + var toolResult *tools.ToolResult + toolCanceled := false + if al.hooks != nil { + args := tc.Arguments + if args == nil { + args = make(map[string]any) + } + btcEvent := &hooks.BeforeToolCallEvent{ + ToolName: tc.Name, + Args: args, + Channel: opts.Channel, + ChatID: opts.ChatID, + } + al.hooks.TriggerBeforeToolCall(ctx, btcEvent) + if btcEvent.Cancel { + toolCanceled = true + reason := btcEvent.CancelReason + if strings.TrimSpace(reason) == "" { + reason = fmt.Sprintf("tool call %q was canceled by before_tool_call hook", tc.Name) + } + toolResult = tools.ErrorResult(reason) + } + tc.Arguments = btcEvent.Args + if tc.Arguments == nil { + tc.Arguments = make(map[string]any) + } + + // Keep persisted assistant tool-call arguments aligned with rewritten execution args. + updateToolCallArguments(&messages[assistantMsgIndex], tcIdx, tc.Arguments) + if assistantSessionIndex >= 0 { + history := agent.Sessions.GetHistory(opts.SessionKey) + if assistantSessionIndex < len(history) { + updateToolCallArguments(&history[assistantSessionIndex], tcIdx, tc.Arguments) + agent.Sessions.SetHistory(opts.SessionKey, history) + } + } + } + + var toolDuration time.Duration + if !toolCanceled { + toolStart := time.Now() + toolResult = agent.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + opts.Channel, + opts.ChatID, + asyncCallback, + ) + toolDuration = time.Since(toolStart) + } + + // Fire after_tool_call hook (fires for both executed and canceled calls) + if al.hooks != nil { + al.hooks.TriggerAfterToolCall(ctx, &hooks.AfterToolCallEvent{ + ToolName: tc.Name, + Args: tc.Arguments, + Channel: opts.Channel, + ChatID: opts.ChatID, + Duration: toolDuration, + Result: toolResult, + }) + } // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: toolResult.ForUser, @@ -898,7 +1070,7 @@ func (al *AgentLoop) runLLMIteration( } parts = append(parts, part) } - al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ + _ = al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Parts: parts, @@ -947,7 +1119,7 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st } // maybeSummarize triggers summarization if the session history exceeds thresholds. -func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { +func (al *AgentLoop) maybeSummarize(_ context.Context, agent *AgentInstance, sessionKey, channel, chatID string) { newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) threshold := agent.ContextWindow * 75 / 100 @@ -957,6 +1129,13 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading { go func() { defer al.summarizing.Delete(summarizeKey) + if !constants.IsInternalChannel(channel) { + al.sendOutbound(context.Background(), bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: "Memory threshold reached. Optimizing conversation history...", + }) + } logger.Debug("Memory threshold reached. Optimizing conversation history...") al.summarizeSession(agent, sessionKey) }() @@ -1026,6 +1205,14 @@ func (al *AgentLoop) GetStartupInfo() map[string]any { return info } + pluginNames := make([]string, 0) + if al.pluginManager != nil { + pluginNames = al.pluginManager.Names() + if pluginNames == nil { + pluginNames = make([]string, 0) + } + } + // Tools info toolsList := agent.Tools.List() info["tools"] = map[string]any{ @@ -1033,6 +1220,12 @@ func (al *AgentLoop) GetStartupInfo() map[string]any { "names": toolsList, } + // Plugins info + info["plugins"] = map[string]any{ + "enabled": pluginNames, + "count": len(pluginNames), + } + // Skills info info["skills"] = agent.ContextBuilder.GetSkillsInfo() @@ -1045,6 +1238,19 @@ func (al *AgentLoop) GetStartupInfo() map[string]any { return info } +// updateToolCallArguments patches the serialized arguments for a tool call in-place. +func updateToolCallArguments(msg *providers.Message, toolCallIndex int, args map[string]any) { + if msg == nil || toolCallIndex < 0 || toolCallIndex >= len(msg.ToolCalls) { + return + } + toolCall := &msg.ToolCalls[toolCallIndex] + if toolCall.Function == nil { + return + } + argumentsJSON, _ := json.Marshal(args) + toolCall.Function.Arguments = string(argumentsJSON) +} + // formatMessagesForLog formats messages for logging func formatMessagesForLog(messages []providers.Message) string { if len(messages) == 0 { @@ -1317,20 +1523,27 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "", false } -// extractPeer extracts the routing peer from the inbound message's structured Peer field. +// extractPeer extracts the routing peer from inbound message metadata. func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { - if msg.Peer.Kind == "" { + peerKind := msg.Metadata["peer_kind"] + if peerKind == "" { + peerKind = msg.Peer.Kind + } + peerID := msg.Metadata["peer_id"] + if peerID == "" { + peerID = msg.Peer.ID + } + if peerKind == "" { return nil } - peerID := msg.Peer.ID if peerID == "" { - if msg.Peer.Kind == "direct" { + if peerKind == "direct" { peerID = msg.SenderID } else { peerID = msg.ChatID } } - return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} + return &routing.RoutePeer{Kind: peerKind, ID: peerID} } // extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 801b6a46ed..fc2beb3907 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -321,6 +321,90 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) { } } +func TestGetStartupInfo_IncludesPluginSummary(t *testing.T) { + newLoop := func(t *testing.T) *AgentLoop { + t.Helper() + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + return NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{}) + } + + t.Run("no plugins enabled", func(t *testing.T) { + al := newLoop(t) + info := al.GetStartupInfo() + + pluginsInfo, ok := info["plugins"].(map[string]any) + if !ok { + t.Fatal("Expected 'plugins' to be a map") + } + + count, ok := pluginsInfo["count"].(int) + if !ok { + t.Fatal("Expected plugin count to be an int") + } + if count != 0 { + t.Fatalf("Expected plugin count 0, got %d", count) + } + + enabled, ok := pluginsInfo["enabled"].([]string) + if !ok { + t.Fatal("Expected plugin enabled list to be []string") + } + if len(enabled) != 0 { + t.Fatalf("Expected no enabled plugins, got %v", enabled) + } + }) + + t.Run("plugins enabled", func(t *testing.T) { + al := newLoop(t) + if err := al.EnablePlugins(blockingPlugin{}); err != nil { + t.Fatalf("EnablePlugins failed: %v", err) + } + + info := al.GetStartupInfo() + pluginsInfo, ok := info["plugins"].(map[string]any) + if !ok { + t.Fatal("Expected 'plugins' to be a map") + } + + count, ok := pluginsInfo["count"].(int) + if !ok { + t.Fatal("Expected plugin count to be an int") + } + if count <= 0 { + t.Fatalf("Expected plugin count > 0, got %d", count) + } + + enabled, ok := pluginsInfo["enabled"].([]string) + if !ok { + t.Fatal("Expected plugin enabled list to be []string") + } + if len(enabled) == 0 { + t.Fatal("Expected at least one enabled plugin") + } + + found := false + for _, name := range enabled { + if name == "block-outbound" { + found = true + break + } + } + if !found { + t.Fatalf("Expected enabled plugin list to include block-outbound, got %v", enabled) + } + }) +} + // TestAgentLoop_Stop verifies Stop() sets running to false func TestAgentLoop_Stop(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") diff --git a/pkg/agent/plugin_test.go b/pkg/agent/plugin_test.go new file mode 100644 index 0000000000..c0b6f0625a --- /dev/null +++ b/pkg/agent/plugin_test.go @@ -0,0 +1,416 @@ +package agent + +import ( + "context" + "encoding/json" + "os" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/hooks" + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +type blockingPlugin struct{} + +func (p blockingPlugin) Name() string { + return "block-outbound" +} + +func (p blockingPlugin) APIVersion() string { + return plugin.APIVersion +} + +func (p blockingPlugin) Register(r *hooks.HookRegistry) error { + r.OnMessageSending("block-all", 0, func(_ context.Context, e *hooks.MessageSendingEvent) error { + e.Cancel = true + e.CancelReason = "blocked by plugin" + return nil + }) + return nil +} + +type nilArgsProvider struct { + calls int +} + +func (p *nilArgsProvider) Chat( + _ context.Context, + _ []providers.Message, + _ []providers.ToolDefinition, + _ string, + _ map[string]any, +) (*providers.LLMResponse, error) { + if p.calls == 0 { + p.calls++ + return &providers.LLMResponse{ + Content: "", + ToolCalls: []providers.ToolCall{ + { + ID: "tc-1", + Type: "function", + Name: "nil_args_tool", + Arguments: map[string]any{"seed": "value"}, + }, + }, + }, nil + } + p.calls++ + return &providers.LLMResponse{ + Content: "done", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (p *nilArgsProvider) GetDefaultModel() string { + return "test-model" +} + +type nilArgsCaptureTool struct { + receivedNil bool +} + +func (t *nilArgsCaptureTool) Name() string { + return "nil_args_tool" +} + +func (t *nilArgsCaptureTool) Description() string { + return "captures whether args are nil" +} + +func (t *nilArgsCaptureTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (t *nilArgsCaptureTool) Execute(_ context.Context, args map[string]any) *tools.ToolResult { + if args == nil { + t.receivedNil = true + } + return tools.SilentResult("ok") +} + +func TestSetPluginManagerInstallsHookRegistry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + + pm := plugin.NewManager() + if err := pm.Register(blockingPlugin{}); err != nil { + t.Fatalf("register plugin: %v", err) + } + + if err := al.SetPluginManager(pm); err != nil { + t.Fatalf("SetPluginManager: %v", err) + } + + if al.pluginManager == nil { + t.Fatal("expected plugin manager to be set") + } + if al.hooks != pm.HookRegistry() { + t.Fatal("expected agent loop hooks to use plugin manager registry") + } + + sent, reason := al.sendOutbound(context.Background(), bus.OutboundMessage{ + Channel: "cli", + ChatID: "direct", + Content: "hello", + }) + if sent { + t.Fatal("expected outbound message to be blocked by plugin") + } + if reason == "" { + t.Fatal("expected cancel reason to be propagated") + } +} + +func TestSetHooksReturnsErrorWhenRunning(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + al.running.Store(true) + + if err := al.SetHooks(hooks.NewHookRegistry()); err == nil { + t.Fatal("expected error when calling SetHooks while running") + } +} + +func TestSetPluginManagerDoesNotPartiallyUpdateOnError(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + al.running.Store(true) + + pm := plugin.NewManager() + if err := pm.Register(blockingPlugin{}); err != nil { + t.Fatalf("register plugin: %v", err) + } + + if err := al.SetPluginManager(pm); err == nil { + t.Fatal("expected SetPluginManager to fail while running") + } + if al.pluginManager != nil { + t.Fatal("expected plugin manager to remain unchanged on SetPluginManager failure") + } +} + +func TestBeforeToolCallHooksCannotLeaveToolArgsNil(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &nilArgsProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + captureTool := &nilArgsCaptureTool{} + al.RegisterTool(captureTool) + + r := hooks.NewHookRegistry() + r.OnBeforeToolCall("force-nil-args", 0, func(_ context.Context, e *hooks.BeforeToolCallEvent) error { + if e.ToolName == "nil_args_tool" { + e.Args = nil + } + return nil + }) + if setErr := al.SetHooks(r); setErr != nil { + t.Fatalf("SetHooks: %v", setErr) + } + + resp, err := al.ProcessDirectWithChannel(context.Background(), "run nil args test", "s1", "cli", "direct") + if err != nil { + t.Fatalf("ProcessDirectWithChannel: %v", err) + } + if resp != "done" { + t.Fatalf("expected final response 'done', got %q", resp) + } + if captureTool.receivedNil { + t.Fatal("expected tool args to be reinitialized to non-nil map") + } +} + +func TestSetHooksNilRestoresDirectMessageCallback(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + tool, ok := agent.Tools.Get("message") + if !ok { + t.Fatal("expected message tool") + } + mt, ok := tool.(*tools.MessageTool) + if !ok { + t.Fatal("expected message tool type") + } + + reg := hooks.NewHookRegistry() + reg.OnMessageSending("block-all", 0, func(_ context.Context, e *hooks.MessageSendingEvent) error { + e.Cancel = true + e.CancelReason = "blocked-by-hook" + return nil + }) + if err := al.SetHooks(reg); err != nil { + t.Fatalf("SetHooks(reg): %v", err) + } + + blocked := mt.Execute(context.Background(), map[string]any{ + "content": "first", + "channel": "cli", + "chat_id": "direct", + }) + if !blocked.IsError { + t.Fatal("expected message tool call to fail while hooks are active") + } + if blocked.Err == nil || !strings.Contains(blocked.Err.Error(), "blocked-by-hook") { + t.Fatalf("expected hook cancel reason in error, got %#v", blocked.Err) + } + + ctxNoMsg, cancelNoMsg := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancelNoMsg() + if _, got := msgBus.SubscribeOutbound(ctxNoMsg); got { + t.Fatal("did not expect outbound message while hook cancellation is active") + } + + if err := al.SetHooks(nil); err != nil { + t.Fatalf("SetHooks(nil): %v", err) + } + + delivered := mt.Execute(context.Background(), map[string]any{ + "content": "second", + "channel": "cli", + "chat_id": "direct", + }) + if delivered.IsError { + t.Fatalf("expected message tool to succeed after SetHooks(nil), got %#v", delivered) + } + + ctxMsg, cancelMsg := context.WithTimeout(context.Background(), time.Second) + defer cancelMsg() + msg, got := msgBus.SubscribeOutbound(ctxMsg) + if !got { + t.Fatal("expected outbound message after SetHooks(nil)") + } + if msg.Content != "second" || msg.Channel != "cli" || msg.ChatID != "direct" { + t.Fatalf("unexpected outbound message: %#v", msg) + } +} + +func TestBeforeToolCallArgRewriteUpdatesAssistantTranscript(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &nilArgsProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + al.RegisterTool(&nilArgsCaptureTool{}) + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + sessionKey := "agent:" + defaultAgent.ID + ":s2" + + reg := hooks.NewHookRegistry() + reg.OnBeforeToolCall("rewrite-args", 0, func(_ context.Context, e *hooks.BeforeToolCallEvent) error { + e.Args["rewritten"] = "yes" + return nil + }) + if err := al.SetHooks(reg); err != nil { + t.Fatalf("SetHooks: %v", err) + } + + if _, err := al.ProcessDirectWithChannel( + context.Background(), + "run rewrite test", + sessionKey, + "cli", + "direct", + ); err != nil { + t.Fatalf("ProcessDirectWithChannel: %v", err) + } + + history := defaultAgent.Sessions.GetHistory(sessionKey) + + foundToolCall := false + for _, msg := range history { + if msg.Role != "assistant" || len(msg.ToolCalls) == 0 { + continue + } + if msg.ToolCalls[0].Function == nil { + t.Fatal("expected tool call function payload") + } + var args map[string]any + if err := json.Unmarshal([]byte(msg.ToolCalls[0].Function.Arguments), &args); err != nil { + t.Fatalf("failed to decode persisted tool call args: %v", err) + } + if got := args["rewritten"]; got != "yes" { + t.Fatalf("expected rewritten arg to be persisted, got %#v", got) + } + foundToolCall = true + break + } + + if !foundToolCall { + t.Fatal("expected assistant tool call message in session history") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index d84772d2b0..92549b7778 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -53,6 +53,7 @@ type Config struct { Session SessionConfig `json:"session,omitempty"` Channels ChannelsConfig `json:"channels"` Providers ProvidersConfig `json:"providers,omitempty"` + Plugins PluginsConfig `json:"plugins,omitempty"` ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration Gateway GatewayConfig `json:"gateway"` Tools ToolsConfig `json:"tools"` @@ -167,6 +168,12 @@ type SessionConfig struct { IdentityLinks map[string][]string `json:"identity_links,omitempty"` } +type PluginsConfig struct { + DefaultEnabled bool `json:"default_enabled"` + Enabled []string `json:"enabled,omitempty"` + Disabled []string `json:"disabled,omitempty"` +} + type AgentDefaults struct { Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 12fd10b50b..7e25b38293 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -243,6 +243,51 @@ func TestDefaultConfig_Temperature(t *testing.T) { } } +func TestDefaultConfig_PluginsDefaults(t *testing.T) { + cfg := DefaultConfig() + + if !cfg.Plugins.DefaultEnabled { + t.Error("Plugins.DefaultEnabled should be true by default") + } + if cfg.Plugins.Enabled == nil { + t.Error("Plugins.Enabled should be initialized to an empty slice") + } + if len(cfg.Plugins.Enabled) != 0 { + t.Errorf("Plugins.Enabled len = %d, want 0", len(cfg.Plugins.Enabled)) + } + if cfg.Plugins.Disabled == nil { + t.Error("Plugins.Disabled should be initialized to an empty slice") + } + if len(cfg.Plugins.Disabled) != 0 { + t.Errorf("Plugins.Disabled len = %d, want 0", len(cfg.Plugins.Disabled)) + } +} + +func TestConfig_PluginsJSONUnmarshal(t *testing.T) { + jsonData := `{ + "plugins": { + "default_enabled": false, + "enabled": ["plugin-a", "plugin-b"], + "disabled": ["plugin-c"] + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if cfg.Plugins.DefaultEnabled { + t.Error("Plugins.DefaultEnabled = true, want false") + } + if len(cfg.Plugins.Enabled) != 2 || cfg.Plugins.Enabled[0] != "plugin-a" || cfg.Plugins.Enabled[1] != "plugin-b" { + t.Errorf("Plugins.Enabled = %v, want [plugin-a plugin-b]", cfg.Plugins.Enabled) + } + if len(cfg.Plugins.Disabled) != 1 || cfg.Plugins.Disabled[0] != "plugin-c" { + t.Errorf("Plugins.Disabled = %v, want [plugin-c]", cfg.Plugins.Disabled) + } +} + // TestDefaultConfig_Gateway verifies gateway defaults func TestDefaultConfig_Gateway(t *testing.T) { cfg := DefaultConfig() diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index ebb924859f..61f61ab34b 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -134,6 +134,11 @@ func DefaultConfig() *Config { Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{WebSearch: true}, }, + Plugins: PluginsConfig{ + DefaultEnabled: true, + Enabled: []string{}, + Disabled: []string{}, + }, ModelList: []ModelConfig{ // ============================================ // Add your API key to the model you want to use diff --git a/pkg/hooks/hooks.go b/pkg/hooks/hooks.go new file mode 100644 index 0000000000..9865c86fe4 --- /dev/null +++ b/pkg/hooks/hooks.go @@ -0,0 +1,499 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package hooks + +import ( + "context" + "fmt" + "reflect" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +const voidHookWaitBudget = 50 * time.Millisecond + +// HookHandler is the callback signature for all hooks. +type HookHandler[T any] func(ctx context.Context, event *T) error + +// HookRegistration tracks a handler with its priority and name. +type HookRegistration[T any] struct { + Handler HookHandler[T] + Priority int // Lower = runs first + Name string +} + +// HookRegistry manages all lifecycle hooks. +type HookRegistry struct { + messageReceived []HookRegistration[MessageReceivedEvent] + messageSending []HookRegistration[MessageSendingEvent] + beforeToolCall []HookRegistration[BeforeToolCallEvent] + afterToolCall []HookRegistration[AfterToolCallEvent] + llmInput []HookRegistration[LLMInputEvent] + llmOutput []HookRegistration[LLMOutputEvent] + sessionStart []HookRegistration[SessionEvent] + sessionEnd []HookRegistration[SessionEvent] + mu sync.RWMutex +} + +// NewHookRegistry creates an empty hook registry. +func NewHookRegistry() *HookRegistry { + return &HookRegistry{} +} + +// insertSorted inserts a registration into a new slice sorted by priority. +// Always allocates a new backing array so concurrent readers of the old slice are safe. +func insertSorted[T any](slice []HookRegistration[T], reg HookRegistration[T]) []HookRegistration[T] { + i := 0 + for i < len(slice) && slice[i].Priority <= reg.Priority { + i++ + } + result := make([]HookRegistration[T], len(slice)+1) + copy(result, slice[:i]) + result[i] = reg + copy(result[i+1:], slice[i:]) + return result +} + +// Registration methods + +func (r *HookRegistry) OnMessageReceived(name string, priority int, handler HookHandler[MessageReceivedEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.messageReceived = insertSorted(r.messageReceived, HookRegistration[MessageReceivedEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnMessageSending(name string, priority int, handler HookHandler[MessageSendingEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.messageSending = insertSorted(r.messageSending, HookRegistration[MessageSendingEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnBeforeToolCall(name string, priority int, handler HookHandler[BeforeToolCallEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.beforeToolCall = insertSorted(r.beforeToolCall, HookRegistration[BeforeToolCallEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnAfterToolCall(name string, priority int, handler HookHandler[AfterToolCallEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.afterToolCall = insertSorted(r.afterToolCall, HookRegistration[AfterToolCallEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnLLMInput(name string, priority int, handler HookHandler[LLMInputEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.llmInput = insertSorted(r.llmInput, HookRegistration[LLMInputEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnLLMOutput(name string, priority int, handler HookHandler[LLMOutputEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.llmOutput = insertSorted(r.llmOutput, HookRegistration[LLMOutputEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnSessionStart(name string, priority int, handler HookHandler[SessionEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.sessionStart = insertSorted(r.sessionStart, HookRegistration[SessionEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnSessionEnd(name string, priority int, handler HookHandler[SessionEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.sessionEnd = insertSorted(r.sessionEnd, HookRegistration[SessionEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +// Trigger methods — void hooks + +func cloneMapStringString(src map[string]string) map[string]string { + if src == nil { + return nil + } + dst := make(map[string]string, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func cloneMapStringAny(src map[string]any) map[string]any { + if src == nil { + return nil + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = cloneAny(v) + } + return dst +} + +func cloneAny(v any) any { + if v == nil { + return nil + } + cloned := cloneReflectValue(reflect.ValueOf(v)) + if !cloned.IsValid() { + return nil + } + return cloned.Interface() +} + +func cloneReflectValue(v reflect.Value) reflect.Value { + if !v.IsValid() { + return v + } + + switch v.Kind() { + case reflect.Pointer: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.New(v.Type().Elem()) + out.Elem().Set(cloneReflectValue(v.Elem())) + return out + case reflect.Interface: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.New(v.Type()).Elem() + out.Set(cloneReflectValue(v.Elem())) + return out + case reflect.Map: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.MakeMapWithSize(v.Type(), v.Len()) + iter := v.MapRange() + for iter.Next() { + out.SetMapIndex(iter.Key(), cloneReflectValue(iter.Value())) + } + return out + case reflect.Slice: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.MakeSlice(v.Type(), v.Len(), v.Len()) + for i := range v.Len() { + out.Index(i).Set(cloneReflectValue(v.Index(i))) + } + return out + case reflect.Array: + out := reflect.New(v.Type()).Elem() + for i := range v.Len() { + out.Index(i).Set(cloneReflectValue(v.Index(i))) + } + return out + case reflect.Struct: + out := reflect.New(v.Type()).Elem() + for i := range v.NumField() { + field := out.Field(i) + if !field.CanSet() { + // Preserve original value for structs with non-settable fields. + return v + } + field.Set(cloneReflectValue(v.Field(i))) + } + return out + default: + return v + } +} + +func cloneToolCall(tc providers.ToolCall) providers.ToolCall { + out := tc + out.Arguments = cloneMapStringAny(tc.Arguments) + if tc.Function != nil { + f := *tc.Function + out.Function = &f + } + if tc.ExtraContent != nil { + ec := *tc.ExtraContent + if tc.ExtraContent.Google != nil { + g := *tc.ExtraContent.Google + ec.Google = &g + } + out.ExtraContent = &ec + } + return out +} + +func cloneMessage(msg providers.Message) providers.Message { + out := msg + if msg.ToolCalls != nil { + out.ToolCalls = make([]providers.ToolCall, len(msg.ToolCalls)) + for i := range msg.ToolCalls { + out.ToolCalls[i] = cloneToolCall(msg.ToolCalls[i]) + } + } + if msg.SystemParts != nil { + out.SystemParts = make([]providers.ContentBlock, len(msg.SystemParts)) + for i := range msg.SystemParts { + part := msg.SystemParts[i] + if part.CacheControl != nil { + cc := *part.CacheControl + part.CacheControl = &cc + } + out.SystemParts[i] = part + } + } + return out +} + +func cloneToolDefinition(td providers.ToolDefinition) providers.ToolDefinition { + out := td + out.Function = td.Function + out.Function.Parameters = cloneMapStringAny(td.Function.Parameters) + return out +} + +func cloneVoidEvent[T any](event *T) *T { + if event == nil { + return nil + } + + switch e := any(event).(type) { + case *MessageReceivedEvent: + c := *e + if e.Media != nil { + c.Media = append([]string(nil), e.Media...) + } + c.Metadata = cloneMapStringString(e.Metadata) + return any(&c).(*T) + case *AfterToolCallEvent: + c := *e + c.Args = cloneMapStringAny(e.Args) + if e.Result != nil { + r := *e.Result + c.Result = &r + } + return any(&c).(*T) + case *LLMInputEvent: + c := *e + if e.Messages != nil { + c.Messages = make([]providers.Message, len(e.Messages)) + for i := range e.Messages { + c.Messages[i] = cloneMessage(e.Messages[i]) + } + } + if e.Tools != nil { + c.Tools = make([]providers.ToolDefinition, len(e.Tools)) + for i := range e.Tools { + c.Tools[i] = cloneToolDefinition(e.Tools[i]) + } + } + return any(&c).(*T) + case *LLMOutputEvent: + c := *e + if e.ToolCalls != nil { + c.ToolCalls = make([]providers.ToolCall, len(e.ToolCalls)) + for i := range e.ToolCalls { + c.ToolCalls[i] = cloneToolCall(e.ToolCalls[i]) + } + } + return any(&c).(*T) + case *SessionEvent: + c := *e + return any(&c).(*T) + default: + c := *event + return &c + } +} + +// triggerVoid runs all handlers concurrently. +// It waits up to a small budget to collect immediate completions, then +// continues fail-open to avoid blocking the core agent pipeline. +// Each handler receives a cloned event to avoid shared-state mutation races. +// Errors are logged but do not propagate to the caller. +func triggerVoid[T any](ctx context.Context, hooks []HookRegistration[T], event *T, hookName string) { + if len(hooks) == 0 { + return + } + var wg sync.WaitGroup + for _, h := range hooks { + wg.Add(1) + go func(reg HookRegistration[T]) { + defer wg.Done() + eventCopy := cloneVoidEvent(event) + defer func() { + if r := recover(); r != nil { + logger.ErrorCF("hooks", "Hook panic", + map[string]any{ + "hook": hookName, + "handler": reg.Name, + "panic": fmt.Sprintf("%v", r), + }) + } + }() + if err := reg.Handler(ctx, eventCopy); err != nil { + logger.WarnCF("hooks", "Hook error", + map[string]any{ + "hook": hookName, + "handler": reg.Name, + "error": err.Error(), + }) + } + }(h) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + logger.WarnCF("hooks", "Void hook dispatch interrupted by context", + map[string]any{ + "hook": hookName, + }) + case <-time.After(voidHookWaitBudget): + logger.WarnCF("hooks", "Void hook dispatch exceeded wait budget; continuing", + map[string]any{ + "hook": hookName, + "wait_budget_ms": voidHookWaitBudget.Milliseconds(), + }) + } +} + +// triggerModifying runs handlers sequentially by priority, stopping if Cancel is set. +// The cancelCheck function inspects the event to determine if Cancel was set. +func triggerModifying[T any]( + ctx context.Context, + hooks []HookRegistration[T], + event *T, + hookName string, + cancelCheck func(*T) bool, +) { + if len(hooks) == 0 { + return + } + for _, h := range hooks { + func() { + defer func() { + if r := recover(); r != nil { + logger.ErrorCF("hooks", "Hook panic", + map[string]any{ + "hook": hookName, + "handler": h.Name, + "panic": fmt.Sprintf("%v", r), + }) + } + }() + if err := h.Handler(ctx, event); err != nil { + logger.WarnCF("hooks", "Hook error", + map[string]any{ + "hook": hookName, + "handler": h.Name, + "error": err.Error(), + }) + } + }() + if cancelCheck(event) { + logger.InfoCF("hooks", "Hook canceled operation", + map[string]any{ + "hook": hookName, + "handler": h.Name, + }) + return + } + } +} + +// TriggerMessageReceived fires all message_received handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerMessageReceived(ctx context.Context, event *MessageReceivedEvent) { + r.mu.RLock() + hooks := r.messageReceived + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "message_received") +} + +func (r *HookRegistry) TriggerMessageSending(ctx context.Context, event *MessageSendingEvent) { + r.mu.RLock() + hooks := r.messageSending + r.mu.RUnlock() + triggerModifying(ctx, hooks, event, "message_sending", func(e *MessageSendingEvent) bool { + return e.Cancel + }) +} + +func (r *HookRegistry) TriggerBeforeToolCall(ctx context.Context, event *BeforeToolCallEvent) { + r.mu.RLock() + hooks := r.beforeToolCall + r.mu.RUnlock() + triggerModifying(ctx, hooks, event, "before_tool_call", func(e *BeforeToolCallEvent) bool { + return e.Cancel + }) +} + +// TriggerAfterToolCall fires all after_tool_call handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerAfterToolCall(ctx context.Context, event *AfterToolCallEvent) { + r.mu.RLock() + hooks := r.afterToolCall + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "after_tool_call") +} + +// TriggerLLMInput fires all llm_input handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerLLMInput(ctx context.Context, event *LLMInputEvent) { + r.mu.RLock() + hooks := r.llmInput + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "llm_input") +} + +// TriggerLLMOutput fires all llm_output handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerLLMOutput(ctx context.Context, event *LLMOutputEvent) { + r.mu.RLock() + hooks := r.llmOutput + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "llm_output") +} + +// TriggerSessionStart fires all session_start handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerSessionStart(ctx context.Context, event *SessionEvent) { + r.mu.RLock() + hooks := r.sessionStart + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "session_start") +} + +// TriggerSessionEnd fires all session_end handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerSessionEnd(ctx context.Context, event *SessionEvent) { + r.mu.RLock() + hooks := r.sessionEnd + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "session_end") +} diff --git a/pkg/hooks/hooks_test.go b/pkg/hooks/hooks_test.go new file mode 100644 index 0000000000..d21467a550 --- /dev/null +++ b/pkg/hooks/hooks_test.go @@ -0,0 +1,657 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package hooks + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +func TestNewHookRegistry(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // Triggering all hooks on an empty registry should not panic. + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "hello"}) + r.TriggerMessageSending(ctx, &MessageSendingEvent{Content: "hello"}) + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{ToolName: "t"}) + r.TriggerAfterToolCall(ctx, &AfterToolCallEvent{ToolName: "t"}) + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "a"}) + r.TriggerLLMOutput(ctx, &LLMOutputEvent{AgentID: "a"}) + r.TriggerSessionStart(ctx, &SessionEvent{AgentID: "a"}) + r.TriggerSessionEnd(ctx, &SessionEvent{AgentID: "a"}) +} + +func TestVoidHookExecution(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var called atomic.Bool + r.OnMessageReceived("test", 0, func(_ context.Context, e *MessageReceivedEvent) error { + called.Store(true) + if e.Content != "ping" { + t.Errorf("Expected content 'ping', got '%s'", e.Content) + } + return nil + }) + + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "ping"}) + + if !called.Load() { + t.Error("Expected handler to be called") + } +} + +func TestVoidHooksConcurrent(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var count atomic.Int32 + started := make(chan struct{}, 5) + release := make(chan struct{}) + done := make(chan struct{}) + + for i := range 5 { + r.OnMessageReceived("hook-"+string(rune('A'+i)), i, func(_ context.Context, _ *MessageReceivedEvent) error { + started <- struct{}{} + <-release + count.Add(1) + return nil + }) + } + + go func() { + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "test"}) + close(done) + }() + + // All 5 handlers must reach the barrier concurrently. + for i := range 5 { + select { + case <-started: + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for handler %d to start", i+1) + } + } + + // Release all handlers. + close(release) + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for handlers to complete") + } + + if count.Load() != 5 { + t.Errorf("Expected 5 handlers called, got %d", count.Load()) + } +} + +func TestVoidHooksReceiveIsolatedMessageReceivedEvents(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnMessageReceived("mutator-a", 0, func(_ context.Context, e *MessageReceivedEvent) error { + e.Content = "changed-a" + e.Media[0] = "changed-media-a" + e.Metadata["k"] = "changed-a" + e.Metadata["new-a"] = "x" + return nil + }) + r.OnMessageReceived("mutator-b", 1, func(_ context.Context, e *MessageReceivedEvent) error { + e.Content = "changed-b" + e.Media = append(e.Media, "extra") + e.Metadata["k"] = "changed-b" + e.Metadata["new-b"] = "y" + return nil + }) + + event := &MessageReceivedEvent{ + Content: "original", + Media: []string{"m1"}, + Metadata: map[string]string{"k": "v"}, + } + r.TriggerMessageReceived(ctx, event) + + if event.Content != "original" { + t.Fatalf("expected original content to remain unchanged, got %q", event.Content) + } + if len(event.Media) != 1 || event.Media[0] != "m1" { + t.Fatalf("expected original media to remain unchanged, got %#v", event.Media) + } + if got := event.Metadata["k"]; got != "v" { + t.Fatalf("expected metadata[k] to remain v, got %q", got) + } + if _, ok := event.Metadata["new-a"]; ok { + t.Fatal("unexpected mutation leaked from hook mutator-a") + } + if _, ok := event.Metadata["new-b"]; ok { + t.Fatal("unexpected mutation leaked from hook mutator-b") + } +} + +func TestVoidHooksReceiveIsolatedAfterToolCallEvents(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnAfterToolCall("mutator-a", 0, func(_ context.Context, e *AfterToolCallEvent) error { + e.Args["k"] = "changed-a" + e.Result.ForLLM = "mutated-a" + return nil + }) + r.OnAfterToolCall("mutator-b", 1, func(_ context.Context, e *AfterToolCallEvent) error { + e.Args["k"] = "changed-b" + e.Args["new"] = "v" + e.Result.ForUser = "mutated-b" + return nil + }) + + event := &AfterToolCallEvent{ + ToolName: "shell", + Args: map[string]any{"k": "original"}, + Result: &tools.ToolResult{ + ForLLM: "for-llm", + ForUser: "for-user", + }, + } + + // Use a local copy so we can compare immutable expectations. + r.TriggerAfterToolCall(ctx, event) + + if got := event.Args["k"]; got != "original" { + t.Fatalf("expected args[k] to remain original, got %#v", got) + } + if _, ok := event.Args["new"]; ok { + t.Fatal("unexpected args mutation leaked from hook") + } + if event.Result.ForLLM != "for-llm" { + t.Fatalf("expected original result.ForLLM to remain unchanged, got %q", event.Result.ForLLM) + } + if event.Result.ForUser != "for-user" { + t.Fatalf("expected original result.ForUser to remain unchanged, got %q", event.Result.ForUser) + } +} + +func TestVoidHooksReceiveIsolatedLLMInputToolSchema(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnLLMInput("mutator", 0, func(_ context.Context, e *LLMInputEvent) error { + required, ok := e.Tools[0].Function.Parameters["required"].([]string) + if !ok { + t.Fatal("required should be []string") + } + required[0] = "mutated" + e.Tools[0].Function.Parameters["required"] = append(required, "extra") + return nil + }) + + event := &LLMInputEvent{ + AgentID: "a1", + Model: "m1", + Tools: []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "message", + Parameters: map[string]any{ + "type": "object", + "required": []string{"content"}, + }, + }, + }, + }, + } + + r.TriggerLLMInput(ctx, event) + + required, ok := event.Tools[0].Function.Parameters["required"].([]string) + if !ok { + t.Fatal("required should remain []string") + } + if len(required) != 1 || required[0] != "content" { + t.Fatalf("expected required to remain unchanged, got %#v", required) + } +} + +func TestVoidHooksReceiveIsolatedStructValuesInMap(t *testing.T) { + type schemaSpec struct { + Required []string + Meta map[string]string + } + + r := NewHookRegistry() + ctx := context.Background() + + r.OnLLMInput("struct-mutator", 0, func(_ context.Context, e *LLMInputEvent) error { + spec, ok := e.Tools[0].Function.Parameters["schema"].(schemaSpec) + if !ok { + t.Fatal("schema should be schemaSpec") + } + spec.Required[0] = "mutated" + spec.Meta["k"] = "changed" + e.Tools[0].Function.Parameters["schema"] = spec + return nil + }) + + event := &LLMInputEvent{ + AgentID: "a1", + Model: "m1", + Tools: []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "message", + Parameters: map[string]any{ + "schema": schemaSpec{ + Required: []string{"content"}, + Meta: map[string]string{"k": "v"}, + }, + }, + }, + }, + }, + } + + r.TriggerLLMInput(ctx, event) + + spec, ok := event.Tools[0].Function.Parameters["schema"].(schemaSpec) + if !ok { + t.Fatal("schema should remain schemaSpec") + } + if len(spec.Required) != 1 || spec.Required[0] != "content" { + t.Fatalf("expected required to remain unchanged, got %#v", spec.Required) + } + if got := spec.Meta["k"]; got != "v" { + t.Fatalf("expected meta[k] to remain v, got %q", got) + } +} + +func TestVoidHooksFailOpenOnSlowHandler(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + started := make(chan struct{}) + release := make(chan struct{}) + done := make(chan struct{}) + + r.OnLLMInput("slow", 0, func(_ context.Context, _ *LLMInputEvent) error { + close(started) + <-release + close(done) + return nil + }) + + begin := time.Now() + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "a1"}) + elapsed := time.Since(begin) + + if elapsed > voidHookWaitBudget*3 { + t.Fatalf("expected fail-open dispatch within budget, got %s", elapsed) + } + + select { + case <-started: + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for slow handler to start") + } + + close(release) + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for slow handler to finish after release") + } +} + +func TestModifyingHookPriority(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var mu sync.Mutex + var order []string + + // Register in reverse priority order to verify sorting. + r.OnMessageSending("third", 30, func(_ context.Context, _ *MessageSendingEvent) error { + mu.Lock() + order = append(order, "third") + mu.Unlock() + return nil + }) + r.OnMessageSending("first", 10, func(_ context.Context, _ *MessageSendingEvent) error { + mu.Lock() + order = append(order, "first") + mu.Unlock() + return nil + }) + r.OnMessageSending("second", 20, func(_ context.Context, _ *MessageSendingEvent) error { + mu.Lock() + order = append(order, "second") + mu.Unlock() + return nil + }) + + r.TriggerMessageSending(ctx, &MessageSendingEvent{Content: "hi"}) + + if len(order) != 3 { + t.Fatalf("Expected 3 handlers, got %d", len(order)) + } + if order[0] != "first" || order[1] != "second" || order[2] != "third" { + t.Errorf("Expected [first second third], got %v", order) + } +} + +func TestModifyingHookCancel(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var secondCalled bool + + r.OnMessageSending("canceler", 10, func(_ context.Context, e *MessageSendingEvent) error { + e.Cancel = true + e.CancelReason = "blocked" + return nil + }) + r.OnMessageSending("after-cancel", 20, func(_ context.Context, _ *MessageSendingEvent) error { + secondCalled = true + return nil + }) + + event := &MessageSendingEvent{Content: "hi"} + r.TriggerMessageSending(ctx, event) + + if !event.Cancel { + t.Error("Expected Cancel to be true") + } + if secondCalled { + t.Error("Expected second handler NOT to be called after cancel") + } +} + +func TestBeforeToolCallModification(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnBeforeToolCall("modifier", 10, func(_ context.Context, e *BeforeToolCallEvent) error { + e.Args["injected"] = "value" + return nil + }) + + event := &BeforeToolCallEvent{ + ToolName: "search", + Args: map[string]any{"query": "test"}, + } + r.TriggerBeforeToolCall(ctx, event) + + if event.Args["injected"] != "value" { + t.Error("Expected injected arg to persist") + } + if event.Args["query"] != "test" { + t.Error("Expected original arg to remain") + } +} + +func TestMessageSendingFilter(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnMessageSending("rewriter", 10, func(_ context.Context, e *MessageSendingEvent) error { + e.Content = "[filtered] " + e.Content + return nil + }) + + event := &MessageSendingEvent{Content: "hello world"} + r.TriggerMessageSending(ctx, event) + + if event.Content != "[filtered] hello world" { + t.Errorf("Expected '[filtered] hello world', got '%s'", event.Content) + } +} + +func TestZeroCostWhenEmpty(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // This is primarily a safety/smoke test — no panics, no allocations of note. + for range 100 { + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{}) + r.TriggerMessageSending(ctx, &MessageSendingEvent{}) + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{}) + r.TriggerAfterToolCall(ctx, &AfterToolCallEvent{}) + r.TriggerLLMInput(ctx, &LLMInputEvent{}) + r.TriggerLLMOutput(ctx, &LLMOutputEvent{}) + r.TriggerSessionStart(ctx, &SessionEvent{}) + r.TriggerSessionEnd(ctx, &SessionEvent{}) + } +} + +func TestLLMInputOutput(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var inputCalled, outputCalled atomic.Bool + + r.OnLLMInput("input-hook", 0, func(_ context.Context, e *LLMInputEvent) error { + if e.Model != "gpt-4" { + t.Errorf("Expected model 'gpt-4', got '%s'", e.Model) + } + inputCalled.Store(true) + return nil + }) + + r.OnLLMOutput("output-hook", 0, func(_ context.Context, e *LLMOutputEvent) error { + if e.Content != "response" { + t.Errorf("Expected content 'response', got '%s'", e.Content) + } + outputCalled.Store(true) + return nil + }) + + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "a1", Model: "gpt-4", Iteration: 1}) + r.TriggerLLMOutput(ctx, &LLMOutputEvent{AgentID: "a1", Model: "gpt-4", Content: "response", Iteration: 1}) + + if !inputCalled.Load() { + t.Error("Expected LLM input hook to be called") + } + if !outputCalled.Load() { + t.Error("Expected LLM output hook to be called") + } +} + +func TestSessionStartEnd(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var startCalled, endCalled atomic.Bool + + r.OnSessionStart("start-hook", 0, func(_ context.Context, e *SessionEvent) error { + if e.SessionKey != "sess-1" { + t.Errorf("Expected session key 'sess-1', got '%s'", e.SessionKey) + } + startCalled.Store(true) + return nil + }) + + r.OnSessionEnd("end-hook", 0, func(_ context.Context, e *SessionEvent) error { + if e.SessionKey != "sess-1" { + t.Errorf("Expected session key 'sess-1', got '%s'", e.SessionKey) + } + endCalled.Store(true) + return nil + }) + + event := &SessionEvent{AgentID: "a1", SessionKey: "sess-1", Channel: "test", ChatID: "c1"} + r.TriggerSessionStart(ctx, event) + r.TriggerSessionEnd(ctx, event) + + if !startCalled.Load() { + t.Error("Expected session start hook to be called") + } + if !endCalled.Load() { + t.Error("Expected session end hook to be called") + } +} + +func TestConcurrentRegistrationAndTrigger(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var wg sync.WaitGroup + + // Goroutines registering hooks. + for i := range 10 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r.OnMessageReceived( + fmt.Sprintf("reg-hook-%d", idx), + idx, + func(_ context.Context, _ *MessageReceivedEvent) error { + return nil + }, + ) + }(i) + } + + // Goroutines triggering hooks concurrently. + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "race"}) + }() + } + + wg.Wait() +} + +func TestInsertSorted(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var order []int + + // Register with priorities: 50, 10, 30, 20, 40 + priorities := []int{50, 10, 30, 20, 40} + for _, p := range priorities { + r.OnBeforeToolCall(fmt.Sprintf("p-%d", p), p, func(_ context.Context, _ *BeforeToolCallEvent) error { + order = append(order, p) + return nil + }) + } + + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{ToolName: "test", Args: map[string]any{}}) + + expected := []int{10, 20, 30, 40, 50} + if len(order) != len(expected) { + t.Fatalf("Expected %d handlers, got %d", len(expected), len(order)) + } + for i, v := range expected { + if order[i] != v { + t.Errorf("Position %d: expected priority %d, got %d", i, v, order[i]) + } + } +} + +func TestAfterToolCallExecution(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var called bool + var capturedName string + r.OnAfterToolCall("logger", 0, func(_ context.Context, event *AfterToolCallEvent) error { + called = true + capturedName = event.ToolName + return nil + }) + + r.TriggerAfterToolCall(ctx, &AfterToolCallEvent{ + ToolName: "shell", + Args: map[string]any{"cmd": "ls"}, + Channel: "telegram", + ChatID: "123", + }) + + if !called { + t.Error("Expected after_tool_call handler to be called") + } + if capturedName != "shell" { + t.Errorf("Expected ToolName 'shell', got '%s'", capturedName) + } +} + +func TestHandlerErrorsSwallowed(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // Test void hooks: error in one handler doesn't prevent others from running + var secondCalled bool + r.OnMessageReceived("erroring", 10, func(_ context.Context, _ *MessageReceivedEvent) error { + return fmt.Errorf("handler error") + }) + r.OnMessageReceived("observer", 20, func(_ context.Context, _ *MessageReceivedEvent) error { + secondCalled = true + return nil + }) + + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "test"}) + if !secondCalled { + t.Error("Expected second void handler to run despite first handler's error") + } + + // Test modifying hooks: error doesn't stop chain (only Cancel does) + var modifySecondCalled bool + r.OnMessageSending("erroring", 10, func(_ context.Context, _ *MessageSendingEvent) error { + return fmt.Errorf("handler error") + }) + r.OnMessageSending("modifier", 20, func(_ context.Context, _ *MessageSendingEvent) error { + modifySecondCalled = true + return nil + }) + + r.TriggerMessageSending(ctx, &MessageSendingEvent{Content: "test"}) + if !modifySecondCalled { + t.Error("Expected second modifying handler to run despite first handler's error") + } +} + +func TestPanicRecovery(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // Void hook: panic in one handler shouldn't crash, other handlers should still run + var safeHandlerCalled bool + r.OnLLMInput("panicker", 10, func(_ context.Context, _ *LLMInputEvent) error { + panic("boom") + }) + r.OnLLMInput("safe", 10, func(_ context.Context, _ *LLMInputEvent) error { + safeHandlerCalled = true + return nil + }) + + // Should not panic + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "test"}) + if !safeHandlerCalled { + t.Error("Expected safe handler to run despite panicking sibling") + } + + // Modifying hook: panic in handler shouldn't crash + r.OnBeforeToolCall("panicker", 10, func(_ context.Context, _ *BeforeToolCallEvent) error { + panic("boom") + }) + + // Should not panic + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{ToolName: "test"}) +} diff --git a/pkg/hooks/types.go b/pkg/hooks/types.go new file mode 100644 index 0000000000..4a0f6697d9 --- /dev/null +++ b/pkg/hooks/types.go @@ -0,0 +1,82 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package hooks + +import ( + "time" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// MessageReceivedEvent is fired when an inbound message is consumed from the bus. +type MessageReceivedEvent struct { + Channel string + SenderID string + ChatID string + Content string + Media []string + Metadata map[string]string +} + +// MessageSendingEvent is fired before an outbound message is published. +// Handlers can modify Content or set Cancel to block delivery. +type MessageSendingEvent struct { + Channel string + ChatID string + Content string // Modifiable + Cancel bool + CancelReason string +} + +// BeforeToolCallEvent is fired before a tool is executed. +// Handlers can modify Args, or set Cancel to block execution. +type BeforeToolCallEvent struct { + ToolName string + Args map[string]any // Modifiable; guaranteed non-nil when triggered via AgentLoop. + Channel string + ChatID string + Cancel bool + CancelReason string // Message returned to LLM when canceled +} + +// AfterToolCallEvent is fired after a tool completes execution. +type AfterToolCallEvent struct { + ToolName string + Args map[string]any + Channel string + ChatID string + Duration time.Duration + Result *tools.ToolResult +} + +// LLMInputEvent is fired before the LLM provider is called. +type LLMInputEvent struct { + AgentID string + Model string + Messages []providers.Message + Tools []providers.ToolDefinition + Iteration int +} + +// LLMOutputEvent is fired after the LLM provider responds. +type LLMOutputEvent struct { + AgentID string + Model string + Content string + ToolCalls []providers.ToolCall + Iteration int + Duration time.Duration +} + +// SessionEvent is fired at session start and end. +type SessionEvent struct { + AgentID string + SessionKey string + Channel string + ChatID string +} diff --git a/pkg/plugin/builtin/catalog.go b/pkg/plugin/builtin/catalog.go new file mode 100644 index 0000000000..807a4d0ecf --- /dev/null +++ b/pkg/plugin/builtin/catalog.go @@ -0,0 +1,31 @@ +package builtin + +import ( + "sort" + + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/plugin/demoplugin" +) + +// Factory creates one builtin plugin instance. +type Factory func() plugin.Plugin + +// Catalog returns compile-time builtin plugin factories by name. +func Catalog() map[string]Factory { + return map[string]Factory{ + "policy-demo": func() plugin.Plugin { + return demoplugin.NewPolicyDemoPlugin(demoplugin.PolicyDemoConfig{}) + }, + } +} + +// Names returns sorted builtin plugin names. +func Names() []string { + catalog := Catalog() + names := make([]string, 0, len(catalog)) + for name := range catalog { + names = append(names, name) + } + sort.Strings(names) + return names +} diff --git a/pkg/plugin/builtin/catalog_test.go b/pkg/plugin/builtin/catalog_test.go new file mode 100644 index 0000000000..8a48a3c997 --- /dev/null +++ b/pkg/plugin/builtin/catalog_test.go @@ -0,0 +1,32 @@ +package builtin + +import ( + "slices" + "testing" +) + +func TestCatalogContainsPolicyDemo(t *testing.T) { + catalog := Catalog() + factory, ok := catalog["policy-demo"] + if !ok { + t.Fatalf("Catalog() missing %q plugin", "policy-demo") + } + if factory == nil { + t.Fatalf("Catalog()[%q] factory is nil", "policy-demo") + } + if got := factory(); got == nil { + t.Fatalf("Catalog()[%q]() returned nil plugin", "policy-demo") + } +} + +func TestNamesSorted(t *testing.T) { + first := Names() + second := Names() + + if !slices.IsSorted(first) { + t.Fatalf("Names() is not sorted: %v", first) + } + if !slices.Equal(first, second) { + t.Fatalf("Names() is not deterministic across calls: %v vs %v", first, second) + } +} diff --git a/pkg/plugin/demoplugin/policy_demo.go b/pkg/plugin/demoplugin/policy_demo.go new file mode 100644 index 0000000000..6e89b1dae8 --- /dev/null +++ b/pkg/plugin/demoplugin/policy_demo.go @@ -0,0 +1,315 @@ +package demoplugin + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/hooks" + "github.com/sipeed/picoclaw/pkg/plugin" +) + +// PolicyDemoConfig controls the demo plugin behavior. +type PolicyDemoConfig struct { + BlockedTools []string + RedactPrefixes []string + ChannelToolAllowlist map[string][]string + DenyOutboundPatterns []string + MaxToolTimeoutSecond int +} + +// PolicyDemoStats provides basic evidence that hook paths were executed. +type PolicyDemoStats struct { + BeforeToolCalls int + BlockedToolCalls int + MessageSends int + RedactedMessages int + BlockedMessages int + SessionStarts int + SessionEnds int + AfterToolCalls int + TotalToolDuration time.Duration +} + +// PolicyDemoPlugin demonstrates why plugins are needed: it enforces runtime policy +// at tool-call and outbound-message lifecycle points and collects audit metrics. +type PolicyDemoPlugin struct { + blockedTools map[string]struct{} + prefixes []string + channelAllowlist map[string]map[string]struct{} + denyPatterns []string + maxTimeout int + + mu sync.Mutex + stats PolicyDemoStats +} + +func NewPolicyDemoPlugin(cfg PolicyDemoConfig) *PolicyDemoPlugin { + blocked := make(map[string]struct{}, len(cfg.BlockedTools)) + for _, t := range cfg.BlockedTools { + t = normalizeLower(t) + if t == "" { + continue + } + blocked[t] = struct{}{} + } + + prefixes := make([]string, 0, len(cfg.RedactPrefixes)) + for _, p := range cfg.RedactPrefixes { + p = strings.TrimSpace(p) + if p == "" { + continue + } + prefixes = append(prefixes, p) + } + + allowlist := make(map[string]map[string]struct{}, len(cfg.ChannelToolAllowlist)) + for channel, tools := range cfg.ChannelToolAllowlist { + channel = normalizeLower(channel) + if channel == "" { + continue + } + toolSet := make(map[string]struct{}, len(tools)) + for _, t := range tools { + t = normalizeLower(t) + if t == "" { + continue + } + toolSet[t] = struct{}{} + } + allowlist[channel] = toolSet + } + + patterns := make([]string, 0, len(cfg.DenyOutboundPatterns)) + for _, p := range cfg.DenyOutboundPatterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + patterns = append(patterns, p) + } + + maxTimeout := cfg.MaxToolTimeoutSecond + if maxTimeout < 0 { + maxTimeout = 0 + } + + return &PolicyDemoPlugin{ + blockedTools: blocked, + prefixes: prefixes, + channelAllowlist: allowlist, + denyPatterns: patterns, + maxTimeout: maxTimeout, + } +} + +func (p *PolicyDemoPlugin) Name() string { + return "policy-demo" +} + +func (p *PolicyDemoPlugin) APIVersion() string { + return plugin.APIVersion +} + +func (p *PolicyDemoPlugin) Snapshot() PolicyDemoStats { + p.mu.Lock() + defer p.mu.Unlock() + return p.stats +} + +func (p *PolicyDemoPlugin) Register(r *hooks.HookRegistry) error { + r.OnBeforeToolCall("policy-demo-tool-policy", 100, func(_ context.Context, e *hooks.BeforeToolCallEvent) error { + tool := normalizeLower(e.ToolName) + p.incBeforeToolCalls() + + if _, blocked := p.blockedTools[tool]; blocked { + e.Cancel = true + e.CancelReason = "blocked by policy-demo plugin" + p.incBlockedToolCalls() + return nil + } + + channel := normalizeLower(e.Channel) + if allow, ok := p.channelAllowlist[channel]; ok { + if _, allowed := allow[tool]; !allowed { + e.Cancel = true + e.CancelReason = fmt.Sprintf("tool %q is not allowed on channel %q", e.ToolName, e.Channel) + p.incBlockedToolCalls() + return nil + } + } + + if p.maxTimeout > 0 { + clampArgNumber(e.Args, "timeout", p.maxTimeout) + clampArgNumber(e.Args, "timeout_seconds", p.maxTimeout) + } + return nil + }) + + r.OnMessageSending("policy-demo-redact-and-guard", 50, func(_ context.Context, e *hooks.MessageSendingEvent) error { + p.incMessageSends() + + for _, pattern := range p.denyPatterns { + if strings.Contains(e.Content, pattern) { + e.Cancel = true + e.CancelReason = "blocked by policy-demo outbound guard" + p.incBlockedMessages() + return nil + } + } + + content := e.Content + redacted := false + for _, prefix := range p.prefixes { + next := strings.ReplaceAll(content, prefix, "[redacted]-") + if next != content { + redacted = true + } + content = next + } + e.Content = content + if redacted { + p.incRedactedMessages() + } + return nil + }) + + r.OnSessionStart("policy-demo-session-start-audit", 0, func(_ context.Context, _ *hooks.SessionEvent) error { + p.incSessionStarts() + return nil + }) + + r.OnSessionEnd("policy-demo-session-end-audit", 0, func(_ context.Context, _ *hooks.SessionEvent) error { + p.incSessionEnds() + return nil + }) + + r.OnAfterToolCall("policy-demo-after-tool-audit", 0, func(_ context.Context, e *hooks.AfterToolCallEvent) error { + p.incAfterToolCall(e.Duration) + return nil + }) + + return nil +} + +func normalizeLower(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} + +func clampArgNumber(args map[string]any, key string, limit int) { + if args == nil || limit <= 0 { + return + } + v, ok := args[key] + if !ok { + return + } + n, ok := toInt(v) + if !ok { + return + } + if n > limit { + args[key] = limit + } +} + +func toInt(v any) (int, bool) { + maxInt := int(^uint(0) >> 1) + maxIntU64 := uint64(maxInt) + maxInt64 := int64(maxInt) + minInt64 := -maxInt64 - 1 + + switch n := v.(type) { + case int: + return n, true + case int8: + return int(n), true + case int16: + return int(n), true + case int32: + return int(n), true + case int64: + if n < minInt64 || n > maxInt64 { + return 0, false + } + return int(n), true + case uint: + if uint64(n) > maxIntU64 { + return 0, false + } + return int(n), true + case uint8: + return int(n), true + case uint16: + return int(n), true + case uint32: + if uint64(n) > maxIntU64 { + return 0, false + } + return int(n), true + case uint64: + if n > maxIntU64 { + return 0, false + } + return int(n), true + case float32: + // Truncation is intentional for timeout normalization. + return int(n), true + case float64: + // Truncation is intentional for timeout normalization. + return int(n), true + default: + return 0, false + } +} + +func (p *PolicyDemoPlugin) incBeforeToolCalls() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.BeforeToolCalls++ +} + +func (p *PolicyDemoPlugin) incBlockedToolCalls() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.BlockedToolCalls++ +} + +func (p *PolicyDemoPlugin) incMessageSends() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.MessageSends++ +} + +func (p *PolicyDemoPlugin) incRedactedMessages() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.RedactedMessages++ +} + +func (p *PolicyDemoPlugin) incBlockedMessages() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.BlockedMessages++ +} + +func (p *PolicyDemoPlugin) incSessionStarts() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.SessionStarts++ +} + +func (p *PolicyDemoPlugin) incSessionEnds() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.SessionEnds++ +} + +func (p *PolicyDemoPlugin) incAfterToolCall(d time.Duration) { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.AfterToolCalls++ + p.stats.TotalToolDuration += d +} diff --git a/pkg/plugin/demoplugin/policy_demo_test.go b/pkg/plugin/demoplugin/policy_demo_test.go new file mode 100644 index 0000000000..4d41084f1c --- /dev/null +++ b/pkg/plugin/demoplugin/policy_demo_test.go @@ -0,0 +1,189 @@ +package demoplugin + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/hooks" + "github.com/sipeed/picoclaw/pkg/plugin" +) + +func TestPolicyDemoPluginBlocksConfiguredTool(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + BlockedTools: []string{"shell"}, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.BeforeToolCallEvent{ToolName: "shell", Args: map[string]any{}, Channel: "cli"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), e) + + if !e.Cancel { + t.Fatal("expected tool call to be canceled") + } + if e.CancelReason == "" { + t.Fatal("expected cancel reason") + } + + stats := p.Snapshot() + if stats.BeforeToolCalls != 1 || stats.BlockedToolCalls != 1 { + t.Fatalf("unexpected stats: %+v", stats) + } +} + +func TestPolicyDemoPluginRedactsOutboundContent(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + RedactPrefixes: []string{"sk-"}, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.MessageSendingEvent{Content: "token=sk-abc123"} + pm.HookRegistry().TriggerMessageSending(context.Background(), e) + + if e.Cancel { + t.Fatal("did not expect cancellation") + } + if e.Content != "token=[redacted]-abc123" { + t.Fatalf("unexpected redaction result: %q", e.Content) + } + + stats := p.Snapshot() + if stats.MessageSends != 1 || stats.RedactedMessages != 1 { + t.Fatalf("unexpected stats: %+v", stats) + } +} + +func TestPolicyDemoPluginChannelAllowlist(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + ChannelToolAllowlist: map[string][]string{ + "telegram": {"web_search"}, + }, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + blocked := &hooks.BeforeToolCallEvent{ToolName: "shell", Args: map[string]any{}, Channel: "telegram"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), blocked) + if !blocked.Cancel { + t.Fatal("expected tool to be blocked by channel allowlist") + } + + allowed := &hooks.BeforeToolCallEvent{ToolName: "web_search", Args: map[string]any{}, Channel: "telegram"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), allowed) + if allowed.Cancel { + t.Fatalf("did not expect allowlisted tool to be blocked: %s", allowed.CancelReason) + } +} + +func TestPolicyDemoPluginOutboundGuard(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + DenyOutboundPatterns: []string{"4111-1111-1111-1111", "@corp.internal"}, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.MessageSendingEvent{Content: "card=4111-1111-1111-1111"} + pm.HookRegistry().TriggerMessageSending(context.Background(), e) + if !e.Cancel { + t.Fatal("expected outbound message to be blocked") + } + if e.CancelReason == "" { + t.Fatal("expected block reason") + } + + stats := p.Snapshot() + if stats.BlockedMessages != 1 { + t.Fatalf("expected blocked message count to be 1, got %+v", stats) + } +} + +func TestPolicyDemoPluginNormalizesTimeoutArg(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{MaxToolTimeoutSecond: 30}) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.BeforeToolCallEvent{ + ToolName: "web_fetch", + Channel: "cli", + Args: map[string]any{ + "timeout": 120, + "timeout_seconds": 90.0, + }, + } + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), e) + + if got, ok := e.Args["timeout"].(int); !ok || got != 30 { + t.Fatalf("expected timeout to be clamped to 30, got %#v", e.Args["timeout"]) + } + if got, ok := e.Args["timeout_seconds"].(int); !ok || got != 30 { + t.Fatalf("expected timeout_seconds to be clamped to 30, got %#v", e.Args["timeout_seconds"]) + } +} + +func TestPolicyDemoPluginAuditHooks(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{}) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + pm.HookRegistry().TriggerSessionStart(context.Background(), &hooks.SessionEvent{AgentID: "a1", SessionKey: "s1"}) + pm.HookRegistry().TriggerAfterToolCall( + context.Background(), + &hooks.AfterToolCallEvent{ + ToolName: "web_search", + Duration: 45 * time.Millisecond, + }, + ) + pm.HookRegistry().TriggerSessionEnd(context.Background(), &hooks.SessionEvent{AgentID: "a1", SessionKey: "s1"}) + + stats := p.Snapshot() + if stats.SessionStarts != 1 || stats.SessionEnds != 1 { + t.Fatalf("unexpected session stats: %+v", stats) + } + if stats.AfterToolCalls != 1 || stats.TotalToolDuration != 45*time.Millisecond { + t.Fatalf("unexpected after_tool_call stats: %+v", stats) + } +} + +func TestPolicyDemoPluginNoConfigNoEffect(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{}) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + toolEvent := &hooks.BeforeToolCallEvent{ToolName: "shell", Args: map[string]any{}, Channel: "telegram"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), toolEvent) + if toolEvent.Cancel { + t.Fatal("did not expect cancellation with empty config") + } + + msgEvent := &hooks.MessageSendingEvent{Content: "token=sk-abc123"} + pm.HookRegistry().TriggerMessageSending(context.Background(), msgEvent) + if msgEvent.Content != "token=sk-abc123" { + t.Fatalf("did not expect content rewrite, got %q", msgEvent.Content) + } +} + +func TestToIntRejectsInt64OverflowOn32Bit(t *testing.T) { + if strconv.IntSize != 32 { + t.Skip("overflow scenario is specific to 32-bit int") + } + if _, ok := toInt(int64(1 << 40)); ok { + t.Fatal("expected overflow conversion to fail on 32-bit int") + } +} diff --git a/pkg/plugin/manager.go b/pkg/plugin/manager.go new file mode 100644 index 0000000000..c63f67e68a --- /dev/null +++ b/pkg/plugin/manager.go @@ -0,0 +1,275 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package plugin + +import ( + "errors" + "fmt" + "slices" + "sort" + "strings" + "sync" + + "github.com/sipeed/picoclaw/pkg/hooks" +) + +// APIVersion identifies the compile-time plugin contract version. +const APIVersion = "v1alpha1" + +// SelectionInput controls plugin enable/disable resolution. +type SelectionInput struct { + DefaultEnabled bool + Enabled []string + Disabled []string +} + +// SelectionResult is the normalized output of plugin enable/disable resolution. +type SelectionResult struct { + EnabledNames []string + DisabledNames []string + UnknownEnabled []string + UnknownDisabled []string + Warnings []string +} + +// Plugin is the Phase-1 compile-time contract for PicoClaw extensions. +type Plugin interface { + Name() string + APIVersion() string + Register(registry *hooks.HookRegistry) error +} + +// PluginInfo describes plugin metadata for introspection APIs. +type PluginInfo struct { + Name string `json:"name"` + APIVersion string `json:"api_version"` + Status string `json:"status"` +} + +// PluginDescriptor optionally provides richer plugin metadata. +type PluginDescriptor interface { + Info() PluginInfo +} + +// NormalizePluginName normalizes plugin names for deterministic matching. +func NormalizePluginName(name string) string { + return strings.ToLower(strings.TrimSpace(name)) +} + +// ResolveSelection resolves final enabled/disabled plugin names deterministically. +func ResolveSelection(available []string, in SelectionInput) (SelectionResult, error) { + result := SelectionResult{} + + availableSet := make(map[string]struct{}, len(available)) + for _, name := range available { + normalized := NormalizePluginName(name) + if normalized == "" { + continue + } + availableSet[normalized] = struct{}{} + } + + enabledSet := make(map[string]struct{}, len(in.Enabled)) + for _, name := range in.Enabled { + normalized := NormalizePluginName(name) + if _, exists := enabledSet[normalized]; exists { + result.Warnings = append(result.Warnings, fmt.Sprintf("duplicate enabled plugin %q ignored", normalized)) + continue + } + enabledSet[normalized] = struct{}{} + } + + disabledSet := make(map[string]struct{}, len(in.Disabled)) + for _, name := range in.Disabled { + normalized := NormalizePluginName(name) + if _, exists := disabledSet[normalized]; exists { + result.Warnings = append(result.Warnings, fmt.Sprintf("duplicate disabled plugin %q ignored", normalized)) + continue + } + disabledSet[normalized] = struct{}{} + } + + for name := range enabledSet { + if _, ok := availableSet[name]; !ok { + result.UnknownEnabled = append(result.UnknownEnabled, name) + } + } + sort.Strings(result.UnknownEnabled) + + for name := range disabledSet { + if _, ok := availableSet[name]; !ok { + result.UnknownDisabled = append(result.UnknownDisabled, name) + } + } + sort.Strings(result.UnknownDisabled) + for _, name := range result.UnknownDisabled { + result.Warnings = append(result.Warnings, fmt.Sprintf("unknown disabled plugin %q ignored", name)) + } + + resolvedEnabled := make(map[string]struct{}, len(availableSet)) + if len(enabledSet) > 0 { + for name := range enabledSet { + if _, ok := availableSet[name]; !ok { + continue + } + if _, disabled := disabledSet[name]; disabled { + continue + } + resolvedEnabled[name] = struct{}{} + } + } else if in.DefaultEnabled { + for name := range availableSet { + if _, disabled := disabledSet[name]; disabled { + continue + } + resolvedEnabled[name] = struct{}{} + } + } + + for name := range resolvedEnabled { + result.EnabledNames = append(result.EnabledNames, name) + } + sort.Strings(result.EnabledNames) + + for name := range availableSet { + if _, enabled := resolvedEnabled[name]; enabled { + continue + } + result.DisabledNames = append(result.DisabledNames, name) + } + sort.Strings(result.DisabledNames) + + if len(result.UnknownEnabled) > 0 { + return result, fmt.Errorf("unknown enabled plugins: %s", strings.Join(result.UnknownEnabled, ", ")) + } + return result, nil +} + +// Manager owns a shared hook registry and loaded plugin metadata. +type Manager struct { + mu sync.RWMutex + registry *hooks.HookRegistry + names []string + plugins []Plugin + seen map[string]struct{} +} + +// NewManager creates an empty plugin manager with a fresh hook registry. +func NewManager() *Manager { + return &Manager{ + registry: hooks.NewHookRegistry(), + seen: make(map[string]struct{}), + } +} + +// HookRegistry returns the shared registry where plugins register hooks. +func (m *Manager) HookRegistry() *hooks.HookRegistry { + return m.registry +} + +// Names returns loaded plugin names in registration order. +func (m *Manager) Names() []string { + m.mu.RLock() + defer m.mu.RUnlock() + return slices.Clone(m.names) +} + +// DescribeAll returns plugin metadata in registration order. +func (m *Manager) DescribeAll() []PluginInfo { + m.mu.RLock() + defer m.mu.RUnlock() + + infos := make([]PluginInfo, 0, len(m.plugins)) + for i, p := range m.plugins { + fallbackName := "" + if i < len(m.names) { + fallbackName = m.names[i] + } + infos = append(infos, normalizePluginInfo(p, fallbackName)) + } + return infos +} + +// DescribeEnabled returns metadata for currently enabled plugins. +func (m *Manager) DescribeEnabled() []PluginInfo { + return m.DescribeAll() +} + +// Register loads one plugin into the shared hook registry. +func (m *Manager) Register(p Plugin) error { + if p == nil { + return errors.New("plugin is nil") + } + name := strings.TrimSpace(p.Name()) + if name == "" { + return errors.New("plugin name is required") + } + if got := strings.TrimSpace(p.APIVersion()); got != APIVersion { + if got == "" { + got = "" + } + return fmt.Errorf( + "plugin %q api version mismatch: got %s, want %s", + name, + got, + APIVersion, + ) + } + + m.mu.Lock() + defer m.mu.Unlock() + if _, exists := m.seen[name]; exists { + return fmt.Errorf("plugin %q already registered", name) + } + if err := p.Register(m.registry); err != nil { + return fmt.Errorf("register plugin %q: %w", name, err) + } + m.seen[name] = struct{}{} + m.names = append(m.names, name) + m.plugins = append(m.plugins, p) + return nil +} + +// RegisterAll loads plugins sequentially. +func (m *Manager) RegisterAll(plugins ...Plugin) error { + for _, p := range plugins { + if err := m.Register(p); err != nil { + return err + } + } + return nil +} + +func normalizePluginInfo(p Plugin, fallbackName string) PluginInfo { + info := PluginInfo{ + Name: strings.TrimSpace(fallbackName), + APIVersion: strings.TrimSpace(p.APIVersion()), + Status: "enabled", + } + if descriptor, ok := p.(PluginDescriptor); ok { + described := descriptor.Info() + if name := strings.TrimSpace(described.Name); name != "" { + info.Name = name + } + if version := strings.TrimSpace(described.APIVersion); version != "" { + info.APIVersion = version + } + if status := strings.TrimSpace(described.Status); status != "" { + info.Status = status + } + } + if info.Name == "" { + info.Name = strings.TrimSpace(p.Name()) + } + if info.APIVersion == "" { + info.APIVersion = APIVersion + } + if info.Status == "" { + info.Status = "enabled" + } + return info +} diff --git a/pkg/plugin/manager_test.go b/pkg/plugin/manager_test.go new file mode 100644 index 0000000000..c7423fd96d --- /dev/null +++ b/pkg/plugin/manager_test.go @@ -0,0 +1,374 @@ +package plugin + +import ( + "context" + "errors" + "slices" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/hooks" +) + +type testPlugin struct { + name string + apiVersion string + registerFn func(*hooks.HookRegistry) error +} + +func (p testPlugin) Name() string { + return p.name +} + +func (p testPlugin) Register(r *hooks.HookRegistry) error { + if p.registerFn != nil { + return p.registerFn(r) + } + return nil +} + +func (p testPlugin) APIVersion() string { + if p.apiVersion == "" { + return APIVersion + } + return p.apiVersion +} + +type descriptorTestPlugin struct { + testPlugin + info PluginInfo +} + +func (p descriptorTestPlugin) Info() PluginInfo { + return p.info +} + +func TestNewManager(t *testing.T) { + m := NewManager() + if m == nil { + t.Fatal("expected manager") + } + if m.HookRegistry() == nil { + t.Fatal("expected non-nil hook registry") + } + if len(m.Names()) != 0 { + t.Fatalf("expected empty names, got %v", m.Names()) + } +} + +func TestRegisterPluginAndTriggerHook(t *testing.T) { + m := NewManager() + called := false + p := testPlugin{ + name: "audit", + registerFn: func(r *hooks.HookRegistry) error { + r.OnSessionStart("audit-session", 0, func(_ context.Context, _ *hooks.SessionEvent) error { + called = true + return nil + }) + return nil + }, + } + + if err := m.Register(p); err != nil { + t.Fatalf("Register() error = %v", err) + } + if got := m.Names(); len(got) != 1 || got[0] != "audit" { + t.Fatalf("unexpected names: %v", got) + } + + m.HookRegistry().TriggerSessionStart(context.Background(), &hooks.SessionEvent{ + AgentID: "a1", + SessionKey: "s1", + }) + if !called { + t.Fatal("expected plugin hook to be called") + } +} + +func TestRegisterRejectsNilPlugin(t *testing.T) { + m := NewManager() + if err := m.Register(nil); err == nil { + t.Fatal("expected error for nil plugin") + } +} + +func TestRegisterRejectsEmptyName(t *testing.T) { + m := NewManager() + if err := m.Register(testPlugin{}); err == nil { + t.Fatal("expected error for empty name") + } +} + +func TestRegisterRejectsDuplicateName(t *testing.T) { + m := NewManager() + p := testPlugin{name: "dup"} + if err := m.Register(p); err != nil { + t.Fatalf("unexpected first register error: %v", err) + } + if err := m.Register(p); err == nil { + t.Fatal("expected duplicate name error") + } +} + +func TestRegisterPropagatesPluginError(t *testing.T) { + m := NewManager() + want := errors.New("register failed") + p := testPlugin{ + name: "bad", + registerFn: func(_ *hooks.HookRegistry) error { + return want + }, + } + err := m.Register(p) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, want) { + t.Fatalf("expected wrapped error %v, got %v", want, err) + } +} + +func TestRegisterRejectsPluginVersionMismatch(t *testing.T) { + m := NewManager() + p := testPlugin{ + name: "old-plugin", + apiVersion: "v0", + } + err := m.Register(p) + if err == nil { + t.Fatal("expected version mismatch error") + } +} + +func TestDescribeAll_UsesDescriptorWhenImplemented(t *testing.T) { + m := NewManager() + p := descriptorTestPlugin{ + testPlugin: testPlugin{name: "descriptor"}, + info: PluginInfo{ + Name: " descriptor-visible ", + APIVersion: " custom-v1 ", + Status: " active ", + }, + } + + if err := m.Register(p); err != nil { + t.Fatalf("Register() error = %v", err) + } + + got := m.DescribeAll() + want := []PluginInfo{ + { + Name: "descriptor-visible", + APIVersion: "custom-v1", + Status: "active", + }, + } + if !slices.Equal(got, want) { + t.Fatalf("DescribeAll() mismatch: got %v, want %v", got, want) + } +} + +func TestDescribeAll_FallsBackForPlainPlugin(t *testing.T) { + m := NewManager() + p := testPlugin{name: "plain"} + + if err := m.Register(p); err != nil { + t.Fatalf("Register() error = %v", err) + } + + got := m.DescribeAll() + want := []PluginInfo{ + { + Name: "plain", + APIVersion: APIVersion, + Status: "enabled", + }, + } + if !slices.Equal(got, want) { + t.Fatalf("DescribeAll() mismatch: got %v, want %v", got, want) + } +} + +func TestDescribeEnabled_MatchesDescribeAllForNow(t *testing.T) { + m := NewManager() + plain := testPlugin{name: "plain"} + described := descriptorTestPlugin{ + testPlugin: testPlugin{name: "described"}, + info: PluginInfo{ + Name: " described-visible ", + }, + } + + if err := m.RegisterAll(plain, described); err != nil { + t.Fatalf("RegisterAll() error = %v", err) + } + + all := m.DescribeAll() + enabled := m.DescribeEnabled() + if !slices.Equal(enabled, all) { + t.Fatalf("DescribeEnabled() mismatch: got %v, want %v", enabled, all) + } + + wantAll := []PluginInfo{ + { + Name: "plain", + APIVersion: APIVersion, + Status: "enabled", + }, + { + Name: "described-visible", + APIVersion: APIVersion, + Status: "enabled", + }, + } + if !slices.Equal(all, wantAll) { + t.Fatalf("DescribeAll() order/content mismatch: got %v, want %v", all, wantAll) + } +} + +func TestResolveSelection_DefaultEnabled(t *testing.T) { + result, err := ResolveSelection( + []string{"beta", "alpha", "gamma"}, + SelectionInput{ + DefaultEnabled: true, + Disabled: []string{"beta"}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"alpha", "gamma"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if !slices.Equal(result.DisabledNames, []string{"beta"}) { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } +} + +func TestResolveSelection_EnabledListOnly(t *testing.T) { + result, err := ResolveSelection( + []string{"a", "b", "c"}, + SelectionInput{ + DefaultEnabled: true, + Enabled: []string{"c", "a"}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"a", "c"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if !slices.Equal(result.DisabledNames, []string{"b"}) { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } +} + +func TestResolveSelection_DisabledWinsOverlap(t *testing.T) { + result, err := ResolveSelection( + []string{"a", "b", "c"}, + SelectionInput{ + Enabled: []string{"a", "b"}, + Disabled: []string{"b"}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"a"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if !slices.Equal(result.DisabledNames, []string{"b", "c"}) { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } +} + +func TestResolveSelection_UnknownEnabledFails(t *testing.T) { + result, err := ResolveSelection( + []string{"a"}, + SelectionInput{ + Enabled: []string{"missing"}, + }, + ) + if err == nil { + t.Fatal("expected error for unknown enabled plugin") + } + if !strings.Contains(err.Error(), "missing") { + t.Fatalf("expected error to mention unknown plugin, got %v", err) + } + if !slices.Equal(result.UnknownEnabled, []string{"missing"}) { + t.Fatalf("UnknownEnabled mismatch: got %v", result.UnknownEnabled) + } +} + +func TestResolveSelection_UnknownDisabledWarns(t *testing.T) { + result, err := ResolveSelection( + []string{"a"}, + SelectionInput{ + DefaultEnabled: true, + Disabled: []string{"missing"}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"a"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if len(result.DisabledNames) != 0 { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } + if !slices.Equal(result.UnknownDisabled, []string{"missing"}) { + t.Fatalf("UnknownDisabled mismatch: got %v", result.UnknownDisabled) + } + if !hasWarningSubstring(result.Warnings, `unknown disabled plugin "missing" ignored`) { + t.Fatalf("expected unknown disabled warning, got %v", result.Warnings) + } +} + +func TestResolveSelection_NormalizationAndDedupe(t *testing.T) { + result, err := ResolveSelection( + []string{" Alpha ", "beta", "gamma"}, + SelectionInput{ + Enabled: []string{"ALPHA", " alpha ", "BETA", "beta"}, + Disabled: []string{" beta", "BETA", "missing", " MISSING "}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"alpha"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if !slices.Equal(result.DisabledNames, []string{"beta", "gamma"}) { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } + if !slices.Equal(result.UnknownDisabled, []string{"missing"}) { + t.Fatalf("UnknownDisabled mismatch: got %v", result.UnknownDisabled) + } + if !hasWarningSubstring(result.Warnings, `duplicate enabled plugin "alpha" ignored`) { + t.Fatalf("expected duplicate enabled warning for alpha, got %v", result.Warnings) + } + if !hasWarningSubstring(result.Warnings, `duplicate enabled plugin "beta" ignored`) { + t.Fatalf("expected duplicate enabled warning for beta, got %v", result.Warnings) + } + if !hasWarningSubstring(result.Warnings, `duplicate disabled plugin "beta" ignored`) { + t.Fatalf("expected duplicate disabled warning for beta, got %v", result.Warnings) + } + if !hasWarningSubstring(result.Warnings, `duplicate disabled plugin "missing" ignored`) { + t.Fatalf("expected duplicate disabled warning for missing, got %v", result.Warnings) + } + if !hasWarningSubstring(result.Warnings, `unknown disabled plugin "missing" ignored`) { + t.Fatalf("expected unknown disabled warning, got %v", result.Warnings) + } +} + +func hasWarningSubstring(warnings []string, sub string) bool { + for _, warning := range warnings { + if strings.Contains(warning, sub) { + return true + } + } + return false +}