diff --git a/cmd/picoclaw/internal/agent/helpers.go b/cmd/picoclaw/internal/agent/helpers.go index f754abc652..f89c23795b 100644 --- a/cmd/picoclaw/internal/agent/helpers.go +++ b/cmd/picoclaw/internal/agent/helpers.go @@ -12,6 +12,7 @@ import ( "github.com/chzyer/readline" "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/pluginruntime" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/logger" @@ -51,6 +52,24 @@ func agentCmd(message, sessionKey, model string, debug bool) error { defer msgBus.Close() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + pluginsToEnable, pluginSummary, err := pluginruntime.ResolveConfiguredPlugins(cfg) + if err != nil { + return fmt.Errorf("error resolving configured plugins: %w", err) + } + if len(pluginsToEnable) > 0 { + if err := agentLoop.EnablePlugins(pluginsToEnable...); err != nil { + return fmt.Errorf("error enabling plugins: %w", err) + } + } + logger.InfoCF("agent", "Plugin selection resolved", + map[string]any{ + "plugins_enabled": pluginSummary.Enabled, + "plugins_disabled": pluginSummary.Disabled, + "plugins_unknown_enabled": pluginSummary.UnknownEnabled, + "plugins_unknown_disabled": pluginSummary.UnknownDisabled, + "plugins_warnings": pluginSummary.Warnings, + }) + // Print agent startup info (only for interactive mode) startupInfo := agentLoop.GetStartupInfo() logger.InfoCF("agent", "Agent initialized", diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 747f7d44e9..e3467cb267 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -10,6 +10,7 @@ import ( "time" "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/pluginruntime" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -61,6 +62,23 @@ func gatewayCmd(debug bool) error { msgBus := bus.NewMessageBus() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + pluginsToEnable, pluginSummary, err := pluginruntime.ResolveConfiguredPlugins(cfg) + if err != nil { + return fmt.Errorf("error resolving configured plugins: %w", err) + } + if len(pluginsToEnable) > 0 { + if enableErr := agentLoop.EnablePlugins(pluginsToEnable...); enableErr != nil { + return fmt.Errorf("error enabling plugins: %w", enableErr) + } + } + logger.InfoCF("agent", "Plugin selection resolved", + map[string]any{ + "plugins_enabled": pluginSummary.Enabled, + "plugins_disabled": pluginSummary.Disabled, + "plugins_unknown_enabled": pluginSummary.UnknownEnabled, + "plugins_unknown_disabled": pluginSummary.UnknownDisabled, + "plugins_warnings": pluginSummary.Warnings, + }) // Print agent startup info fmt.Println("\nšŸ“¦ Agent Status:") diff --git a/cmd/picoclaw/internal/plugin/command.go b/cmd/picoclaw/internal/plugin/command.go new file mode 100644 index 0000000000..b156aa1018 --- /dev/null +++ b/cmd/picoclaw/internal/plugin/command.go @@ -0,0 +1,19 @@ +package plugin + +import "github.com/spf13/cobra" + +func NewPluginCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "plugin", + Short: "Inspect and validate plugins", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + return cmd.Help() + }, + } + + cmd.AddCommand(newListCommand()) + cmd.AddCommand(newLintSubcommand()) + + return cmd +} diff --git a/cmd/picoclaw/internal/plugin/command_test.go b/cmd/picoclaw/internal/plugin/command_test.go new file mode 100644 index 0000000000..b457ad8471 --- /dev/null +++ b/cmd/picoclaw/internal/plugin/command_test.go @@ -0,0 +1,44 @@ +package plugin + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewPluginCommand(t *testing.T) { + cmd := NewPluginCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "plugin", cmd.Use) + assert.Equal(t, "Inspect and validate plugins", cmd.Short) + + assert.True(t, cmd.HasSubCommands()) + assert.True(t, cmd.HasAvailableSubCommands()) + + assert.False(t, cmd.HasFlags()) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) + + allowedCommands := []string{ + "list", + "lint", + } + + subcommands := cmd.Commands() + assert.Len(t, subcommands, len(allowedCommands)) + + for _, subcmd := range subcommands { + found := slices.Contains(allowedCommands, subcmd.Name()) + assert.True(t, found, "unexpected subcommand %q", subcmd.Name()) + + assert.False(t, subcmd.Hidden) + } +} diff --git a/cmd/picoclaw/internal/plugin/lint.go b/cmd/picoclaw/internal/plugin/lint.go new file mode 100644 index 0000000000..c31b87d47c --- /dev/null +++ b/cmd/picoclaw/internal/plugin/lint.go @@ -0,0 +1,41 @@ +package plugin + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/pluginruntime" + "github.com/sipeed/picoclaw/pkg/config" +) + +func newLintSubcommand() *cobra.Command { + configPath := internal.GetConfigPath() + + cmd := &cobra.Command{ + Use: "lint", + Short: "Lint plugin configuration", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg, err := config.LoadConfig(configPath) + if err != nil { + return fmt.Errorf("error loading config: %w", err) + } + + if _, _, err := pluginruntime.ResolveConfiguredPlugins(cfg); err != nil { + return fmt.Errorf("invalid plugin config: %w", err) + } + + if _, err := fmt.Fprintln(cmd.OutOrStdout(), "plugin config lint: ok"); err != nil { + return err + } + + return nil + }, + } + + cmd.Flags().StringVar(&configPath, "config", internal.GetConfigPath(), "Path to config file") + + return cmd +} diff --git a/cmd/picoclaw/internal/plugin/lint_test.go b/cmd/picoclaw/internal/plugin/lint_test.go new file mode 100644 index 0000000000..0e007983e8 --- /dev/null +++ b/cmd/picoclaw/internal/plugin/lint_test.go @@ -0,0 +1,73 @@ +package plugin + +import ( + "bytes" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestNewLintSubcommand(t *testing.T) { + cmd := newLintSubcommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "lint", cmd.Use) + assert.Equal(t, "Lint plugin configuration", cmd.Short) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.False(t, cmd.HasSubCommands()) + assert.True(t, cmd.HasFlags()) + + configFlag := cmd.Flags().Lookup("config") + require.NotNil(t, configFlag) + assert.Equal(t, internal.GetConfigPath(), configFlag.DefValue) +} + +func TestPluginLint_UnknownEnabledExitNonZero(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Plugins = config.PluginsConfig{ + DefaultEnabled: false, + Enabled: []string{"missing-plugin"}, + Disabled: []string{}, + } + require.NoError(t, config.SaveConfig(configPath, cfg)) + + cmd := NewPluginCommand() + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + cmd.SetArgs([]string{"lint", "--config", configPath}) + + err := cmd.Execute() + require.Error(t, err) + assert.ErrorContains(t, err, "missing-plugin") +} + +func TestPluginLint_ValidConfigExitZero(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Plugins = config.PluginsConfig{ + DefaultEnabled: false, + Enabled: []string{}, + Disabled: []string{}, + } + require.NoError(t, config.SaveConfig(configPath, cfg)) + + out := &bytes.Buffer{} + cmd := NewPluginCommand() + cmd.SetOut(out) + cmd.SetErr(&bytes.Buffer{}) + cmd.SetArgs([]string{"lint", "--config", configPath}) + + err := cmd.Execute() + require.NoError(t, err) + assert.Contains(t, out.String(), "plugin config lint: ok") +} diff --git a/cmd/picoclaw/internal/plugin/list.go b/cmd/picoclaw/internal/plugin/list.go new file mode 100644 index 0000000000..b1d0e5e0f5 --- /dev/null +++ b/cmd/picoclaw/internal/plugin/list.go @@ -0,0 +1,110 @@ +package plugin + +import ( + "encoding/json" + "fmt" + "io" + "sort" + + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/pluginruntime" +) + +const ( + formatText = "text" + formatJSON = "json" +) + +type pluginStatus struct { + Name string `json:"name"` + Status string `json:"status"` +} + +func newListCommand() *cobra.Command { + var format string + + cmd := &cobra.Command{ + Use: "list", + Short: "List configured plugin status", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + if format != formatText && format != formatJSON { + return fmt.Errorf("invalid value for --format: %q (allowed: %s, %s)", format, formatText, formatJSON) + } + + cfg, err := internal.LoadConfig() + if err != nil { + return fmt.Errorf("error loading config: %w", err) + } + + _, summary, err := pluginruntime.ResolveConfiguredPlugins(cfg) + statuses := buildPluginStatuses(summary) + + if outputErr := renderPluginStatuses(cmd.OutOrStdout(), format, statuses); outputErr != nil { + return outputErr + } + if err != nil { + return fmt.Errorf("error resolving configured plugins: %w", err) + } + + return nil + }, + } + + cmd.Flags().StringVar(&format, "format", formatText, "Output format (text|json)") + + return cmd +} + +func buildPluginStatuses(summary pluginruntime.Summary) []pluginStatus { + total := len(summary.Enabled) + + len(summary.Disabled) + + len(summary.UnknownEnabled) + + len(summary.UnknownDisabled) + statuses := make([]pluginStatus, 0, total) + + for _, name := range summary.Enabled { + statuses = append(statuses, pluginStatus{Name: name, Status: "enabled"}) + } + for _, name := range summary.Disabled { + statuses = append(statuses, pluginStatus{Name: name, Status: "disabled"}) + } + for _, name := range summary.UnknownEnabled { + statuses = append(statuses, pluginStatus{Name: name, Status: "unknown-enabled"}) + } + for _, name := range summary.UnknownDisabled { + statuses = append(statuses, pluginStatus{Name: name, Status: "unknown-disabled"}) + } + + sort.Slice(statuses, func(i, j int) bool { + if statuses[i].Name == statuses[j].Name { + return statuses[i].Status < statuses[j].Status + } + return statuses[i].Name < statuses[j].Name + }) + + return statuses +} + +func renderPluginStatuses(w io.Writer, format string, statuses []pluginStatus) error { + switch format { + case formatText: + if _, err := fmt.Fprintln(w, "NAME\tSTATUS"); err != nil { + return err + } + for _, status := range statuses { + if _, err := fmt.Fprintf(w, "%s\t%s\n", status.Name, status.Status); err != nil { + return err + } + } + return nil + case formatJSON: + encoder := json.NewEncoder(w) + encoder.SetIndent("", " ") + return encoder.Encode(statuses) + default: + return fmt.Errorf("invalid value for --format: %q (allowed: %s, %s)", format, formatText, formatJSON) + } +} diff --git a/cmd/picoclaw/internal/plugin/list_test.go b/cmd/picoclaw/internal/plugin/list_test.go new file mode 100644 index 0000000000..d541d81cd2 --- /dev/null +++ b/cmd/picoclaw/internal/plugin/list_test.go @@ -0,0 +1,61 @@ +package plugin + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/pluginruntime" +) + +func TestNewListSubcommand(t *testing.T) { + cmd := newListCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "list", cmd.Use) + assert.Equal(t, "List configured plugin status", cmd.Short) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.False(t, cmd.HasSubCommands()) + assert.True(t, cmd.HasFlags()) + + assert.Len(t, cmd.Aliases, 0) + + formatFlag := cmd.Flags().Lookup("format") + require.NotNil(t, formatFlag) + assert.Equal(t, formatText, formatFlag.DefValue) +} + +func TestNewListSubcommand_RejectsUnknownFormat(t *testing.T) { + cmd := newListCommand() + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + cmd.SetArgs([]string{"--format", "yaml"}) + + err := cmd.Execute() + require.Error(t, err) + assert.ErrorContains(t, err, `invalid value for --format: "yaml"`) +} + +func TestBuildPluginStatuses_DeterministicOrder(t *testing.T) { + summary := pluginruntime.Summary{ + Enabled: []string{"beta"}, + Disabled: []string{"alpha"}, + UnknownEnabled: []string{"zeta"}, + UnknownDisabled: []string{"eta"}, + } + + got := buildPluginStatuses(summary) + + assert.Equal(t, []pluginStatus{ + {Name: "alpha", Status: "disabled"}, + {Name: "beta", Status: "enabled"}, + {Name: "eta", Status: "unknown-disabled"}, + {Name: "zeta", Status: "unknown-enabled"}, + }, got) +} diff --git a/cmd/picoclaw/internal/pluginruntime/bootstrap.go b/cmd/picoclaw/internal/pluginruntime/bootstrap.go new file mode 100644 index 0000000000..4cfcffde2c --- /dev/null +++ b/cmd/picoclaw/internal/pluginruntime/bootstrap.go @@ -0,0 +1,64 @@ +package pluginruntime + +import ( + "fmt" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/plugin/builtin" +) + +type Summary struct { + Enabled []string + Disabled []string + UnknownEnabled []string + UnknownDisabled []string + Warnings []string +} + +func ResolveConfiguredPlugins(cfg *config.Config) ([]plugin.Plugin, Summary, error) { + if cfg == nil { + return nil, Summary{}, fmt.Errorf("config is nil") + } + + resolved, err := plugin.ResolveSelection( + builtin.Names(), + plugin.SelectionInput{ + DefaultEnabled: cfg.Plugins.DefaultEnabled, + Enabled: cfg.Plugins.Enabled, + Disabled: cfg.Plugins.Disabled, + }, + ) + + summary := Summary{ + Enabled: resolved.EnabledNames, + Disabled: resolved.DisabledNames, + UnknownEnabled: resolved.UnknownEnabled, + UnknownDisabled: resolved.UnknownDisabled, + Warnings: resolved.Warnings, + } + if err != nil { + return nil, summary, err + } + + catalog := builtin.Catalog() + normalizedCatalog := make(map[string]builtin.Factory, len(catalog)) + for name, factory := range catalog { + normalizedCatalog[plugin.NormalizePluginName(name)] = factory + } + + instances := make([]plugin.Plugin, 0, len(resolved.EnabledNames)) + for _, name := range resolved.EnabledNames { + factory, ok := normalizedCatalog[name] + if !ok { + return nil, summary, fmt.Errorf("builtin plugin %q has no factory", name) + } + instance := factory() + if instance == nil { + return nil, summary, fmt.Errorf("builtin plugin %q factory returned nil", name) + } + instances = append(instances, instance) + } + + return instances, summary, nil +} diff --git a/cmd/picoclaw/internal/pluginruntime/bootstrap_test.go b/cmd/picoclaw/internal/pluginruntime/bootstrap_test.go new file mode 100644 index 0000000000..93b2be2b37 --- /dev/null +++ b/cmd/picoclaw/internal/pluginruntime/bootstrap_test.go @@ -0,0 +1,106 @@ +package pluginruntime + +import ( + "slices" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/plugin/builtin" +) + +func TestResolveConfiguredPlugins_UnknownEnabledReturnsError(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Plugins = config.PluginsConfig{ + DefaultEnabled: false, + Enabled: []string{"missing-plugin"}, + Disabled: []string{}, + } + + instances, summary, err := ResolveConfiguredPlugins(cfg) + if err == nil { + t.Fatal("expected error for unknown enabled plugin") + } + if !strings.Contains(err.Error(), "missing-plugin") { + t.Fatalf("expected error to mention missing plugin, got %v", err) + } + if len(instances) != 0 { + t.Fatalf("expected no instances on error, got %d", len(instances)) + } + if !slices.Equal(summary.UnknownEnabled, []string{"missing-plugin"}) { + t.Fatalf("UnknownEnabled mismatch: got %v", summary.UnknownEnabled) + } +} + +func TestResolveConfiguredPlugins_ReturnsDeterministicInstances(t *testing.T) { + available := builtin.Names() + if len(available) == 0 { + t.Fatal("expected at least one builtin plugin") + } + + enabled := slices.Clone(available) + slices.Reverse(enabled) + + cfg := config.DefaultConfig() + cfg.Plugins = config.PluginsConfig{ + DefaultEnabled: false, + Enabled: enabled, + Disabled: []string{}, + } + + instances, summary, err := ResolveConfiguredPlugins(cfg) + if err != nil { + t.Fatalf("ResolveConfiguredPlugins() error = %v", err) + } + + gotNames := pluginNames(instances) + if !slices.Equal(gotNames, available) { + t.Fatalf("plugin names mismatch: got %v, want %v", gotNames, available) + } + if !slices.Equal(summary.Enabled, available) { + t.Fatalf("summary enabled mismatch: got %v, want %v", summary.Enabled, available) + } +} + +func TestResolveConfiguredPlugins_UnknownDisabledWarns(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Plugins = config.PluginsConfig{ + DefaultEnabled: true, + Enabled: []string{}, + Disabled: []string{"missing-plugin"}, + } + + instances, summary, err := ResolveConfiguredPlugins(cfg) + if err != nil { + t.Fatalf("ResolveConfiguredPlugins() error = %v", err) + } + + expectedEnabled := builtin.Names() + if !slices.Equal(pluginNames(instances), expectedEnabled) { + t.Fatalf("plugin names mismatch: got %v, want %v", pluginNames(instances), expectedEnabled) + } + if !slices.Equal(summary.UnknownDisabled, []string{"missing-plugin"}) { + t.Fatalf("UnknownDisabled mismatch: got %v", summary.UnknownDisabled) + } + if !hasWarningSubstring(summary.Warnings, `unknown disabled plugin "missing-plugin" ignored`) { + t.Fatalf("expected unknown disabled warning, got %v", summary.Warnings) + } +} + +func pluginNames(instances []plugin.Plugin) []string { + names := make([]string, 0, len(instances)) + for _, instance := range instances { + names = append(names, instance.Name()) + } + return names +} + +func hasWarningSubstring(warnings []string, sub string) bool { + for _, warning := range warnings { + if strings.Contains(warning, sub) { + return true + } + } + return false +} diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 6db69c9902..4db5603094 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -19,6 +19,7 @@ import ( "github.com/sipeed/picoclaw/cmd/picoclaw/internal/gateway" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/migrate" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/onboard" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/plugin" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/skills" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/status" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/version" @@ -41,6 +42,7 @@ func NewPicoclawCommand() *cobra.Command { status.NewStatusCommand(), cron.NewCronCommand(), migrate.NewMigrateCommand(), + plugin.NewPluginCommand(), skills.NewSkillsCommand(), version.NewVersionCommand(), ) diff --git a/cmd/picoclaw/main_test.go b/cmd/picoclaw/main_test.go index 3740ba358e..2d9a84efe6 100644 --- a/cmd/picoclaw/main_test.go +++ b/cmd/picoclaw/main_test.go @@ -39,6 +39,7 @@ func TestNewPicoclawCommand(t *testing.T) { "gateway", "migrate", "onboard", + "plugin", "skills", "status", "version", diff --git a/docs/plans/2026-02-28-plugin-system-phase2-phase3-design.md b/docs/plans/2026-02-28-plugin-system-phase2-phase3-design.md new file mode 100644 index 0000000000..79b2b98a35 --- /dev/null +++ b/docs/plans/2026-02-28-plugin-system-phase2-phase3-design.md @@ -0,0 +1,240 @@ +# PR #473 Phase 2/3 Rethink + +## What Changed + +The previous draft mixed control-plane work with schema-heavy plugin config too early. +This rethink narrows scope so implementation matches the current codebase: + +- Phase 2 is only plugin selection and runtime wiring. +- Phase 3 is introspection, linting, and optional plugin-specific settings. +- `plugin.Plugin` stays unchanged in both phases. + +## Hard Scope + +- In-process built-in plugins only. +- No runtime loader, no dynamic module loading, no hot reload. +- Runtime loading remains Phase 4 and must reuse Phase 2/3 abstractions. + +## Baseline + +- Baseline is PR #473 phase-0/phase-1 behavior (`pkg/plugin`, `pkg/hooks`, `pkg/agent/loop.go`). +- Existing deployments without `plugins` config must keep the same effective behavior. + +## Implemented Snapshot + +Implemented in current Phase 2/3 scope: + +- Phase 2: +- typed plugin config schema with `plugins.default_enabled`, `plugins.enabled`, `plugins.disabled`. +- deterministic plugin resolver. +- startup wiring in both `agent` and `gateway` paths. +- Phase 3: +- plugin metadata introspection in manager APIs for built-ins. +- CLI support for listing and linting plugin config (`list` output is `name` + `status`). + +Command examples: + +```bash +picoclaw plugin list +picoclaw plugin list --format json +picoclaw plugin lint --config ~/.picoclaw/config.json +``` + +Precision note: + +- No dynamic runtime plugin loading/hot reload. +- Plugin-specific `settings` remain optional future work and are not required by the implemented Phase 2 selection plane. + +## Phase 2: Selection Plane (Minimal, Deterministic) + +### Goal + +Make plugin enable/disable operational from config with deterministic behavior and fail-fast handling. + +### Config (Phase 2 only) + +Project-facing examples should follow the repo default config format (JSON). + +```json +{ + "plugins": { + "default_enabled": true, + "enabled": ["policy-demo"], + "disabled": ["legacy_policy"] + } +} +``` + +No plugin-specific `settings` in Phase 2. + +### Resolution Rules (authoritative) + +For each built-in plugin name in sorted order: + +1. normalize names (`trim`, `lowercase`) for matching. +2. if in `disabled`, mark disabled. +3. else if `enabled` list is non-empty, enable only if listed. +4. else enable only if `default_enabled=true`. + +### Error Policy (Phase 2) + +- unknown name in `enabled`: startup error. +- unknown name in `disabled`: warning only. +- duplicates after normalization: dedupe and warn. +- overlap between `enabled` and `disabled`: disabled wins. + +### Required Code Changes + +- `pkg/config/config.go`, `pkg/config/defaults.go` + - add typed `plugins` block (`default_enabled`, `enabled`, `disabled`). +- `pkg/plugin/manager.go` + - add built-in registry map and deterministic resolver. + - expose resolution result buckets (`enabled`, `disabled`, `unknown`). +- `cmd/picoclaw/internal/agent/helpers.go` + - resolve plugins from config and wire into `loop.EnablePlugins(...)`. +- `cmd/picoclaw/internal/gateway/helpers.go` + - same as agent path. +- `pkg/agent/loop.go` + - keep existing plugin interface and lifecycle. + - add startup diagnostics with final resolved plugin names. + +### PR Slicing (Review-Friendly) + +Keep Phase 2 and Phase 3 as separate PR series. + +1. Phase 2 PR-A: config schema + resolver only. +2. Phase 2 PR-B: agent/gateway startup wiring + startup diagnostics. +3. Phase 2 PR-C: tests + docs updates. +4. Phase 3 PR-A: manager introspection + metadata side interface. +5. Phase 3 PR-B: `plugin list` and `plugin lint` commands. +6. Phase 3 PR-C: observability polishing + integration tests + docs. + +### Phase 2 Acceptance Gate + +- No `plugins` block: behavior matches baseline. +- Unknown name in `enabled` fails startup with actionable message. +- Resolution order and result are deterministic. +- Entry points actually wire resolved plugins (no silent no-op). +- Startup logs show enabled/disabled plugin sets. + +### Phase 2 Test Matrix + +1. No `plugins` block. +- Expected: same enabled plugin set as baseline. +- Expected startup summary keys: `plugins_enabled`, `plugins_disabled`, `plugins_unknown_enabled`, `plugins_unknown_disabled`, `plugins_warnings`. +2. `enabled=["policy-demo"]`, empty `disabled`. +- Expected: only `policy-demo` loaded. +- Expected startup summary key examples: `plugins_enabled=["policy-demo"]`, `plugins_disabled=[]`. +3. `disabled=["policy-demo"]`, empty `enabled`. +- Expected: `policy-demo` not loaded. +- Expected startup summary key examples: `plugins_enabled=[]`, `plugins_disabled=["policy-demo"]`. +4. Overlap: `enabled=["policy-demo"]`, `disabled=["policy-demo"]`. +- Expected: plugin disabled (disabled wins). +- Expected startup summary key examples: `plugins_enabled=[]`, `plugins_disabled=["policy-demo"]`. +5. Unknown in `enabled`: `enabled=["not_exists"]`. +- Expected: startup fails. +- Expected error text includes: `unknown plugin in enabled`. +6. Unknown in `disabled`: `disabled=["not_exists"]`. +- Expected: startup continues with warning. +- Expected warning text includes: `unknown plugin in disabled`. +7. Duplicates/case variants: `enabled=["Policy_Demo","policy_demo"]`. +- Expected: deduped after normalization and warning emitted. +- Expected warning text includes: `duplicate plugin name after normalization`. + +## Phase 3: Introspection Plane (DX + Validation) + +### Goal + +Make plugin state inspectable and config validation review-friendly. + +### Capabilities + +- Manager introspection: + - `DescribeAll() []PluginInfo` + - `DescribeEnabled() []PluginInfo` +- Optional metadata side interface for plugins: + +```go +type PluginDescriptor interface { + Info() PluginInfo +} +``` + +- CLI: + - `picoclaw plugin list` (text/json). + - `picoclaw plugin lint --config `. +- Diagnostics (implemented now): + - startup summary fields emitted by entrypoints: + - `plugins_enabled` + - `plugins_disabled` + - `plugins_unknown_enabled` + - `plugins_unknown_disabled` + - `plugins_warnings` +- Diagnostics (deferred): + - per-hook invocation outcome fields (`plugin`, `hook`, `result`, `duration_ms`). + +### Optional Phase 3 Extension + +If needed after list/lint lands, introduce plugin-specific `settings` with strict schema validation. +This is explicitly Phase 3, not Phase 2. + +### Phase 3 Acceptance Gate + +- `plugin list` output is stable in text and JSON with fields: `name`, `status`. +- `plugin lint` returns non-zero on invalid plugin names/config. +- Startup diagnostics include plugin resolution summary fields listed above. +- Tests cover one disabled path and one lint failure path (per current command contracts). + +### Phase 3 Test Matrix + +1. `picoclaw plugin list` text output. +- Expected: deterministic ordering and fields: `name`, `status`. +2. `picoclaw plugin list --format json`. +- Expected: stable JSON schema and deterministic ordering (`name`, `status` only). +3. `picoclaw plugin lint --config ` valid config. +- Expected: exit code `0`. +4. `picoclaw plugin lint --config ` invalid plugin name. +- Expected: non-zero exit. +- Expected error text includes: `unknown plugin`. + +## Rollback Runbook + +1. Revert to baseline behavior. +- Action: remove `plugins` block from config and restart process. +- Success signal: startup summaries return to expected enabled/disabled sets. +2. Recover from bad selection config. +- Action: clear `plugins.enabled` and `plugins.disabled`, restart. +- Success signal: startup summaries show valid resolution and no startup error. +3. Recover from Phase 3 command regressions. +- Action: disable plugin command surface with feature flag and restart. +- Success signal: `plugin` command group hidden/disabled in CLI help. +4. Incident confirmation checks. +- Verify startup logs include: +- `plugins_enabled` +- `plugins_disabled` +- `plugins_unknown_enabled` +- `plugins_unknown_disabled` +- `plugins_warnings` + +## Why This Is More Sound + +- Matches current interfaces (`EnablePlugins` with concrete plugin instances). +- Avoids premature schema coupling before metadata/lint tooling exists. +- Eliminates silent rollout risk by making entrypoint wiring a Phase 2 gate. +- Keeps a clean migration path to Phase 4 runtime sources. + +## External Alignment (Informational) + +This phase split follows a common Go OSS progression: +- compile-time plugin selection first +- discovery/validation CLI second +- runtime loader and trust model last + +These references are context only (not proof of direct feature parity with PicoClaw implementation): + +Reference patterns: +- Go `plugin` package caveats: `https://pkg.go.dev/plugin` +- HashiCorp `go-plugin` runtime model: `https://pkg.go.dev/github.com/hashicorp/go-plugin` +- module listing/validation command patterns (Caddy/Terraform style): + - `https://caddyserver.com/docs/command-line` + - `https://developer.hashicorp.com/terraform/cli/commands/providers/schema` diff --git a/docs/plans/2026-03-01-plugin-system-phase2-phase3-implementation.md b/docs/plans/2026-03-01-plugin-system-phase2-phase3-implementation.md new file mode 100644 index 0000000000..34436cab42 --- /dev/null +++ b/docs/plans/2026-03-01-plugin-system-phase2-phase3-implementation.md @@ -0,0 +1,576 @@ +# Plugin System Phase 2/3 Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. +> **Parity note (current stack):** This is a historical execution plan. For the shipped contract in PRs #936-#939, treat `docs/plugin-system-roadmap.md` as authoritative for current CLI/logging behavior. + +**Goal:** Implement Phase 2 (config-driven plugin selection + runtime wiring) and Phase 3 (plugin introspection + CLI list/lint + diagnostics) with small, reviewable PRs. + +**Architecture:** Keep in-process compile-time plugins as the only runtime model. Build a deterministic selection plane first (`config` + `plugin` resolver + bootstrap wiring), then add a non-breaking introspection plane (`PluginInfo`, list/lint commands) without changing `plugin.Plugin` contract. Use TDD for each slice and keep changes split into small commits. + +**Tech Stack:** Go, Cobra CLI, Testify, existing PicoClaw `pkg/config`, `pkg/plugin`, `pkg/agent`, `cmd/picoclaw/internal/*`. + +--- + +## Preflight Notes + +- Execute this plan inside the dedicated worktree created during design/brainstorming. +- Use `@test-driven-development` for every task (`red -> green -> refactor`). +- Before final handoff, use `@verification-before-completion`. +- Before opening PRs, use `@requesting-code-review`. + +### Task 1: Add Plugin Config Schema and Defaults (Phase 2) + +**Files:** +- Modify: `pkg/config/config.go` +- Modify: `pkg/config/defaults.go` +- Test: `pkg/config/config_test.go` + +**Step 1: Write the failing test** + +```go +func TestDefaultConfig_PluginsDefaults(t *testing.T) { + cfg := DefaultConfig() + if !cfg.Plugins.DefaultEnabled { + t.Fatal("plugins.default_enabled should default to true") + } + if len(cfg.Plugins.Enabled) != 0 || len(cfg.Plugins.Disabled) != 0 { + t.Fatal("plugins enabled/disabled should default empty") + } +} + +func TestConfig_PluginsJSONUnmarshal(t *testing.T) { + cfg := DefaultConfig() + err := json.Unmarshal([]byte(`{"plugins":{"default_enabled":false,"enabled":["policy-demo"],"disabled":["x"]}}`), cfg) + if err != nil { + t.Fatalf("unmarshal: %v", err) + } + if cfg.Plugins.DefaultEnabled { + t.Fatal("expected default_enabled=false from JSON") + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `go test ./pkg/config -run 'TestDefaultConfig_PluginsDefaults|TestConfig_PluginsJSONUnmarshal' -v` +Expected: FAIL with `cfg.Plugins undefined`. + +**Step 3: Write minimal implementation** + +```go +type PluginsConfig struct { + DefaultEnabled bool `json:"default_enabled"` + Enabled []string `json:"enabled,omitempty"` + Disabled []string `json:"disabled,omitempty"` +} + +type Config struct { + // ...existing fields... + Plugins PluginsConfig `json:"plugins,omitempty"` +} +``` + +```go +Plugins: PluginsConfig{ + DefaultEnabled: true, + Enabled: []string{}, + Disabled: []string{}, +}, +``` + +**Step 4: Run test to verify it passes** + +Run: `go test ./pkg/config -run 'TestDefaultConfig_PluginsDefaults|TestConfig_PluginsJSONUnmarshal' -v` +Expected: PASS. + +**Step 5: Commit** + +```bash +git add pkg/config/config.go pkg/config/defaults.go pkg/config/config_test.go +git commit -m "feat(config): add plugins selection config schema" +``` + +### Task 2: Build Deterministic Selection Resolver (Phase 2) + +**Files:** +- Modify: `pkg/plugin/manager.go` +- Test: `pkg/plugin/manager_test.go` + +**Step 1: Write the failing test** + +```go +func TestResolveSelection_DefaultEnabled(t *testing.T) {} +func TestResolveSelection_EnabledListOnly(t *testing.T) {} +func TestResolveSelection_DisabledWinsOverlap(t *testing.T) {} +func TestResolveSelection_UnknownEnabledFails(t *testing.T) {} +func TestResolveSelection_UnknownDisabledWarns(t *testing.T) {} +func TestResolveSelection_NormalizeAndDedupe(t *testing.T) {} +``` + +In each test, assert deterministic sorted resolution and expected error/warning behavior. + +**Step 2: Run test to verify it fails** + +Run: `go test ./pkg/plugin -run 'TestResolveSelection_' -v` +Expected: FAIL with undefined resolver types/functions. + +**Step 3: Write minimal implementation** + +```go +type SelectionInput struct { + DefaultEnabled bool + Enabled []string + Disabled []string +} + +type SelectionResult struct { + EnabledNames []string + DisabledNames []string + UnknownEnabled []string + UnknownDisabled []string + Warnings []string +} + +func NormalizePluginName(s string) string { /* strings.TrimSpace + strings.ToLower */ } +func ResolveSelection(available []string, in SelectionInput) (SelectionResult, error) { /* deterministic rules */ } +``` + +Rules implemented exactly: +- unknown in `enabled` => error +- unknown in `disabled` => warning bucket +- overlap => disabled wins +- sorted deterministic output + +**Step 4: Run test to verify it passes** + +Run: `go test ./pkg/plugin -run 'TestResolveSelection_' -v` +Expected: PASS. + +**Step 5: Commit** + +```bash +git add pkg/plugin/manager.go pkg/plugin/manager_test.go +git commit -m "feat(plugin): add deterministic plugin selection resolver" +``` + +### Task 3: Add Built-in Plugin Catalog Without Import Cycles (Phase 2) + +**Files:** +- Create: `pkg/plugin/builtin/catalog.go` +- Test: `pkg/plugin/builtin/catalog_test.go` + +**Step 1: Write the failing test** + +```go +func TestCatalogContainsPolicyDemo(t *testing.T) { + c := Catalog() + fn, ok := c["policy-demo"] + if !ok { + t.Fatal("expected policy-demo in builtin catalog") + } + if fn() == nil { + t.Fatal("expected non-nil plugin instance") + } +} +``` + +**Step 2: Run test to verify it fails** + +Run: `go test ./pkg/plugin/builtin -run TestCatalogContainsPolicyDemo -v` +Expected: FAIL with package/file missing. + +**Step 3: Write minimal implementation** + +```go +package builtin + +import ( + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/plugin/demoplugin" +) + +type Factory func() plugin.Plugin + +func Catalog() map[string]Factory { + return map[string]Factory{ + "policy-demo": func() plugin.Plugin { + return demoplugin.NewPolicyDemoPlugin(demoplugin.PolicyDemoConfig{}) + }, + } +} + +func Names() []string { /* return sorted keys */ } +``` + +**Step 4: Run test to verify it passes** + +Run: `go test ./pkg/plugin/builtin -run TestCatalogContainsPolicyDemo -v` +Expected: PASS. + +**Step 5: Commit** + +```bash +git add pkg/plugin/builtin/catalog.go pkg/plugin/builtin/catalog_test.go +git commit -m "feat(plugin): add builtin plugin catalog package" +``` + +### Task 4: Add Bootstrap Resolver Module for Agent/Gateway (Phase 2) + +**Files:** +- Create: `cmd/picoclaw/internal/pluginruntime/bootstrap.go` +- Test: `cmd/picoclaw/internal/pluginruntime/bootstrap_test.go` + +**Step 1: Write the failing test** + +```go +func TestResolveConfiguredPlugins_UnknownEnabledReturnsError(t *testing.T) {} +func TestResolveConfiguredPlugins_ReturnsDeterministicInstances(t *testing.T) {} +func TestResolveConfiguredPlugins_UnknownDisabledWarns(t *testing.T) {} +``` + +**Step 2: Run test to verify it fails** + +Run: `go test ./cmd/picoclaw/internal/pluginruntime -run 'TestResolveConfiguredPlugins_' -v` +Expected: FAIL with package/file missing. + +**Step 3: Write minimal implementation** + +```go +type Summary struct { + Enabled []string + Disabled []string + UnknownEnabled []string + UnknownDisabled []string + Warnings []string +} + +func ResolveConfiguredPlugins(cfg *config.Config) ([]plugin.Plugin, Summary, error) { + // 1) get catalog names from builtin.Names() + // 2) call plugin.ResolveSelection(...) + // 3) instantiate enabled plugins via builtin.Catalog factories + // 4) return instances + summary +} +``` + +**Step 4: Run test to verify it passes** + +Run: `go test ./cmd/picoclaw/internal/pluginruntime -run 'TestResolveConfiguredPlugins_' -v` +Expected: PASS. + +**Step 5: Commit** + +```bash +git add cmd/picoclaw/internal/pluginruntime/bootstrap.go cmd/picoclaw/internal/pluginruntime/bootstrap_test.go +git commit -m "feat(cli): add plugin runtime bootstrap resolver" +``` + +### Task 5: Wire Phase 2 Plugin Bootstrap into `agent` and `gateway` + +**Files:** +- Modify: `cmd/picoclaw/internal/agent/helpers.go` +- Modify: `cmd/picoclaw/internal/gateway/helpers.go` +- Test: `cmd/picoclaw/internal/agent/command_test.go` +- Test: `cmd/picoclaw/internal/gateway/command_test.go` + +**Step 1: Write the failing test** + +Add focused assertions that command constructors remain stable after importing plugin bootstrap package and wiring helper calls (no regressions in command metadata). +If needed, add table-driven compile/runtime smoke tests in a new `_test.go` under each package. + +**Step 2: Run test to verify it fails** + +Run: `go test ./cmd/picoclaw/internal/agent ./cmd/picoclaw/internal/gateway -run 'TestNew.*Command|Test.*Plugin.*' -v` +Expected: FAIL once bootstrap calls are referenced but not integrated correctly. + +**Step 3: Write minimal implementation** + +```go +pluginsToEnable, summary, err := pluginruntime.ResolveConfiguredPlugins(cfg) +if err != nil { + return fmt.Errorf("resolve plugins: %w", err) +} +if len(pluginsToEnable) > 0 { + if err := agentLoop.EnablePlugins(pluginsToEnable...); err != nil { + return fmt.Errorf("enable plugins: %w", err) + } +} +logger.InfoCF("plugin", "Plugin selection resolved", map[string]any{ + "enabled": summary.Enabled, "disabled": summary.Disabled, + "unknown_disabled": summary.UnknownDisabled, +}) +``` + +**Step 4: Run test to verify it passes** + +Run: `go test ./cmd/picoclaw/internal/agent ./cmd/picoclaw/internal/gateway -v` +Expected: PASS. + +**Step 5: Commit** + +```bash +git add cmd/picoclaw/internal/agent/helpers.go cmd/picoclaw/internal/gateway/helpers.go cmd/picoclaw/internal/agent/command_test.go cmd/picoclaw/internal/gateway/command_test.go +git commit -m "feat(cli): wire plugin selection into agent and gateway startup" +``` + +### Task 6: Expose Plugin Resolution in Startup Diagnostics (Phase 2) + +**Files:** +- Modify: `pkg/agent/loop.go` +- Test: `pkg/agent/loop_test.go` +- Test: `pkg/agent/plugin_test.go` + +**Step 1: Write the failing test** + +```go +func TestGetStartupInfo_IncludesPluginSummary(t *testing.T) { + // create AgentLoop, enable a test plugin, call GetStartupInfo + // assert "plugins" key exists with enabled list +} +``` + +**Step 2: Run test to verify it fails** + +Run: `go test ./pkg/agent -run 'TestGetStartupInfo_IncludesPluginSummary' -v` +Expected: FAIL because `plugins` section is absent. + +**Step 3: Write minimal implementation** + +```go +if al.pluginManager != nil { + info["plugins"] = map[string]any{ + "enabled": al.pluginManager.Names(), + "count": len(al.pluginManager.Names()), + } +} else { + info["plugins"] = map[string]any{ + "enabled": []string{}, + "count": 0, + } +} +``` + +**Step 4: Run test to verify it passes** + +Run: `go test ./pkg/agent -run 'TestGetStartupInfo_IncludesPluginSummary|TestSetPluginManagerInstallsHookRegistry' -v` +Expected: PASS. + +**Step 5: Commit** + +```bash +git add pkg/agent/loop.go pkg/agent/loop_test.go pkg/agent/plugin_test.go +git commit -m "feat(agent): include plugin summary in startup diagnostics" +``` + +### Task 7: Add Non-Breaking Plugin Metadata Introspection (Phase 3) + +**Files:** +- Modify: `pkg/plugin/manager.go` +- Test: `pkg/plugin/manager_test.go` + +**Step 1: Write the failing test** + +```go +func TestDescribeAll_UsesDescriptorWhenImplemented(t *testing.T) {} +func TestDescribeAll_FallsBackForPlainPlugin(t *testing.T) {} +``` + +**Step 2: Run test to verify it fails** + +Run: `go test ./pkg/plugin -run 'TestDescribeAll_' -v` +Expected: FAIL with undefined `PluginInfo` / `DescribeAll`. + +**Step 3: Write minimal implementation** + +```go +type PluginInfo struct { + Name string `json:"name"` + APIVersion string `json:"api_version"` + Status string `json:"status"` +} + +type PluginDescriptor interface { + Info() PluginInfo +} + +func (m *Manager) DescribeAll() []PluginInfo { /* include fallback info */ } +func (m *Manager) DescribeEnabled() []PluginInfo { /* status=enabled */ } +``` + +**Step 4: Run test to verify it passes** + +Run: `go test ./pkg/plugin -run 'TestDescribeAll_|TestRegisterPluginAndTriggerHook' -v` +Expected: PASS. + +**Step 5: Commit** + +```bash +git add pkg/plugin/manager.go pkg/plugin/manager_test.go +git commit -m "feat(plugin): add non-breaking plugin metadata introspection" +``` + +### Task 8: Add `picoclaw plugin list` Command (Phase 3) + +**Files:** +- Create: `cmd/picoclaw/internal/plugin/command.go` +- Create: `cmd/picoclaw/internal/plugin/list.go` +- Test: `cmd/picoclaw/internal/plugin/command_test.go` +- Test: `cmd/picoclaw/internal/plugin/list_test.go` +- Modify: `cmd/picoclaw/main.go` +- Modify: `cmd/picoclaw/main_test.go` + +**Step 1: Write the failing test** + +```go +func TestNewPluginCommand(t *testing.T) {} +func TestNewListSubcommand(t *testing.T) {} +func TestNewPicoclawCommand_IncludesPluginCommand(t *testing.T) {} +``` + +**Step 2: Run test to verify it fails** + +Run: `go test ./cmd/picoclaw/internal/plugin ./cmd/picoclaw -run 'TestNewPluginCommand|TestNewListSubcommand|TestNewPicoclawCommand' -v` +Expected: FAIL with missing package/command registration. + +**Step 3: Write minimal implementation** + +```go +func NewPluginCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "plugin", + Short: "Inspect and validate plugins", + RunE: func(cmd *cobra.Command, _ []string) error { return cmd.Help() }, + } + cmd.AddCommand(newListCommand()) + return cmd +} +``` + +`newListCommand()` should load config, resolve selection, and print text or JSON list. + +**Step 4: Run test to verify it passes** + +Run: `go test ./cmd/picoclaw/internal/plugin ./cmd/picoclaw -v` +Expected: PASS. + +**Step 5: Commit** + +```bash +git add cmd/picoclaw/internal/plugin/command.go cmd/picoclaw/internal/plugin/list.go cmd/picoclaw/internal/plugin/command_test.go cmd/picoclaw/internal/plugin/list_test.go cmd/picoclaw/main.go cmd/picoclaw/main_test.go +git commit -m "feat(cli): add plugin list command" +``` + +### Task 9: Add `picoclaw plugin lint` Command (Phase 3) + +**Files:** +- Create: `cmd/picoclaw/internal/plugin/lint.go` +- Test: `cmd/picoclaw/internal/plugin/lint_test.go` +- Modify: `cmd/picoclaw/internal/plugin/command.go` +- Modify: `cmd/picoclaw/internal/pluginruntime/bootstrap.go` +- Modify: `cmd/picoclaw/internal/pluginruntime/bootstrap_test.go` + +**Step 1: Write the failing test** + +```go +func TestPluginLint_ValidConfigExitZero(t *testing.T) {} +func TestPluginLint_UnknownEnabledExitNonZero(t *testing.T) {} +``` + +**Step 2: Run test to verify it fails** + +Run: `go test ./cmd/picoclaw/internal/plugin ./cmd/picoclaw/internal/pluginruntime -run 'TestPluginLint_|TestResolveConfiguredPlugins_' -v` +Expected: FAIL with missing lint command/validation path. + +**Step 3: Write minimal implementation** + +```go +func newLintCommand() *cobra.Command { + var configPath string + cmd := &cobra.Command{ + Use: "lint", + Short: "Validate plugin configuration", + RunE: func(_ *cobra.Command, _ []string) error { + cfg, err := config.LoadConfig(configPath) + if err != nil { return err } + _, _, err = pluginruntime.ResolveConfiguredPlugins(cfg) + return err + }, + } + cmd.Flags().StringVar(&configPath, "config", internal.GetConfigPath(), "Path to config.json") + return cmd +} +``` + +**Step 4: Run test to verify it passes** + +Run: `go test ./cmd/picoclaw/internal/plugin ./cmd/picoclaw/internal/pluginruntime -v` +Expected: PASS. + +**Step 5: Commit** + +```bash +git add cmd/picoclaw/internal/plugin/lint.go cmd/picoclaw/internal/plugin/lint_test.go cmd/picoclaw/internal/plugin/command.go cmd/picoclaw/internal/pluginruntime/bootstrap.go cmd/picoclaw/internal/pluginruntime/bootstrap_test.go +git commit -m "feat(cli): add plugin lint command" +``` + +### Task 10: Documentation and Final Verification + +**Files:** +- Modify: `docs/plugin-system-roadmap.md` +- Modify: `docs/plans/2026-02-28-plugin-system-phase2-phase3-design.md` +- Optional Modify: `README.md` (if command docs are surfaced there) + +**Step 1: Write docs-oriented failing checks** + +Add/update checklist assertions in docs PR description: +- Phase 2 gates explicitly checked. +- Phase 3 list/lint behavior and exit semantics documented. + +**Step 2: Run verification commands** + +Run: + +```bash +go test ./pkg/config ./pkg/plugin ./pkg/plugin/builtin ./pkg/agent ./cmd/picoclaw ./cmd/picoclaw/internal/plugin ./cmd/picoclaw/internal/pluginruntime -v +``` + +Expected: PASS. + +**Step 3: Minimal doc implementation** + +Document: +- JSON `plugins` config examples +- deterministic precedence rules +- `plugin list` usage (`--format json`) +- `plugin lint --config` usage and non-zero behavior + +**Step 4: Re-run verification** + +Run: + +```bash +go test ./pkg/config ./pkg/plugin ./pkg/plugin/builtin ./pkg/agent ./cmd/picoclaw ./cmd/picoclaw/internal/plugin ./cmd/picoclaw/internal/pluginruntime -v +``` + +Expected: PASS. + +**Step 5: Commit** + +```bash +git add docs/plugin-system-roadmap.md docs/plans/2026-02-28-plugin-system-phase2-phase3-design.md README.md +git commit -m "docs(plugin): document phase2/phase3 behavior and cli usage" +``` + +--- + +## PR Plan (maintainer-friendly) + +1. PR-1: Tasks 1-2 (`config` + resolver). +2. PR-2: Tasks 3-6 (catalog + bootstrap + startup diagnostics). +3. PR-3: Tasks 7-9 (metadata + plugin list/lint CLI). +4. PR-4: Task 10 docs-only cleanup if needed. + +Each PR should include: +- complete PR template fields +- AI disclosure +- test environment and command evidence +- unresolved comments = 0 before merge diff --git a/docs/plugin-system-roadmap.md b/docs/plugin-system-roadmap.md new file mode 100644 index 0000000000..b846f4f608 --- /dev/null +++ b/docs/plugin-system-roadmap.md @@ -0,0 +1,126 @@ +# Plugin System Roadmap + +This document defines how PicoClaw evolves from hook-based extension points to a fuller plugin system in low-risk phases. + +## Current Status (Phase 0: Foundation) + +Implemented in current hooks PR: + +- Typed lifecycle hooks (`pkg/hooks`) +- Priority-based handler ordering +- Cancellation support for modifying hooks +- Panic recovery and error isolation +- Agent-loop integration via `agentLoop.SetHooks(...)` + +Compatibility: + +- If no hooks are registered, runtime behavior is unchanged. +- No config migration is required. + +## Non-Goals in Phase 0 + +- No dynamic runtime plugin loading +- No remote plugin marketplace/distribution +- No plugin sandboxing model +- No stable external plugin ABI yet +- No Go `.so` plugin loading as default direction + +## Phase Plan + +## Phase 1: Static Plugin Contract (Compile-time) — Implemented + +Goal: define a minimal public plugin contract for Go modules. + +Implemented: + +- Add `pkg/plugin` with a small interface: + - `Name() string` + - `APIVersion() string` + - `Register(*hooks.HookRegistry) error` +- Register plugins at startup in code. +- Add compatibility metadata (`plugin.APIVersion`) and registration-time checks. + +Exit criteria (met): + +- Example plugin module builds against the contract. +- Startup validation logs loaded plugins and registration errors clearly. + +## Phase 2: Config-driven Enable/Disable — Implemented + +Goal: operational control without code changes. + +Implemented: + +- Add typed plugin selection config in `config.json`: + - `plugins.default_enabled` + - `plugins.enabled` + - `plugins.disabled` +- Add deterministic plugin resolution and conflict handling in the plugin manager. +- Wire resolved plugins into startup for both `agent` and `gateway` entrypoints. + +Exit criteria (met): + +- Users can toggle built-in plugins without rebuilding. +- Invalid plugin selection in config is surfaced during startup/lint flow. + +## Phase 3: Metadata Introspection + CLI — Implemented + +Goal: make plugin state inspectable and config validation straightforward. + +Implemented: + +- Add plugin metadata introspection in the plugin manager (internal API surface). +- Add CLI inspection commands: + - `picoclaw plugin list` + - `picoclaw plugin list --format json` +- Add CLI lint command: + - `picoclaw plugin lint --config ` +- Add startup plugin resolution summary diagnostics: + - `plugins_enabled` + - `plugins_disabled` + - `plugins_unknown_enabled` + - `plugins_unknown_disabled` + - `plugins_warnings` + +Exit criteria (met): + +- Operators can inspect plugin status in text/JSON outputs (`name`, `status`). +- Plugin metadata introspection is available via plugin manager APIs. +- Operators can validate plugin config before startup. + +## Future DX Work (Post-Phase 3) + +- Provide `examples/plugins/*` reference implementations. +- Publish plugin authoring guide (lifecycle map, best practices, safety constraints). +- Add plugin-focused test harness patterns for hook behavior verification. + +## Phase 4: Optional Dynamic Loading (Separate RFC) + +Goal: support runtime-loaded plugins only if security and operability are acceptable. + +Preferred direction: + +- Runtime plugins run as subprocesses. +- Host and plugin communicate via RPC/gRPC. +- Host manages lifecycle (spawn/health/timeout/restart), not in-process dynamic loading. + +Why this direction: + +- Go native `.so` plugin loading has strict toolchain/ABI coupling with host binary. +- Subprocess RPC model reduces coupling and improves fault isolation. +- Process boundary provides a cleaner place for permissions and sandbox controls. + +Preconditions: + +- Threat model approved +- Signature/trust model defined +- Sandboxing and permission boundaries defined +- Rollback and safe-disable behavior validated +- Versioned RPC handshake and capability negotiation defined +- Process supervision policy defined (timeouts, retries, crash loop backoff) + +Until then, compile-time registration remains the recommended model. + +## Maintainer Review Notes + +The current hooks PR should be reviewed as Phase 0+1 only. It intentionally establishes extension points while avoiding high-risk runtime plugin mechanics. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 8fd7328d10..a59604571d 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -22,8 +22,10 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/hooks" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/plugin" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/skills" @@ -42,6 +44,8 @@ type AgentLoop struct { fallback *providers.FallbackChain channelManager *channels.Manager mediaStore media.MediaStore + hooks *hooks.HookRegistry + pluginManager *plugin.Manager } // processOptions configures how a message is processed @@ -56,8 +60,6 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } -const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." - func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { registry := NewAgentRegistry(cfg, provider) @@ -172,61 +174,33 @@ func (al *AgentLoop) Run(ctx context.Context) error { continue } - // Process message - func() { - // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. - // Currently disabled because files are deleted before the LLM can access their content. - // defer func() { - // if al.mediaStore != nil && msg.MediaScope != "" { - // if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { - // logger.WarnCF("agent", "Failed to release media", map[string]any{ - // "scope": msg.MediaScope, - // "error": releaseErr.Error(), - // }) - // } - // } - // }() - - response, err := al.processMessage(ctx, msg) - if err != nil { - response = fmt.Sprintf("Error processing message: %v", err) - } + response, err := al.processMessage(ctx, msg) + if err != nil { + response = fmt.Sprintf("Error processing message: %v", err) + } - if response != "" { - // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). - alreadySent := false - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent != nil { - if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() - } + if response != "" { + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). + alreadySent := false + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() } } + } - if !alreadySent { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) - logger.InfoCF("agent", "Published outbound response", - map[string]any{ - "channel": msg.Channel, - "chat_id": msg.ChatID, - "content_len": len(response), - }) - } else { - logger.DebugCF( - "agent", - "Skipped outbound (message tool already sent)", - map[string]any{"channel": msg.Channel}, - ) - } + if !alreadySent { + al.sendOutbound(ctx, bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) } - }() + } } } @@ -284,6 +258,111 @@ func inferMediaType(filename, contentType string) string { return "file" } +// SetHooks installs a hook registry. Must be called before Run starts. +func (al *AgentLoop) SetHooks(h *hooks.HookRegistry) error { + if al.running.Load() { + return fmt.Errorf("SetHooks must be called before Run starts") + } + al.hooks = h + + // Rewire MessageTool callbacks to route through sendOutbound for hook interception. + for _, agentID := range al.registry.ListAgentIDs() { + if agent, ok := al.registry.GetAgent(agentID); ok { + if tool, ok := agent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + if h == nil { + mt.SetSendCallback(func(channel, chatID, content string) error { + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) + }) + continue + } + mt.SetSendCallback(func(channel, chatID, content string) error { + if sent, reason := al.sendOutbound(context.Background(), bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }); !sent { + if strings.TrimSpace(reason) == "" { + reason = "unspecified" + } + return fmt.Errorf("message canceled by hook: %s", reason) + } + return nil + }) + } + } + } + } + return nil +} + +// SetPluginManager installs a plugin manager and routes its hook registry into the loop. +// Must be called before Run starts. +func (al *AgentLoop) SetPluginManager(pm *plugin.Manager) error { + if pm == nil { + if err := al.SetHooks(nil); err != nil { + return err + } + al.pluginManager = nil + return nil + } + if err := al.SetHooks(pm.HookRegistry()); err != nil { + return err + } + al.pluginManager = pm + return nil +} + +// EnablePlugins is a convenience helper to build and install a plugin manager. +func (al *AgentLoop) EnablePlugins(plugins ...plugin.Plugin) error { + pm := plugin.NewManager() + if err := pm.RegisterAll(plugins...); err != nil { + return err + } + return al.SetPluginManager(pm) +} + +// sendOutbound wraps bus.PublishOutbound with the message_sending hook. +// Returns whether the message was sent and, if canceled, the cancel reason. +func (al *AgentLoop) sendOutbound(ctx context.Context, msg bus.OutboundMessage) (bool, string) { + if ctx == nil { + ctx = context.Background() + } + if al.hooks != nil { + event := &hooks.MessageSendingEvent{Channel: msg.Channel, ChatID: msg.ChatID, Content: msg.Content} + al.hooks.TriggerMessageSending(ctx, event) + if event.Cancel { + reason := event.CancelReason + if reason == "" { + reason = "unspecified" + } + logger.WarnCF("hooks", "Outbound message canceled by hook", + map[string]any{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "reason": reason, + }) + return false, reason + } + msg.Content = event.Content + } + if err := al.bus.PublishOutbound(ctx, msg); err != nil { + logger.WarnCF("agent", "Failed to publish outbound message", map[string]any{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "error": err.Error(), + }) + return false, err.Error() + } + return true, "" +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { @@ -333,7 +412,7 @@ func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, cha Channel: channel, ChatID: chatID, UserMessage: content, - DefaultResponse: defaultResponse, + DefaultResponse: "I've completed processing but have no response to give.", EnableSummary: false, SendResponse: false, NoHistory: true, // Don't load session history for heartbeat @@ -356,6 +435,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "session_key": msg.SessionKey, }) + // Fire message_received hook + if al.hooks != nil { + al.hooks.TriggerMessageReceived(ctx, &hooks.MessageReceivedEvent{ + Channel: msg.Channel, + SenderID: msg.SenderID, + ChatID: msg.ChatID, + Content: msg.Content, + Media: msg.Media, + Metadata: msg.Metadata, + }) + } + // Route system messages to processSystemMessage if msg.Channel == "system" { return al.processSystemMessage(ctx, msg) @@ -384,13 +475,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) } - // Reset message-tool state for this round so we don't skip publishing due to a previous round. - if tool, ok := agent.Tools.Get("message"); ok { - if mt, ok := tool.(tools.ContextualTool); ok { - mt.SetContext(msg.Channel, msg.ChatID) - } - } - // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) sessionKey := route.SessionKey if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") { @@ -409,7 +493,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, - DefaultResponse: defaultResponse, + DefaultResponse: "I've completed processing but have no response to give.", EnableSummary: true, SendResponse: false, }) @@ -490,6 +574,18 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 1. Update tool contexts al.updateToolContexts(agent, opts.Channel, opts.ChatID) + // Fire session hooks + if al.hooks != nil { + sessionEvt := &hooks.SessionEvent{ + AgentID: agent.ID, + SessionKey: opts.SessionKey, + Channel: opts.Channel, + ChatID: opts.ChatID, + } + al.hooks.TriggerSessionStart(ctx, sessionEvt) + defer al.hooks.TriggerSessionEnd(ctx, sessionEvt) + } + // 2. Build messages (skip history for heartbeat) var history []providers.Message var summary string @@ -529,12 +625,12 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 7. Optional: summarization if opts.EnableSummary { - al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) + al.maybeSummarize(ctx, agent, opts.SessionKey, opts.Channel, opts.ChatID) } // 8. Optional: send response via bus if opts.SendResponse { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: finalContent, @@ -576,7 +672,7 @@ func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, chan } // Use a short timeout so the goroutine does not block indefinitely when - // the outbound bus is full. Reasoning output is best-effort; dropping it + // the outbound bus is full. Reasoning output is best-effort; dropping it // is acceptable to avoid goroutine accumulation. pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second) defer pubCancel() @@ -587,7 +683,7 @@ func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, chan Content: reasoningContent, }); err != nil { // Treat context.DeadlineExceeded / context.Canceled as expected - // (bus full under load, or parent canceled). Check the error + // (bus full under load, or parent canceled). Check the error // itself rather than ctx.Err(), because pubCtx may time out // (5 s) while the parent ctx is still active. // Also treat ErrBusClosed as expected — it occurs during normal @@ -684,8 +780,19 @@ func (al *AgentLoop) runLLMIteration( } // Retry loop for context/token errors + llmStart := time.Now() maxRetries := 2 for retry := 0; retry <= maxRetries; retry++ { + // Fire llm_input hook (re-fires after compression so hooks see actual messages) + if al.hooks != nil { + al.hooks.TriggerLLMInput(ctx, &hooks.LLMInputEvent{ + AgentID: agent.ID, + Model: agent.Model, + Messages: messages, + Tools: providerToolDefs, + Iteration: iteration, + }) + } response, err = callLLM() if err == nil { break @@ -729,7 +836,7 @@ func (al *AgentLoop) runLLMIteration( }) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: "Context window exceeded. Compressing history and retrying...", @@ -748,6 +855,8 @@ func (al *AgentLoop) runLLMIteration( break } + llmDuration := time.Since(llmStart) + if err != nil { logger.ErrorCF("agent", "LLM call failed", map[string]any{ @@ -760,16 +869,18 @@ func (al *AgentLoop) runLLMIteration( go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel)) - logger.DebugCF("agent", "LLM response", - map[string]any{ - "agent_id": agent.ID, - "iteration": iteration, - "content_chars": len(response.Content), - "tool_calls": len(response.ToolCalls), - "reasoning": response.Reasoning, - "target_channel": al.targetReasoningChannelID(opts.Channel), - "channel": opts.Channel, + // Fire llm_output hook + if al.hooks != nil { + al.hooks.TriggerLLMOutput(ctx, &hooks.LLMOutputEvent{ + AgentID: agent.ID, + Model: agent.Model, + Content: response.Content, + ToolCalls: response.ToolCalls, + Iteration: iteration, + Duration: llmDuration, }) + } + // Check if no tool calls - we're done if len(response.ToolCalls) == 0 { finalContent = response.Content @@ -832,9 +943,14 @@ func (al *AgentLoop) runLLMIteration( // Save assistant message with tool calls to session agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) + assistantMsgIndex := len(messages) - 1 + assistantSessionIndex := -1 + if history := agent.Sessions.GetHistory(opts.SessionKey); len(history) > 0 { + assistantSessionIndex = len(history) - 1 + } // Execute tool calls - for _, tc := range normalizedToolCalls { + for tcIdx, tc := range normalizedToolCalls { argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), @@ -860,18 +976,74 @@ func (al *AgentLoop) runLLMIteration( } } - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, - asyncCallback, - ) + // Fire before_tool_call hook + var toolResult *tools.ToolResult + toolCanceled := false + if al.hooks != nil { + args := tc.Arguments + if args == nil { + args = make(map[string]any) + } + btcEvent := &hooks.BeforeToolCallEvent{ + ToolName: tc.Name, + Args: args, + Channel: opts.Channel, + ChatID: opts.ChatID, + } + al.hooks.TriggerBeforeToolCall(ctx, btcEvent) + if btcEvent.Cancel { + toolCanceled = true + reason := btcEvent.CancelReason + if strings.TrimSpace(reason) == "" { + reason = fmt.Sprintf("tool call %q was canceled by before_tool_call hook", tc.Name) + } + toolResult = tools.ErrorResult(reason) + } + tc.Arguments = btcEvent.Args + if tc.Arguments == nil { + tc.Arguments = make(map[string]any) + } + + // Keep persisted assistant tool-call arguments aligned with rewritten execution args. + updateToolCallArguments(&messages[assistantMsgIndex], tcIdx, tc.Arguments) + if assistantSessionIndex >= 0 { + history := agent.Sessions.GetHistory(opts.SessionKey) + if assistantSessionIndex < len(history) { + updateToolCallArguments(&history[assistantSessionIndex], tcIdx, tc.Arguments) + agent.Sessions.SetHistory(opts.SessionKey, history) + } + } + } + + var toolDuration time.Duration + if !toolCanceled { + toolStart := time.Now() + toolResult = agent.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + opts.Channel, + opts.ChatID, + asyncCallback, + ) + toolDuration = time.Since(toolStart) + } + + // Fire after_tool_call hook (fires for both executed and canceled calls) + if al.hooks != nil { + al.hooks.TriggerAfterToolCall(ctx, &hooks.AfterToolCallEvent{ + ToolName: tc.Name, + Args: tc.Arguments, + Channel: opts.Channel, + ChatID: opts.ChatID, + Duration: toolDuration, + Result: toolResult, + }) + } // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: toolResult.ForUser, @@ -898,7 +1070,7 @@ func (al *AgentLoop) runLLMIteration( } parts = append(parts, part) } - al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ + _ = al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Parts: parts, @@ -947,7 +1119,7 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st } // maybeSummarize triggers summarization if the session history exceeds thresholds. -func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { +func (al *AgentLoop) maybeSummarize(_ context.Context, agent *AgentInstance, sessionKey, channel, chatID string) { newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) threshold := agent.ContextWindow * 75 / 100 @@ -957,6 +1129,13 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading { go func() { defer al.summarizing.Delete(summarizeKey) + if !constants.IsInternalChannel(channel) { + al.sendOutbound(context.Background(), bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: "Memory threshold reached. Optimizing conversation history...", + }) + } logger.Debug("Memory threshold reached. Optimizing conversation history...") al.summarizeSession(agent, sessionKey) }() @@ -1026,6 +1205,14 @@ func (al *AgentLoop) GetStartupInfo() map[string]any { return info } + pluginNames := make([]string, 0) + if al.pluginManager != nil { + pluginNames = al.pluginManager.Names() + if pluginNames == nil { + pluginNames = make([]string, 0) + } + } + // Tools info toolsList := agent.Tools.List() info["tools"] = map[string]any{ @@ -1033,6 +1220,12 @@ func (al *AgentLoop) GetStartupInfo() map[string]any { "names": toolsList, } + // Plugins info + info["plugins"] = map[string]any{ + "enabled": pluginNames, + "count": len(pluginNames), + } + // Skills info info["skills"] = agent.ContextBuilder.GetSkillsInfo() @@ -1045,6 +1238,19 @@ func (al *AgentLoop) GetStartupInfo() map[string]any { return info } +// updateToolCallArguments patches the serialized arguments for a tool call in-place. +func updateToolCallArguments(msg *providers.Message, toolCallIndex int, args map[string]any) { + if msg == nil || toolCallIndex < 0 || toolCallIndex >= len(msg.ToolCalls) { + return + } + toolCall := &msg.ToolCalls[toolCallIndex] + if toolCall.Function == nil { + return + } + argumentsJSON, _ := json.Marshal(args) + toolCall.Function.Arguments = string(argumentsJSON) +} + // formatMessagesForLog formats messages for logging func formatMessagesForLog(messages []providers.Message) string { if len(messages) == 0 { @@ -1317,20 +1523,27 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "", false } -// extractPeer extracts the routing peer from the inbound message's structured Peer field. +// extractPeer extracts the routing peer from inbound message metadata. func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { - if msg.Peer.Kind == "" { + peerKind := msg.Metadata["peer_kind"] + if peerKind == "" { + peerKind = msg.Peer.Kind + } + peerID := msg.Metadata["peer_id"] + if peerID == "" { + peerID = msg.Peer.ID + } + if peerKind == "" { return nil } - peerID := msg.Peer.ID if peerID == "" { - if msg.Peer.Kind == "direct" { + if peerKind == "direct" { peerID = msg.SenderID } else { peerID = msg.ChatID } } - return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} + return &routing.RoutePeer{Kind: peerKind, ID: peerID} } // extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 801b6a46ed..fc2beb3907 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -321,6 +321,90 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) { } } +func TestGetStartupInfo_IncludesPluginSummary(t *testing.T) { + newLoop := func(t *testing.T) *AgentLoop { + t.Helper() + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + return NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{}) + } + + t.Run("no plugins enabled", func(t *testing.T) { + al := newLoop(t) + info := al.GetStartupInfo() + + pluginsInfo, ok := info["plugins"].(map[string]any) + if !ok { + t.Fatal("Expected 'plugins' to be a map") + } + + count, ok := pluginsInfo["count"].(int) + if !ok { + t.Fatal("Expected plugin count to be an int") + } + if count != 0 { + t.Fatalf("Expected plugin count 0, got %d", count) + } + + enabled, ok := pluginsInfo["enabled"].([]string) + if !ok { + t.Fatal("Expected plugin enabled list to be []string") + } + if len(enabled) != 0 { + t.Fatalf("Expected no enabled plugins, got %v", enabled) + } + }) + + t.Run("plugins enabled", func(t *testing.T) { + al := newLoop(t) + if err := al.EnablePlugins(blockingPlugin{}); err != nil { + t.Fatalf("EnablePlugins failed: %v", err) + } + + info := al.GetStartupInfo() + pluginsInfo, ok := info["plugins"].(map[string]any) + if !ok { + t.Fatal("Expected 'plugins' to be a map") + } + + count, ok := pluginsInfo["count"].(int) + if !ok { + t.Fatal("Expected plugin count to be an int") + } + if count <= 0 { + t.Fatalf("Expected plugin count > 0, got %d", count) + } + + enabled, ok := pluginsInfo["enabled"].([]string) + if !ok { + t.Fatal("Expected plugin enabled list to be []string") + } + if len(enabled) == 0 { + t.Fatal("Expected at least one enabled plugin") + } + + found := false + for _, name := range enabled { + if name == "block-outbound" { + found = true + break + } + } + if !found { + t.Fatalf("Expected enabled plugin list to include block-outbound, got %v", enabled) + } + }) +} + // TestAgentLoop_Stop verifies Stop() sets running to false func TestAgentLoop_Stop(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") diff --git a/pkg/agent/plugin_test.go b/pkg/agent/plugin_test.go new file mode 100644 index 0000000000..c0b6f0625a --- /dev/null +++ b/pkg/agent/plugin_test.go @@ -0,0 +1,416 @@ +package agent + +import ( + "context" + "encoding/json" + "os" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/hooks" + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +type blockingPlugin struct{} + +func (p blockingPlugin) Name() string { + return "block-outbound" +} + +func (p blockingPlugin) APIVersion() string { + return plugin.APIVersion +} + +func (p blockingPlugin) Register(r *hooks.HookRegistry) error { + r.OnMessageSending("block-all", 0, func(_ context.Context, e *hooks.MessageSendingEvent) error { + e.Cancel = true + e.CancelReason = "blocked by plugin" + return nil + }) + return nil +} + +type nilArgsProvider struct { + calls int +} + +func (p *nilArgsProvider) Chat( + _ context.Context, + _ []providers.Message, + _ []providers.ToolDefinition, + _ string, + _ map[string]any, +) (*providers.LLMResponse, error) { + if p.calls == 0 { + p.calls++ + return &providers.LLMResponse{ + Content: "", + ToolCalls: []providers.ToolCall{ + { + ID: "tc-1", + Type: "function", + Name: "nil_args_tool", + Arguments: map[string]any{"seed": "value"}, + }, + }, + }, nil + } + p.calls++ + return &providers.LLMResponse{ + Content: "done", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (p *nilArgsProvider) GetDefaultModel() string { + return "test-model" +} + +type nilArgsCaptureTool struct { + receivedNil bool +} + +func (t *nilArgsCaptureTool) Name() string { + return "nil_args_tool" +} + +func (t *nilArgsCaptureTool) Description() string { + return "captures whether args are nil" +} + +func (t *nilArgsCaptureTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (t *nilArgsCaptureTool) Execute(_ context.Context, args map[string]any) *tools.ToolResult { + if args == nil { + t.receivedNil = true + } + return tools.SilentResult("ok") +} + +func TestSetPluginManagerInstallsHookRegistry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + + pm := plugin.NewManager() + if err := pm.Register(blockingPlugin{}); err != nil { + t.Fatalf("register plugin: %v", err) + } + + if err := al.SetPluginManager(pm); err != nil { + t.Fatalf("SetPluginManager: %v", err) + } + + if al.pluginManager == nil { + t.Fatal("expected plugin manager to be set") + } + if al.hooks != pm.HookRegistry() { + t.Fatal("expected agent loop hooks to use plugin manager registry") + } + + sent, reason := al.sendOutbound(context.Background(), bus.OutboundMessage{ + Channel: "cli", + ChatID: "direct", + Content: "hello", + }) + if sent { + t.Fatal("expected outbound message to be blocked by plugin") + } + if reason == "" { + t.Fatal("expected cancel reason to be propagated") + } +} + +func TestSetHooksReturnsErrorWhenRunning(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + al.running.Store(true) + + if err := al.SetHooks(hooks.NewHookRegistry()); err == nil { + t.Fatal("expected error when calling SetHooks while running") + } +} + +func TestSetPluginManagerDoesNotPartiallyUpdateOnError(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + al.running.Store(true) + + pm := plugin.NewManager() + if err := pm.Register(blockingPlugin{}); err != nil { + t.Fatalf("register plugin: %v", err) + } + + if err := al.SetPluginManager(pm); err == nil { + t.Fatal("expected SetPluginManager to fail while running") + } + if al.pluginManager != nil { + t.Fatal("expected plugin manager to remain unchanged on SetPluginManager failure") + } +} + +func TestBeforeToolCallHooksCannotLeaveToolArgsNil(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &nilArgsProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + captureTool := &nilArgsCaptureTool{} + al.RegisterTool(captureTool) + + r := hooks.NewHookRegistry() + r.OnBeforeToolCall("force-nil-args", 0, func(_ context.Context, e *hooks.BeforeToolCallEvent) error { + if e.ToolName == "nil_args_tool" { + e.Args = nil + } + return nil + }) + if setErr := al.SetHooks(r); setErr != nil { + t.Fatalf("SetHooks: %v", setErr) + } + + resp, err := al.ProcessDirectWithChannel(context.Background(), "run nil args test", "s1", "cli", "direct") + if err != nil { + t.Fatalf("ProcessDirectWithChannel: %v", err) + } + if resp != "done" { + t.Fatalf("expected final response 'done', got %q", resp) + } + if captureTool.receivedNil { + t.Fatal("expected tool args to be reinitialized to non-nil map") + } +} + +func TestSetHooksNilRestoresDirectMessageCallback(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + tool, ok := agent.Tools.Get("message") + if !ok { + t.Fatal("expected message tool") + } + mt, ok := tool.(*tools.MessageTool) + if !ok { + t.Fatal("expected message tool type") + } + + reg := hooks.NewHookRegistry() + reg.OnMessageSending("block-all", 0, func(_ context.Context, e *hooks.MessageSendingEvent) error { + e.Cancel = true + e.CancelReason = "blocked-by-hook" + return nil + }) + if err := al.SetHooks(reg); err != nil { + t.Fatalf("SetHooks(reg): %v", err) + } + + blocked := mt.Execute(context.Background(), map[string]any{ + "content": "first", + "channel": "cli", + "chat_id": "direct", + }) + if !blocked.IsError { + t.Fatal("expected message tool call to fail while hooks are active") + } + if blocked.Err == nil || !strings.Contains(blocked.Err.Error(), "blocked-by-hook") { + t.Fatalf("expected hook cancel reason in error, got %#v", blocked.Err) + } + + ctxNoMsg, cancelNoMsg := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancelNoMsg() + if _, got := msgBus.SubscribeOutbound(ctxNoMsg); got { + t.Fatal("did not expect outbound message while hook cancellation is active") + } + + if err := al.SetHooks(nil); err != nil { + t.Fatalf("SetHooks(nil): %v", err) + } + + delivered := mt.Execute(context.Background(), map[string]any{ + "content": "second", + "channel": "cli", + "chat_id": "direct", + }) + if delivered.IsError { + t.Fatalf("expected message tool to succeed after SetHooks(nil), got %#v", delivered) + } + + ctxMsg, cancelMsg := context.WithTimeout(context.Background(), time.Second) + defer cancelMsg() + msg, got := msgBus.SubscribeOutbound(ctxMsg) + if !got { + t.Fatal("expected outbound message after SetHooks(nil)") + } + if msg.Content != "second" || msg.Channel != "cli" || msg.ChatID != "direct" { + t.Fatalf("unexpected outbound message: %#v", msg) + } +} + +func TestBeforeToolCallArgRewriteUpdatesAssistantTranscript(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &nilArgsProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + al.RegisterTool(&nilArgsCaptureTool{}) + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + sessionKey := "agent:" + defaultAgent.ID + ":s2" + + reg := hooks.NewHookRegistry() + reg.OnBeforeToolCall("rewrite-args", 0, func(_ context.Context, e *hooks.BeforeToolCallEvent) error { + e.Args["rewritten"] = "yes" + return nil + }) + if err := al.SetHooks(reg); err != nil { + t.Fatalf("SetHooks: %v", err) + } + + if _, err := al.ProcessDirectWithChannel( + context.Background(), + "run rewrite test", + sessionKey, + "cli", + "direct", + ); err != nil { + t.Fatalf("ProcessDirectWithChannel: %v", err) + } + + history := defaultAgent.Sessions.GetHistory(sessionKey) + + foundToolCall := false + for _, msg := range history { + if msg.Role != "assistant" || len(msg.ToolCalls) == 0 { + continue + } + if msg.ToolCalls[0].Function == nil { + t.Fatal("expected tool call function payload") + } + var args map[string]any + if err := json.Unmarshal([]byte(msg.ToolCalls[0].Function.Arguments), &args); err != nil { + t.Fatalf("failed to decode persisted tool call args: %v", err) + } + if got := args["rewritten"]; got != "yes" { + t.Fatalf("expected rewritten arg to be persisted, got %#v", got) + } + foundToolCall = true + break + } + + if !foundToolCall { + t.Fatal("expected assistant tool call message in session history") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index d84772d2b0..92549b7778 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -53,6 +53,7 @@ type Config struct { Session SessionConfig `json:"session,omitempty"` Channels ChannelsConfig `json:"channels"` Providers ProvidersConfig `json:"providers,omitempty"` + Plugins PluginsConfig `json:"plugins,omitempty"` ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration Gateway GatewayConfig `json:"gateway"` Tools ToolsConfig `json:"tools"` @@ -167,6 +168,12 @@ type SessionConfig struct { IdentityLinks map[string][]string `json:"identity_links,omitempty"` } +type PluginsConfig struct { + DefaultEnabled bool `json:"default_enabled"` + Enabled []string `json:"enabled,omitempty"` + Disabled []string `json:"disabled,omitempty"` +} + type AgentDefaults struct { Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 12fd10b50b..7e25b38293 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -243,6 +243,51 @@ func TestDefaultConfig_Temperature(t *testing.T) { } } +func TestDefaultConfig_PluginsDefaults(t *testing.T) { + cfg := DefaultConfig() + + if !cfg.Plugins.DefaultEnabled { + t.Error("Plugins.DefaultEnabled should be true by default") + } + if cfg.Plugins.Enabled == nil { + t.Error("Plugins.Enabled should be initialized to an empty slice") + } + if len(cfg.Plugins.Enabled) != 0 { + t.Errorf("Plugins.Enabled len = %d, want 0", len(cfg.Plugins.Enabled)) + } + if cfg.Plugins.Disabled == nil { + t.Error("Plugins.Disabled should be initialized to an empty slice") + } + if len(cfg.Plugins.Disabled) != 0 { + t.Errorf("Plugins.Disabled len = %d, want 0", len(cfg.Plugins.Disabled)) + } +} + +func TestConfig_PluginsJSONUnmarshal(t *testing.T) { + jsonData := `{ + "plugins": { + "default_enabled": false, + "enabled": ["plugin-a", "plugin-b"], + "disabled": ["plugin-c"] + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if cfg.Plugins.DefaultEnabled { + t.Error("Plugins.DefaultEnabled = true, want false") + } + if len(cfg.Plugins.Enabled) != 2 || cfg.Plugins.Enabled[0] != "plugin-a" || cfg.Plugins.Enabled[1] != "plugin-b" { + t.Errorf("Plugins.Enabled = %v, want [plugin-a plugin-b]", cfg.Plugins.Enabled) + } + if len(cfg.Plugins.Disabled) != 1 || cfg.Plugins.Disabled[0] != "plugin-c" { + t.Errorf("Plugins.Disabled = %v, want [plugin-c]", cfg.Plugins.Disabled) + } +} + // TestDefaultConfig_Gateway verifies gateway defaults func TestDefaultConfig_Gateway(t *testing.T) { cfg := DefaultConfig() diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index ebb924859f..61f61ab34b 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -134,6 +134,11 @@ func DefaultConfig() *Config { Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{WebSearch: true}, }, + Plugins: PluginsConfig{ + DefaultEnabled: true, + Enabled: []string{}, + Disabled: []string{}, + }, ModelList: []ModelConfig{ // ============================================ // Add your API key to the model you want to use diff --git a/pkg/hooks/hooks.go b/pkg/hooks/hooks.go new file mode 100644 index 0000000000..9865c86fe4 --- /dev/null +++ b/pkg/hooks/hooks.go @@ -0,0 +1,499 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package hooks + +import ( + "context" + "fmt" + "reflect" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +const voidHookWaitBudget = 50 * time.Millisecond + +// HookHandler is the callback signature for all hooks. +type HookHandler[T any] func(ctx context.Context, event *T) error + +// HookRegistration tracks a handler with its priority and name. +type HookRegistration[T any] struct { + Handler HookHandler[T] + Priority int // Lower = runs first + Name string +} + +// HookRegistry manages all lifecycle hooks. +type HookRegistry struct { + messageReceived []HookRegistration[MessageReceivedEvent] + messageSending []HookRegistration[MessageSendingEvent] + beforeToolCall []HookRegistration[BeforeToolCallEvent] + afterToolCall []HookRegistration[AfterToolCallEvent] + llmInput []HookRegistration[LLMInputEvent] + llmOutput []HookRegistration[LLMOutputEvent] + sessionStart []HookRegistration[SessionEvent] + sessionEnd []HookRegistration[SessionEvent] + mu sync.RWMutex +} + +// NewHookRegistry creates an empty hook registry. +func NewHookRegistry() *HookRegistry { + return &HookRegistry{} +} + +// insertSorted inserts a registration into a new slice sorted by priority. +// Always allocates a new backing array so concurrent readers of the old slice are safe. +func insertSorted[T any](slice []HookRegistration[T], reg HookRegistration[T]) []HookRegistration[T] { + i := 0 + for i < len(slice) && slice[i].Priority <= reg.Priority { + i++ + } + result := make([]HookRegistration[T], len(slice)+1) + copy(result, slice[:i]) + result[i] = reg + copy(result[i+1:], slice[i:]) + return result +} + +// Registration methods + +func (r *HookRegistry) OnMessageReceived(name string, priority int, handler HookHandler[MessageReceivedEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.messageReceived = insertSorted(r.messageReceived, HookRegistration[MessageReceivedEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnMessageSending(name string, priority int, handler HookHandler[MessageSendingEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.messageSending = insertSorted(r.messageSending, HookRegistration[MessageSendingEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnBeforeToolCall(name string, priority int, handler HookHandler[BeforeToolCallEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.beforeToolCall = insertSorted(r.beforeToolCall, HookRegistration[BeforeToolCallEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnAfterToolCall(name string, priority int, handler HookHandler[AfterToolCallEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.afterToolCall = insertSorted(r.afterToolCall, HookRegistration[AfterToolCallEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnLLMInput(name string, priority int, handler HookHandler[LLMInputEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.llmInput = insertSorted(r.llmInput, HookRegistration[LLMInputEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnLLMOutput(name string, priority int, handler HookHandler[LLMOutputEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.llmOutput = insertSorted(r.llmOutput, HookRegistration[LLMOutputEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnSessionStart(name string, priority int, handler HookHandler[SessionEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.sessionStart = insertSorted(r.sessionStart, HookRegistration[SessionEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnSessionEnd(name string, priority int, handler HookHandler[SessionEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.sessionEnd = insertSorted(r.sessionEnd, HookRegistration[SessionEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +// Trigger methods — void hooks + +func cloneMapStringString(src map[string]string) map[string]string { + if src == nil { + return nil + } + dst := make(map[string]string, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func cloneMapStringAny(src map[string]any) map[string]any { + if src == nil { + return nil + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = cloneAny(v) + } + return dst +} + +func cloneAny(v any) any { + if v == nil { + return nil + } + cloned := cloneReflectValue(reflect.ValueOf(v)) + if !cloned.IsValid() { + return nil + } + return cloned.Interface() +} + +func cloneReflectValue(v reflect.Value) reflect.Value { + if !v.IsValid() { + return v + } + + switch v.Kind() { + case reflect.Pointer: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.New(v.Type().Elem()) + out.Elem().Set(cloneReflectValue(v.Elem())) + return out + case reflect.Interface: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.New(v.Type()).Elem() + out.Set(cloneReflectValue(v.Elem())) + return out + case reflect.Map: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.MakeMapWithSize(v.Type(), v.Len()) + iter := v.MapRange() + for iter.Next() { + out.SetMapIndex(iter.Key(), cloneReflectValue(iter.Value())) + } + return out + case reflect.Slice: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.MakeSlice(v.Type(), v.Len(), v.Len()) + for i := range v.Len() { + out.Index(i).Set(cloneReflectValue(v.Index(i))) + } + return out + case reflect.Array: + out := reflect.New(v.Type()).Elem() + for i := range v.Len() { + out.Index(i).Set(cloneReflectValue(v.Index(i))) + } + return out + case reflect.Struct: + out := reflect.New(v.Type()).Elem() + for i := range v.NumField() { + field := out.Field(i) + if !field.CanSet() { + // Preserve original value for structs with non-settable fields. + return v + } + field.Set(cloneReflectValue(v.Field(i))) + } + return out + default: + return v + } +} + +func cloneToolCall(tc providers.ToolCall) providers.ToolCall { + out := tc + out.Arguments = cloneMapStringAny(tc.Arguments) + if tc.Function != nil { + f := *tc.Function + out.Function = &f + } + if tc.ExtraContent != nil { + ec := *tc.ExtraContent + if tc.ExtraContent.Google != nil { + g := *tc.ExtraContent.Google + ec.Google = &g + } + out.ExtraContent = &ec + } + return out +} + +func cloneMessage(msg providers.Message) providers.Message { + out := msg + if msg.ToolCalls != nil { + out.ToolCalls = make([]providers.ToolCall, len(msg.ToolCalls)) + for i := range msg.ToolCalls { + out.ToolCalls[i] = cloneToolCall(msg.ToolCalls[i]) + } + } + if msg.SystemParts != nil { + out.SystemParts = make([]providers.ContentBlock, len(msg.SystemParts)) + for i := range msg.SystemParts { + part := msg.SystemParts[i] + if part.CacheControl != nil { + cc := *part.CacheControl + part.CacheControl = &cc + } + out.SystemParts[i] = part + } + } + return out +} + +func cloneToolDefinition(td providers.ToolDefinition) providers.ToolDefinition { + out := td + out.Function = td.Function + out.Function.Parameters = cloneMapStringAny(td.Function.Parameters) + return out +} + +func cloneVoidEvent[T any](event *T) *T { + if event == nil { + return nil + } + + switch e := any(event).(type) { + case *MessageReceivedEvent: + c := *e + if e.Media != nil { + c.Media = append([]string(nil), e.Media...) + } + c.Metadata = cloneMapStringString(e.Metadata) + return any(&c).(*T) + case *AfterToolCallEvent: + c := *e + c.Args = cloneMapStringAny(e.Args) + if e.Result != nil { + r := *e.Result + c.Result = &r + } + return any(&c).(*T) + case *LLMInputEvent: + c := *e + if e.Messages != nil { + c.Messages = make([]providers.Message, len(e.Messages)) + for i := range e.Messages { + c.Messages[i] = cloneMessage(e.Messages[i]) + } + } + if e.Tools != nil { + c.Tools = make([]providers.ToolDefinition, len(e.Tools)) + for i := range e.Tools { + c.Tools[i] = cloneToolDefinition(e.Tools[i]) + } + } + return any(&c).(*T) + case *LLMOutputEvent: + c := *e + if e.ToolCalls != nil { + c.ToolCalls = make([]providers.ToolCall, len(e.ToolCalls)) + for i := range e.ToolCalls { + c.ToolCalls[i] = cloneToolCall(e.ToolCalls[i]) + } + } + return any(&c).(*T) + case *SessionEvent: + c := *e + return any(&c).(*T) + default: + c := *event + return &c + } +} + +// triggerVoid runs all handlers concurrently. +// It waits up to a small budget to collect immediate completions, then +// continues fail-open to avoid blocking the core agent pipeline. +// Each handler receives a cloned event to avoid shared-state mutation races. +// Errors are logged but do not propagate to the caller. +func triggerVoid[T any](ctx context.Context, hooks []HookRegistration[T], event *T, hookName string) { + if len(hooks) == 0 { + return + } + var wg sync.WaitGroup + for _, h := range hooks { + wg.Add(1) + go func(reg HookRegistration[T]) { + defer wg.Done() + eventCopy := cloneVoidEvent(event) + defer func() { + if r := recover(); r != nil { + logger.ErrorCF("hooks", "Hook panic", + map[string]any{ + "hook": hookName, + "handler": reg.Name, + "panic": fmt.Sprintf("%v", r), + }) + } + }() + if err := reg.Handler(ctx, eventCopy); err != nil { + logger.WarnCF("hooks", "Hook error", + map[string]any{ + "hook": hookName, + "handler": reg.Name, + "error": err.Error(), + }) + } + }(h) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + logger.WarnCF("hooks", "Void hook dispatch interrupted by context", + map[string]any{ + "hook": hookName, + }) + case <-time.After(voidHookWaitBudget): + logger.WarnCF("hooks", "Void hook dispatch exceeded wait budget; continuing", + map[string]any{ + "hook": hookName, + "wait_budget_ms": voidHookWaitBudget.Milliseconds(), + }) + } +} + +// triggerModifying runs handlers sequentially by priority, stopping if Cancel is set. +// The cancelCheck function inspects the event to determine if Cancel was set. +func triggerModifying[T any]( + ctx context.Context, + hooks []HookRegistration[T], + event *T, + hookName string, + cancelCheck func(*T) bool, +) { + if len(hooks) == 0 { + return + } + for _, h := range hooks { + func() { + defer func() { + if r := recover(); r != nil { + logger.ErrorCF("hooks", "Hook panic", + map[string]any{ + "hook": hookName, + "handler": h.Name, + "panic": fmt.Sprintf("%v", r), + }) + } + }() + if err := h.Handler(ctx, event); err != nil { + logger.WarnCF("hooks", "Hook error", + map[string]any{ + "hook": hookName, + "handler": h.Name, + "error": err.Error(), + }) + } + }() + if cancelCheck(event) { + logger.InfoCF("hooks", "Hook canceled operation", + map[string]any{ + "hook": hookName, + "handler": h.Name, + }) + return + } + } +} + +// TriggerMessageReceived fires all message_received handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerMessageReceived(ctx context.Context, event *MessageReceivedEvent) { + r.mu.RLock() + hooks := r.messageReceived + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "message_received") +} + +func (r *HookRegistry) TriggerMessageSending(ctx context.Context, event *MessageSendingEvent) { + r.mu.RLock() + hooks := r.messageSending + r.mu.RUnlock() + triggerModifying(ctx, hooks, event, "message_sending", func(e *MessageSendingEvent) bool { + return e.Cancel + }) +} + +func (r *HookRegistry) TriggerBeforeToolCall(ctx context.Context, event *BeforeToolCallEvent) { + r.mu.RLock() + hooks := r.beforeToolCall + r.mu.RUnlock() + triggerModifying(ctx, hooks, event, "before_tool_call", func(e *BeforeToolCallEvent) bool { + return e.Cancel + }) +} + +// TriggerAfterToolCall fires all after_tool_call handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerAfterToolCall(ctx context.Context, event *AfterToolCallEvent) { + r.mu.RLock() + hooks := r.afterToolCall + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "after_tool_call") +} + +// TriggerLLMInput fires all llm_input handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerLLMInput(ctx context.Context, event *LLMInputEvent) { + r.mu.RLock() + hooks := r.llmInput + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "llm_input") +} + +// TriggerLLMOutput fires all llm_output handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerLLMOutput(ctx context.Context, event *LLMOutputEvent) { + r.mu.RLock() + hooks := r.llmOutput + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "llm_output") +} + +// TriggerSessionStart fires all session_start handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerSessionStart(ctx context.Context, event *SessionEvent) { + r.mu.RLock() + hooks := r.sessionStart + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "session_start") +} + +// TriggerSessionEnd fires all session_end handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerSessionEnd(ctx context.Context, event *SessionEvent) { + r.mu.RLock() + hooks := r.sessionEnd + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "session_end") +} diff --git a/pkg/hooks/hooks_test.go b/pkg/hooks/hooks_test.go new file mode 100644 index 0000000000..d21467a550 --- /dev/null +++ b/pkg/hooks/hooks_test.go @@ -0,0 +1,657 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package hooks + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +func TestNewHookRegistry(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // Triggering all hooks on an empty registry should not panic. + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "hello"}) + r.TriggerMessageSending(ctx, &MessageSendingEvent{Content: "hello"}) + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{ToolName: "t"}) + r.TriggerAfterToolCall(ctx, &AfterToolCallEvent{ToolName: "t"}) + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "a"}) + r.TriggerLLMOutput(ctx, &LLMOutputEvent{AgentID: "a"}) + r.TriggerSessionStart(ctx, &SessionEvent{AgentID: "a"}) + r.TriggerSessionEnd(ctx, &SessionEvent{AgentID: "a"}) +} + +func TestVoidHookExecution(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var called atomic.Bool + r.OnMessageReceived("test", 0, func(_ context.Context, e *MessageReceivedEvent) error { + called.Store(true) + if e.Content != "ping" { + t.Errorf("Expected content 'ping', got '%s'", e.Content) + } + return nil + }) + + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "ping"}) + + if !called.Load() { + t.Error("Expected handler to be called") + } +} + +func TestVoidHooksConcurrent(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var count atomic.Int32 + started := make(chan struct{}, 5) + release := make(chan struct{}) + done := make(chan struct{}) + + for i := range 5 { + r.OnMessageReceived("hook-"+string(rune('A'+i)), i, func(_ context.Context, _ *MessageReceivedEvent) error { + started <- struct{}{} + <-release + count.Add(1) + return nil + }) + } + + go func() { + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "test"}) + close(done) + }() + + // All 5 handlers must reach the barrier concurrently. + for i := range 5 { + select { + case <-started: + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for handler %d to start", i+1) + } + } + + // Release all handlers. + close(release) + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for handlers to complete") + } + + if count.Load() != 5 { + t.Errorf("Expected 5 handlers called, got %d", count.Load()) + } +} + +func TestVoidHooksReceiveIsolatedMessageReceivedEvents(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnMessageReceived("mutator-a", 0, func(_ context.Context, e *MessageReceivedEvent) error { + e.Content = "changed-a" + e.Media[0] = "changed-media-a" + e.Metadata["k"] = "changed-a" + e.Metadata["new-a"] = "x" + return nil + }) + r.OnMessageReceived("mutator-b", 1, func(_ context.Context, e *MessageReceivedEvent) error { + e.Content = "changed-b" + e.Media = append(e.Media, "extra") + e.Metadata["k"] = "changed-b" + e.Metadata["new-b"] = "y" + return nil + }) + + event := &MessageReceivedEvent{ + Content: "original", + Media: []string{"m1"}, + Metadata: map[string]string{"k": "v"}, + } + r.TriggerMessageReceived(ctx, event) + + if event.Content != "original" { + t.Fatalf("expected original content to remain unchanged, got %q", event.Content) + } + if len(event.Media) != 1 || event.Media[0] != "m1" { + t.Fatalf("expected original media to remain unchanged, got %#v", event.Media) + } + if got := event.Metadata["k"]; got != "v" { + t.Fatalf("expected metadata[k] to remain v, got %q", got) + } + if _, ok := event.Metadata["new-a"]; ok { + t.Fatal("unexpected mutation leaked from hook mutator-a") + } + if _, ok := event.Metadata["new-b"]; ok { + t.Fatal("unexpected mutation leaked from hook mutator-b") + } +} + +func TestVoidHooksReceiveIsolatedAfterToolCallEvents(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnAfterToolCall("mutator-a", 0, func(_ context.Context, e *AfterToolCallEvent) error { + e.Args["k"] = "changed-a" + e.Result.ForLLM = "mutated-a" + return nil + }) + r.OnAfterToolCall("mutator-b", 1, func(_ context.Context, e *AfterToolCallEvent) error { + e.Args["k"] = "changed-b" + e.Args["new"] = "v" + e.Result.ForUser = "mutated-b" + return nil + }) + + event := &AfterToolCallEvent{ + ToolName: "shell", + Args: map[string]any{"k": "original"}, + Result: &tools.ToolResult{ + ForLLM: "for-llm", + ForUser: "for-user", + }, + } + + // Use a local copy so we can compare immutable expectations. + r.TriggerAfterToolCall(ctx, event) + + if got := event.Args["k"]; got != "original" { + t.Fatalf("expected args[k] to remain original, got %#v", got) + } + if _, ok := event.Args["new"]; ok { + t.Fatal("unexpected args mutation leaked from hook") + } + if event.Result.ForLLM != "for-llm" { + t.Fatalf("expected original result.ForLLM to remain unchanged, got %q", event.Result.ForLLM) + } + if event.Result.ForUser != "for-user" { + t.Fatalf("expected original result.ForUser to remain unchanged, got %q", event.Result.ForUser) + } +} + +func TestVoidHooksReceiveIsolatedLLMInputToolSchema(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnLLMInput("mutator", 0, func(_ context.Context, e *LLMInputEvent) error { + required, ok := e.Tools[0].Function.Parameters["required"].([]string) + if !ok { + t.Fatal("required should be []string") + } + required[0] = "mutated" + e.Tools[0].Function.Parameters["required"] = append(required, "extra") + return nil + }) + + event := &LLMInputEvent{ + AgentID: "a1", + Model: "m1", + Tools: []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "message", + Parameters: map[string]any{ + "type": "object", + "required": []string{"content"}, + }, + }, + }, + }, + } + + r.TriggerLLMInput(ctx, event) + + required, ok := event.Tools[0].Function.Parameters["required"].([]string) + if !ok { + t.Fatal("required should remain []string") + } + if len(required) != 1 || required[0] != "content" { + t.Fatalf("expected required to remain unchanged, got %#v", required) + } +} + +func TestVoidHooksReceiveIsolatedStructValuesInMap(t *testing.T) { + type schemaSpec struct { + Required []string + Meta map[string]string + } + + r := NewHookRegistry() + ctx := context.Background() + + r.OnLLMInput("struct-mutator", 0, func(_ context.Context, e *LLMInputEvent) error { + spec, ok := e.Tools[0].Function.Parameters["schema"].(schemaSpec) + if !ok { + t.Fatal("schema should be schemaSpec") + } + spec.Required[0] = "mutated" + spec.Meta["k"] = "changed" + e.Tools[0].Function.Parameters["schema"] = spec + return nil + }) + + event := &LLMInputEvent{ + AgentID: "a1", + Model: "m1", + Tools: []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "message", + Parameters: map[string]any{ + "schema": schemaSpec{ + Required: []string{"content"}, + Meta: map[string]string{"k": "v"}, + }, + }, + }, + }, + }, + } + + r.TriggerLLMInput(ctx, event) + + spec, ok := event.Tools[0].Function.Parameters["schema"].(schemaSpec) + if !ok { + t.Fatal("schema should remain schemaSpec") + } + if len(spec.Required) != 1 || spec.Required[0] != "content" { + t.Fatalf("expected required to remain unchanged, got %#v", spec.Required) + } + if got := spec.Meta["k"]; got != "v" { + t.Fatalf("expected meta[k] to remain v, got %q", got) + } +} + +func TestVoidHooksFailOpenOnSlowHandler(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + started := make(chan struct{}) + release := make(chan struct{}) + done := make(chan struct{}) + + r.OnLLMInput("slow", 0, func(_ context.Context, _ *LLMInputEvent) error { + close(started) + <-release + close(done) + return nil + }) + + begin := time.Now() + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "a1"}) + elapsed := time.Since(begin) + + if elapsed > voidHookWaitBudget*3 { + t.Fatalf("expected fail-open dispatch within budget, got %s", elapsed) + } + + select { + case <-started: + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for slow handler to start") + } + + close(release) + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for slow handler to finish after release") + } +} + +func TestModifyingHookPriority(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var mu sync.Mutex + var order []string + + // Register in reverse priority order to verify sorting. + r.OnMessageSending("third", 30, func(_ context.Context, _ *MessageSendingEvent) error { + mu.Lock() + order = append(order, "third") + mu.Unlock() + return nil + }) + r.OnMessageSending("first", 10, func(_ context.Context, _ *MessageSendingEvent) error { + mu.Lock() + order = append(order, "first") + mu.Unlock() + return nil + }) + r.OnMessageSending("second", 20, func(_ context.Context, _ *MessageSendingEvent) error { + mu.Lock() + order = append(order, "second") + mu.Unlock() + return nil + }) + + r.TriggerMessageSending(ctx, &MessageSendingEvent{Content: "hi"}) + + if len(order) != 3 { + t.Fatalf("Expected 3 handlers, got %d", len(order)) + } + if order[0] != "first" || order[1] != "second" || order[2] != "third" { + t.Errorf("Expected [first second third], got %v", order) + } +} + +func TestModifyingHookCancel(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var secondCalled bool + + r.OnMessageSending("canceler", 10, func(_ context.Context, e *MessageSendingEvent) error { + e.Cancel = true + e.CancelReason = "blocked" + return nil + }) + r.OnMessageSending("after-cancel", 20, func(_ context.Context, _ *MessageSendingEvent) error { + secondCalled = true + return nil + }) + + event := &MessageSendingEvent{Content: "hi"} + r.TriggerMessageSending(ctx, event) + + if !event.Cancel { + t.Error("Expected Cancel to be true") + } + if secondCalled { + t.Error("Expected second handler NOT to be called after cancel") + } +} + +func TestBeforeToolCallModification(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnBeforeToolCall("modifier", 10, func(_ context.Context, e *BeforeToolCallEvent) error { + e.Args["injected"] = "value" + return nil + }) + + event := &BeforeToolCallEvent{ + ToolName: "search", + Args: map[string]any{"query": "test"}, + } + r.TriggerBeforeToolCall(ctx, event) + + if event.Args["injected"] != "value" { + t.Error("Expected injected arg to persist") + } + if event.Args["query"] != "test" { + t.Error("Expected original arg to remain") + } +} + +func TestMessageSendingFilter(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnMessageSending("rewriter", 10, func(_ context.Context, e *MessageSendingEvent) error { + e.Content = "[filtered] " + e.Content + return nil + }) + + event := &MessageSendingEvent{Content: "hello world"} + r.TriggerMessageSending(ctx, event) + + if event.Content != "[filtered] hello world" { + t.Errorf("Expected '[filtered] hello world', got '%s'", event.Content) + } +} + +func TestZeroCostWhenEmpty(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // This is primarily a safety/smoke test — no panics, no allocations of note. + for range 100 { + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{}) + r.TriggerMessageSending(ctx, &MessageSendingEvent{}) + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{}) + r.TriggerAfterToolCall(ctx, &AfterToolCallEvent{}) + r.TriggerLLMInput(ctx, &LLMInputEvent{}) + r.TriggerLLMOutput(ctx, &LLMOutputEvent{}) + r.TriggerSessionStart(ctx, &SessionEvent{}) + r.TriggerSessionEnd(ctx, &SessionEvent{}) + } +} + +func TestLLMInputOutput(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var inputCalled, outputCalled atomic.Bool + + r.OnLLMInput("input-hook", 0, func(_ context.Context, e *LLMInputEvent) error { + if e.Model != "gpt-4" { + t.Errorf("Expected model 'gpt-4', got '%s'", e.Model) + } + inputCalled.Store(true) + return nil + }) + + r.OnLLMOutput("output-hook", 0, func(_ context.Context, e *LLMOutputEvent) error { + if e.Content != "response" { + t.Errorf("Expected content 'response', got '%s'", e.Content) + } + outputCalled.Store(true) + return nil + }) + + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "a1", Model: "gpt-4", Iteration: 1}) + r.TriggerLLMOutput(ctx, &LLMOutputEvent{AgentID: "a1", Model: "gpt-4", Content: "response", Iteration: 1}) + + if !inputCalled.Load() { + t.Error("Expected LLM input hook to be called") + } + if !outputCalled.Load() { + t.Error("Expected LLM output hook to be called") + } +} + +func TestSessionStartEnd(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var startCalled, endCalled atomic.Bool + + r.OnSessionStart("start-hook", 0, func(_ context.Context, e *SessionEvent) error { + if e.SessionKey != "sess-1" { + t.Errorf("Expected session key 'sess-1', got '%s'", e.SessionKey) + } + startCalled.Store(true) + return nil + }) + + r.OnSessionEnd("end-hook", 0, func(_ context.Context, e *SessionEvent) error { + if e.SessionKey != "sess-1" { + t.Errorf("Expected session key 'sess-1', got '%s'", e.SessionKey) + } + endCalled.Store(true) + return nil + }) + + event := &SessionEvent{AgentID: "a1", SessionKey: "sess-1", Channel: "test", ChatID: "c1"} + r.TriggerSessionStart(ctx, event) + r.TriggerSessionEnd(ctx, event) + + if !startCalled.Load() { + t.Error("Expected session start hook to be called") + } + if !endCalled.Load() { + t.Error("Expected session end hook to be called") + } +} + +func TestConcurrentRegistrationAndTrigger(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var wg sync.WaitGroup + + // Goroutines registering hooks. + for i := range 10 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r.OnMessageReceived( + fmt.Sprintf("reg-hook-%d", idx), + idx, + func(_ context.Context, _ *MessageReceivedEvent) error { + return nil + }, + ) + }(i) + } + + // Goroutines triggering hooks concurrently. + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "race"}) + }() + } + + wg.Wait() +} + +func TestInsertSorted(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var order []int + + // Register with priorities: 50, 10, 30, 20, 40 + priorities := []int{50, 10, 30, 20, 40} + for _, p := range priorities { + r.OnBeforeToolCall(fmt.Sprintf("p-%d", p), p, func(_ context.Context, _ *BeforeToolCallEvent) error { + order = append(order, p) + return nil + }) + } + + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{ToolName: "test", Args: map[string]any{}}) + + expected := []int{10, 20, 30, 40, 50} + if len(order) != len(expected) { + t.Fatalf("Expected %d handlers, got %d", len(expected), len(order)) + } + for i, v := range expected { + if order[i] != v { + t.Errorf("Position %d: expected priority %d, got %d", i, v, order[i]) + } + } +} + +func TestAfterToolCallExecution(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var called bool + var capturedName string + r.OnAfterToolCall("logger", 0, func(_ context.Context, event *AfterToolCallEvent) error { + called = true + capturedName = event.ToolName + return nil + }) + + r.TriggerAfterToolCall(ctx, &AfterToolCallEvent{ + ToolName: "shell", + Args: map[string]any{"cmd": "ls"}, + Channel: "telegram", + ChatID: "123", + }) + + if !called { + t.Error("Expected after_tool_call handler to be called") + } + if capturedName != "shell" { + t.Errorf("Expected ToolName 'shell', got '%s'", capturedName) + } +} + +func TestHandlerErrorsSwallowed(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // Test void hooks: error in one handler doesn't prevent others from running + var secondCalled bool + r.OnMessageReceived("erroring", 10, func(_ context.Context, _ *MessageReceivedEvent) error { + return fmt.Errorf("handler error") + }) + r.OnMessageReceived("observer", 20, func(_ context.Context, _ *MessageReceivedEvent) error { + secondCalled = true + return nil + }) + + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "test"}) + if !secondCalled { + t.Error("Expected second void handler to run despite first handler's error") + } + + // Test modifying hooks: error doesn't stop chain (only Cancel does) + var modifySecondCalled bool + r.OnMessageSending("erroring", 10, func(_ context.Context, _ *MessageSendingEvent) error { + return fmt.Errorf("handler error") + }) + r.OnMessageSending("modifier", 20, func(_ context.Context, _ *MessageSendingEvent) error { + modifySecondCalled = true + return nil + }) + + r.TriggerMessageSending(ctx, &MessageSendingEvent{Content: "test"}) + if !modifySecondCalled { + t.Error("Expected second modifying handler to run despite first handler's error") + } +} + +func TestPanicRecovery(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // Void hook: panic in one handler shouldn't crash, other handlers should still run + var safeHandlerCalled bool + r.OnLLMInput("panicker", 10, func(_ context.Context, _ *LLMInputEvent) error { + panic("boom") + }) + r.OnLLMInput("safe", 10, func(_ context.Context, _ *LLMInputEvent) error { + safeHandlerCalled = true + return nil + }) + + // Should not panic + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "test"}) + if !safeHandlerCalled { + t.Error("Expected safe handler to run despite panicking sibling") + } + + // Modifying hook: panic in handler shouldn't crash + r.OnBeforeToolCall("panicker", 10, func(_ context.Context, _ *BeforeToolCallEvent) error { + panic("boom") + }) + + // Should not panic + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{ToolName: "test"}) +} diff --git a/pkg/hooks/types.go b/pkg/hooks/types.go new file mode 100644 index 0000000000..4a0f6697d9 --- /dev/null +++ b/pkg/hooks/types.go @@ -0,0 +1,82 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package hooks + +import ( + "time" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// MessageReceivedEvent is fired when an inbound message is consumed from the bus. +type MessageReceivedEvent struct { + Channel string + SenderID string + ChatID string + Content string + Media []string + Metadata map[string]string +} + +// MessageSendingEvent is fired before an outbound message is published. +// Handlers can modify Content or set Cancel to block delivery. +type MessageSendingEvent struct { + Channel string + ChatID string + Content string // Modifiable + Cancel bool + CancelReason string +} + +// BeforeToolCallEvent is fired before a tool is executed. +// Handlers can modify Args, or set Cancel to block execution. +type BeforeToolCallEvent struct { + ToolName string + Args map[string]any // Modifiable; guaranteed non-nil when triggered via AgentLoop. + Channel string + ChatID string + Cancel bool + CancelReason string // Message returned to LLM when canceled +} + +// AfterToolCallEvent is fired after a tool completes execution. +type AfterToolCallEvent struct { + ToolName string + Args map[string]any + Channel string + ChatID string + Duration time.Duration + Result *tools.ToolResult +} + +// LLMInputEvent is fired before the LLM provider is called. +type LLMInputEvent struct { + AgentID string + Model string + Messages []providers.Message + Tools []providers.ToolDefinition + Iteration int +} + +// LLMOutputEvent is fired after the LLM provider responds. +type LLMOutputEvent struct { + AgentID string + Model string + Content string + ToolCalls []providers.ToolCall + Iteration int + Duration time.Duration +} + +// SessionEvent is fired at session start and end. +type SessionEvent struct { + AgentID string + SessionKey string + Channel string + ChatID string +} diff --git a/pkg/plugin/builtin/catalog.go b/pkg/plugin/builtin/catalog.go new file mode 100644 index 0000000000..807a4d0ecf --- /dev/null +++ b/pkg/plugin/builtin/catalog.go @@ -0,0 +1,31 @@ +package builtin + +import ( + "sort" + + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/plugin/demoplugin" +) + +// Factory creates one builtin plugin instance. +type Factory func() plugin.Plugin + +// Catalog returns compile-time builtin plugin factories by name. +func Catalog() map[string]Factory { + return map[string]Factory{ + "policy-demo": func() plugin.Plugin { + return demoplugin.NewPolicyDemoPlugin(demoplugin.PolicyDemoConfig{}) + }, + } +} + +// Names returns sorted builtin plugin names. +func Names() []string { + catalog := Catalog() + names := make([]string, 0, len(catalog)) + for name := range catalog { + names = append(names, name) + } + sort.Strings(names) + return names +} diff --git a/pkg/plugin/builtin/catalog_test.go b/pkg/plugin/builtin/catalog_test.go new file mode 100644 index 0000000000..8a48a3c997 --- /dev/null +++ b/pkg/plugin/builtin/catalog_test.go @@ -0,0 +1,32 @@ +package builtin + +import ( + "slices" + "testing" +) + +func TestCatalogContainsPolicyDemo(t *testing.T) { + catalog := Catalog() + factory, ok := catalog["policy-demo"] + if !ok { + t.Fatalf("Catalog() missing %q plugin", "policy-demo") + } + if factory == nil { + t.Fatalf("Catalog()[%q] factory is nil", "policy-demo") + } + if got := factory(); got == nil { + t.Fatalf("Catalog()[%q]() returned nil plugin", "policy-demo") + } +} + +func TestNamesSorted(t *testing.T) { + first := Names() + second := Names() + + if !slices.IsSorted(first) { + t.Fatalf("Names() is not sorted: %v", first) + } + if !slices.Equal(first, second) { + t.Fatalf("Names() is not deterministic across calls: %v vs %v", first, second) + } +} diff --git a/pkg/plugin/demoplugin/policy_demo.go b/pkg/plugin/demoplugin/policy_demo.go new file mode 100644 index 0000000000..6e89b1dae8 --- /dev/null +++ b/pkg/plugin/demoplugin/policy_demo.go @@ -0,0 +1,315 @@ +package demoplugin + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/hooks" + "github.com/sipeed/picoclaw/pkg/plugin" +) + +// PolicyDemoConfig controls the demo plugin behavior. +type PolicyDemoConfig struct { + BlockedTools []string + RedactPrefixes []string + ChannelToolAllowlist map[string][]string + DenyOutboundPatterns []string + MaxToolTimeoutSecond int +} + +// PolicyDemoStats provides basic evidence that hook paths were executed. +type PolicyDemoStats struct { + BeforeToolCalls int + BlockedToolCalls int + MessageSends int + RedactedMessages int + BlockedMessages int + SessionStarts int + SessionEnds int + AfterToolCalls int + TotalToolDuration time.Duration +} + +// PolicyDemoPlugin demonstrates why plugins are needed: it enforces runtime policy +// at tool-call and outbound-message lifecycle points and collects audit metrics. +type PolicyDemoPlugin struct { + blockedTools map[string]struct{} + prefixes []string + channelAllowlist map[string]map[string]struct{} + denyPatterns []string + maxTimeout int + + mu sync.Mutex + stats PolicyDemoStats +} + +func NewPolicyDemoPlugin(cfg PolicyDemoConfig) *PolicyDemoPlugin { + blocked := make(map[string]struct{}, len(cfg.BlockedTools)) + for _, t := range cfg.BlockedTools { + t = normalizeLower(t) + if t == "" { + continue + } + blocked[t] = struct{}{} + } + + prefixes := make([]string, 0, len(cfg.RedactPrefixes)) + for _, p := range cfg.RedactPrefixes { + p = strings.TrimSpace(p) + if p == "" { + continue + } + prefixes = append(prefixes, p) + } + + allowlist := make(map[string]map[string]struct{}, len(cfg.ChannelToolAllowlist)) + for channel, tools := range cfg.ChannelToolAllowlist { + channel = normalizeLower(channel) + if channel == "" { + continue + } + toolSet := make(map[string]struct{}, len(tools)) + for _, t := range tools { + t = normalizeLower(t) + if t == "" { + continue + } + toolSet[t] = struct{}{} + } + allowlist[channel] = toolSet + } + + patterns := make([]string, 0, len(cfg.DenyOutboundPatterns)) + for _, p := range cfg.DenyOutboundPatterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + patterns = append(patterns, p) + } + + maxTimeout := cfg.MaxToolTimeoutSecond + if maxTimeout < 0 { + maxTimeout = 0 + } + + return &PolicyDemoPlugin{ + blockedTools: blocked, + prefixes: prefixes, + channelAllowlist: allowlist, + denyPatterns: patterns, + maxTimeout: maxTimeout, + } +} + +func (p *PolicyDemoPlugin) Name() string { + return "policy-demo" +} + +func (p *PolicyDemoPlugin) APIVersion() string { + return plugin.APIVersion +} + +func (p *PolicyDemoPlugin) Snapshot() PolicyDemoStats { + p.mu.Lock() + defer p.mu.Unlock() + return p.stats +} + +func (p *PolicyDemoPlugin) Register(r *hooks.HookRegistry) error { + r.OnBeforeToolCall("policy-demo-tool-policy", 100, func(_ context.Context, e *hooks.BeforeToolCallEvent) error { + tool := normalizeLower(e.ToolName) + p.incBeforeToolCalls() + + if _, blocked := p.blockedTools[tool]; blocked { + e.Cancel = true + e.CancelReason = "blocked by policy-demo plugin" + p.incBlockedToolCalls() + return nil + } + + channel := normalizeLower(e.Channel) + if allow, ok := p.channelAllowlist[channel]; ok { + if _, allowed := allow[tool]; !allowed { + e.Cancel = true + e.CancelReason = fmt.Sprintf("tool %q is not allowed on channel %q", e.ToolName, e.Channel) + p.incBlockedToolCalls() + return nil + } + } + + if p.maxTimeout > 0 { + clampArgNumber(e.Args, "timeout", p.maxTimeout) + clampArgNumber(e.Args, "timeout_seconds", p.maxTimeout) + } + return nil + }) + + r.OnMessageSending("policy-demo-redact-and-guard", 50, func(_ context.Context, e *hooks.MessageSendingEvent) error { + p.incMessageSends() + + for _, pattern := range p.denyPatterns { + if strings.Contains(e.Content, pattern) { + e.Cancel = true + e.CancelReason = "blocked by policy-demo outbound guard" + p.incBlockedMessages() + return nil + } + } + + content := e.Content + redacted := false + for _, prefix := range p.prefixes { + next := strings.ReplaceAll(content, prefix, "[redacted]-") + if next != content { + redacted = true + } + content = next + } + e.Content = content + if redacted { + p.incRedactedMessages() + } + return nil + }) + + r.OnSessionStart("policy-demo-session-start-audit", 0, func(_ context.Context, _ *hooks.SessionEvent) error { + p.incSessionStarts() + return nil + }) + + r.OnSessionEnd("policy-demo-session-end-audit", 0, func(_ context.Context, _ *hooks.SessionEvent) error { + p.incSessionEnds() + return nil + }) + + r.OnAfterToolCall("policy-demo-after-tool-audit", 0, func(_ context.Context, e *hooks.AfterToolCallEvent) error { + p.incAfterToolCall(e.Duration) + return nil + }) + + return nil +} + +func normalizeLower(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} + +func clampArgNumber(args map[string]any, key string, limit int) { + if args == nil || limit <= 0 { + return + } + v, ok := args[key] + if !ok { + return + } + n, ok := toInt(v) + if !ok { + return + } + if n > limit { + args[key] = limit + } +} + +func toInt(v any) (int, bool) { + maxInt := int(^uint(0) >> 1) + maxIntU64 := uint64(maxInt) + maxInt64 := int64(maxInt) + minInt64 := -maxInt64 - 1 + + switch n := v.(type) { + case int: + return n, true + case int8: + return int(n), true + case int16: + return int(n), true + case int32: + return int(n), true + case int64: + if n < minInt64 || n > maxInt64 { + return 0, false + } + return int(n), true + case uint: + if uint64(n) > maxIntU64 { + return 0, false + } + return int(n), true + case uint8: + return int(n), true + case uint16: + return int(n), true + case uint32: + if uint64(n) > maxIntU64 { + return 0, false + } + return int(n), true + case uint64: + if n > maxIntU64 { + return 0, false + } + return int(n), true + case float32: + // Truncation is intentional for timeout normalization. + return int(n), true + case float64: + // Truncation is intentional for timeout normalization. + return int(n), true + default: + return 0, false + } +} + +func (p *PolicyDemoPlugin) incBeforeToolCalls() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.BeforeToolCalls++ +} + +func (p *PolicyDemoPlugin) incBlockedToolCalls() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.BlockedToolCalls++ +} + +func (p *PolicyDemoPlugin) incMessageSends() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.MessageSends++ +} + +func (p *PolicyDemoPlugin) incRedactedMessages() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.RedactedMessages++ +} + +func (p *PolicyDemoPlugin) incBlockedMessages() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.BlockedMessages++ +} + +func (p *PolicyDemoPlugin) incSessionStarts() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.SessionStarts++ +} + +func (p *PolicyDemoPlugin) incSessionEnds() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.SessionEnds++ +} + +func (p *PolicyDemoPlugin) incAfterToolCall(d time.Duration) { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.AfterToolCalls++ + p.stats.TotalToolDuration += d +} diff --git a/pkg/plugin/demoplugin/policy_demo_test.go b/pkg/plugin/demoplugin/policy_demo_test.go new file mode 100644 index 0000000000..4d41084f1c --- /dev/null +++ b/pkg/plugin/demoplugin/policy_demo_test.go @@ -0,0 +1,189 @@ +package demoplugin + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/hooks" + "github.com/sipeed/picoclaw/pkg/plugin" +) + +func TestPolicyDemoPluginBlocksConfiguredTool(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + BlockedTools: []string{"shell"}, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.BeforeToolCallEvent{ToolName: "shell", Args: map[string]any{}, Channel: "cli"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), e) + + if !e.Cancel { + t.Fatal("expected tool call to be canceled") + } + if e.CancelReason == "" { + t.Fatal("expected cancel reason") + } + + stats := p.Snapshot() + if stats.BeforeToolCalls != 1 || stats.BlockedToolCalls != 1 { + t.Fatalf("unexpected stats: %+v", stats) + } +} + +func TestPolicyDemoPluginRedactsOutboundContent(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + RedactPrefixes: []string{"sk-"}, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.MessageSendingEvent{Content: "token=sk-abc123"} + pm.HookRegistry().TriggerMessageSending(context.Background(), e) + + if e.Cancel { + t.Fatal("did not expect cancellation") + } + if e.Content != "token=[redacted]-abc123" { + t.Fatalf("unexpected redaction result: %q", e.Content) + } + + stats := p.Snapshot() + if stats.MessageSends != 1 || stats.RedactedMessages != 1 { + t.Fatalf("unexpected stats: %+v", stats) + } +} + +func TestPolicyDemoPluginChannelAllowlist(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + ChannelToolAllowlist: map[string][]string{ + "telegram": {"web_search"}, + }, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + blocked := &hooks.BeforeToolCallEvent{ToolName: "shell", Args: map[string]any{}, Channel: "telegram"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), blocked) + if !blocked.Cancel { + t.Fatal("expected tool to be blocked by channel allowlist") + } + + allowed := &hooks.BeforeToolCallEvent{ToolName: "web_search", Args: map[string]any{}, Channel: "telegram"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), allowed) + if allowed.Cancel { + t.Fatalf("did not expect allowlisted tool to be blocked: %s", allowed.CancelReason) + } +} + +func TestPolicyDemoPluginOutboundGuard(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + DenyOutboundPatterns: []string{"4111-1111-1111-1111", "@corp.internal"}, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.MessageSendingEvent{Content: "card=4111-1111-1111-1111"} + pm.HookRegistry().TriggerMessageSending(context.Background(), e) + if !e.Cancel { + t.Fatal("expected outbound message to be blocked") + } + if e.CancelReason == "" { + t.Fatal("expected block reason") + } + + stats := p.Snapshot() + if stats.BlockedMessages != 1 { + t.Fatalf("expected blocked message count to be 1, got %+v", stats) + } +} + +func TestPolicyDemoPluginNormalizesTimeoutArg(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{MaxToolTimeoutSecond: 30}) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.BeforeToolCallEvent{ + ToolName: "web_fetch", + Channel: "cli", + Args: map[string]any{ + "timeout": 120, + "timeout_seconds": 90.0, + }, + } + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), e) + + if got, ok := e.Args["timeout"].(int); !ok || got != 30 { + t.Fatalf("expected timeout to be clamped to 30, got %#v", e.Args["timeout"]) + } + if got, ok := e.Args["timeout_seconds"].(int); !ok || got != 30 { + t.Fatalf("expected timeout_seconds to be clamped to 30, got %#v", e.Args["timeout_seconds"]) + } +} + +func TestPolicyDemoPluginAuditHooks(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{}) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + pm.HookRegistry().TriggerSessionStart(context.Background(), &hooks.SessionEvent{AgentID: "a1", SessionKey: "s1"}) + pm.HookRegistry().TriggerAfterToolCall( + context.Background(), + &hooks.AfterToolCallEvent{ + ToolName: "web_search", + Duration: 45 * time.Millisecond, + }, + ) + pm.HookRegistry().TriggerSessionEnd(context.Background(), &hooks.SessionEvent{AgentID: "a1", SessionKey: "s1"}) + + stats := p.Snapshot() + if stats.SessionStarts != 1 || stats.SessionEnds != 1 { + t.Fatalf("unexpected session stats: %+v", stats) + } + if stats.AfterToolCalls != 1 || stats.TotalToolDuration != 45*time.Millisecond { + t.Fatalf("unexpected after_tool_call stats: %+v", stats) + } +} + +func TestPolicyDemoPluginNoConfigNoEffect(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{}) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + toolEvent := &hooks.BeforeToolCallEvent{ToolName: "shell", Args: map[string]any{}, Channel: "telegram"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), toolEvent) + if toolEvent.Cancel { + t.Fatal("did not expect cancellation with empty config") + } + + msgEvent := &hooks.MessageSendingEvent{Content: "token=sk-abc123"} + pm.HookRegistry().TriggerMessageSending(context.Background(), msgEvent) + if msgEvent.Content != "token=sk-abc123" { + t.Fatalf("did not expect content rewrite, got %q", msgEvent.Content) + } +} + +func TestToIntRejectsInt64OverflowOn32Bit(t *testing.T) { + if strconv.IntSize != 32 { + t.Skip("overflow scenario is specific to 32-bit int") + } + if _, ok := toInt(int64(1 << 40)); ok { + t.Fatal("expected overflow conversion to fail on 32-bit int") + } +} diff --git a/pkg/plugin/manager.go b/pkg/plugin/manager.go new file mode 100644 index 0000000000..c63f67e68a --- /dev/null +++ b/pkg/plugin/manager.go @@ -0,0 +1,275 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package plugin + +import ( + "errors" + "fmt" + "slices" + "sort" + "strings" + "sync" + + "github.com/sipeed/picoclaw/pkg/hooks" +) + +// APIVersion identifies the compile-time plugin contract version. +const APIVersion = "v1alpha1" + +// SelectionInput controls plugin enable/disable resolution. +type SelectionInput struct { + DefaultEnabled bool + Enabled []string + Disabled []string +} + +// SelectionResult is the normalized output of plugin enable/disable resolution. +type SelectionResult struct { + EnabledNames []string + DisabledNames []string + UnknownEnabled []string + UnknownDisabled []string + Warnings []string +} + +// Plugin is the Phase-1 compile-time contract for PicoClaw extensions. +type Plugin interface { + Name() string + APIVersion() string + Register(registry *hooks.HookRegistry) error +} + +// PluginInfo describes plugin metadata for introspection APIs. +type PluginInfo struct { + Name string `json:"name"` + APIVersion string `json:"api_version"` + Status string `json:"status"` +} + +// PluginDescriptor optionally provides richer plugin metadata. +type PluginDescriptor interface { + Info() PluginInfo +} + +// NormalizePluginName normalizes plugin names for deterministic matching. +func NormalizePluginName(name string) string { + return strings.ToLower(strings.TrimSpace(name)) +} + +// ResolveSelection resolves final enabled/disabled plugin names deterministically. +func ResolveSelection(available []string, in SelectionInput) (SelectionResult, error) { + result := SelectionResult{} + + availableSet := make(map[string]struct{}, len(available)) + for _, name := range available { + normalized := NormalizePluginName(name) + if normalized == "" { + continue + } + availableSet[normalized] = struct{}{} + } + + enabledSet := make(map[string]struct{}, len(in.Enabled)) + for _, name := range in.Enabled { + normalized := NormalizePluginName(name) + if _, exists := enabledSet[normalized]; exists { + result.Warnings = append(result.Warnings, fmt.Sprintf("duplicate enabled plugin %q ignored", normalized)) + continue + } + enabledSet[normalized] = struct{}{} + } + + disabledSet := make(map[string]struct{}, len(in.Disabled)) + for _, name := range in.Disabled { + normalized := NormalizePluginName(name) + if _, exists := disabledSet[normalized]; exists { + result.Warnings = append(result.Warnings, fmt.Sprintf("duplicate disabled plugin %q ignored", normalized)) + continue + } + disabledSet[normalized] = struct{}{} + } + + for name := range enabledSet { + if _, ok := availableSet[name]; !ok { + result.UnknownEnabled = append(result.UnknownEnabled, name) + } + } + sort.Strings(result.UnknownEnabled) + + for name := range disabledSet { + if _, ok := availableSet[name]; !ok { + result.UnknownDisabled = append(result.UnknownDisabled, name) + } + } + sort.Strings(result.UnknownDisabled) + for _, name := range result.UnknownDisabled { + result.Warnings = append(result.Warnings, fmt.Sprintf("unknown disabled plugin %q ignored", name)) + } + + resolvedEnabled := make(map[string]struct{}, len(availableSet)) + if len(enabledSet) > 0 { + for name := range enabledSet { + if _, ok := availableSet[name]; !ok { + continue + } + if _, disabled := disabledSet[name]; disabled { + continue + } + resolvedEnabled[name] = struct{}{} + } + } else if in.DefaultEnabled { + for name := range availableSet { + if _, disabled := disabledSet[name]; disabled { + continue + } + resolvedEnabled[name] = struct{}{} + } + } + + for name := range resolvedEnabled { + result.EnabledNames = append(result.EnabledNames, name) + } + sort.Strings(result.EnabledNames) + + for name := range availableSet { + if _, enabled := resolvedEnabled[name]; enabled { + continue + } + result.DisabledNames = append(result.DisabledNames, name) + } + sort.Strings(result.DisabledNames) + + if len(result.UnknownEnabled) > 0 { + return result, fmt.Errorf("unknown enabled plugins: %s", strings.Join(result.UnknownEnabled, ", ")) + } + return result, nil +} + +// Manager owns a shared hook registry and loaded plugin metadata. +type Manager struct { + mu sync.RWMutex + registry *hooks.HookRegistry + names []string + plugins []Plugin + seen map[string]struct{} +} + +// NewManager creates an empty plugin manager with a fresh hook registry. +func NewManager() *Manager { + return &Manager{ + registry: hooks.NewHookRegistry(), + seen: make(map[string]struct{}), + } +} + +// HookRegistry returns the shared registry where plugins register hooks. +func (m *Manager) HookRegistry() *hooks.HookRegistry { + return m.registry +} + +// Names returns loaded plugin names in registration order. +func (m *Manager) Names() []string { + m.mu.RLock() + defer m.mu.RUnlock() + return slices.Clone(m.names) +} + +// DescribeAll returns plugin metadata in registration order. +func (m *Manager) DescribeAll() []PluginInfo { + m.mu.RLock() + defer m.mu.RUnlock() + + infos := make([]PluginInfo, 0, len(m.plugins)) + for i, p := range m.plugins { + fallbackName := "" + if i < len(m.names) { + fallbackName = m.names[i] + } + infos = append(infos, normalizePluginInfo(p, fallbackName)) + } + return infos +} + +// DescribeEnabled returns metadata for currently enabled plugins. +func (m *Manager) DescribeEnabled() []PluginInfo { + return m.DescribeAll() +} + +// Register loads one plugin into the shared hook registry. +func (m *Manager) Register(p Plugin) error { + if p == nil { + return errors.New("plugin is nil") + } + name := strings.TrimSpace(p.Name()) + if name == "" { + return errors.New("plugin name is required") + } + if got := strings.TrimSpace(p.APIVersion()); got != APIVersion { + if got == "" { + got = "" + } + return fmt.Errorf( + "plugin %q api version mismatch: got %s, want %s", + name, + got, + APIVersion, + ) + } + + m.mu.Lock() + defer m.mu.Unlock() + if _, exists := m.seen[name]; exists { + return fmt.Errorf("plugin %q already registered", name) + } + if err := p.Register(m.registry); err != nil { + return fmt.Errorf("register plugin %q: %w", name, err) + } + m.seen[name] = struct{}{} + m.names = append(m.names, name) + m.plugins = append(m.plugins, p) + return nil +} + +// RegisterAll loads plugins sequentially. +func (m *Manager) RegisterAll(plugins ...Plugin) error { + for _, p := range plugins { + if err := m.Register(p); err != nil { + return err + } + } + return nil +} + +func normalizePluginInfo(p Plugin, fallbackName string) PluginInfo { + info := PluginInfo{ + Name: strings.TrimSpace(fallbackName), + APIVersion: strings.TrimSpace(p.APIVersion()), + Status: "enabled", + } + if descriptor, ok := p.(PluginDescriptor); ok { + described := descriptor.Info() + if name := strings.TrimSpace(described.Name); name != "" { + info.Name = name + } + if version := strings.TrimSpace(described.APIVersion); version != "" { + info.APIVersion = version + } + if status := strings.TrimSpace(described.Status); status != "" { + info.Status = status + } + } + if info.Name == "" { + info.Name = strings.TrimSpace(p.Name()) + } + if info.APIVersion == "" { + info.APIVersion = APIVersion + } + if info.Status == "" { + info.Status = "enabled" + } + return info +} diff --git a/pkg/plugin/manager_test.go b/pkg/plugin/manager_test.go new file mode 100644 index 0000000000..c7423fd96d --- /dev/null +++ b/pkg/plugin/manager_test.go @@ -0,0 +1,374 @@ +package plugin + +import ( + "context" + "errors" + "slices" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/hooks" +) + +type testPlugin struct { + name string + apiVersion string + registerFn func(*hooks.HookRegistry) error +} + +func (p testPlugin) Name() string { + return p.name +} + +func (p testPlugin) Register(r *hooks.HookRegistry) error { + if p.registerFn != nil { + return p.registerFn(r) + } + return nil +} + +func (p testPlugin) APIVersion() string { + if p.apiVersion == "" { + return APIVersion + } + return p.apiVersion +} + +type descriptorTestPlugin struct { + testPlugin + info PluginInfo +} + +func (p descriptorTestPlugin) Info() PluginInfo { + return p.info +} + +func TestNewManager(t *testing.T) { + m := NewManager() + if m == nil { + t.Fatal("expected manager") + } + if m.HookRegistry() == nil { + t.Fatal("expected non-nil hook registry") + } + if len(m.Names()) != 0 { + t.Fatalf("expected empty names, got %v", m.Names()) + } +} + +func TestRegisterPluginAndTriggerHook(t *testing.T) { + m := NewManager() + called := false + p := testPlugin{ + name: "audit", + registerFn: func(r *hooks.HookRegistry) error { + r.OnSessionStart("audit-session", 0, func(_ context.Context, _ *hooks.SessionEvent) error { + called = true + return nil + }) + return nil + }, + } + + if err := m.Register(p); err != nil { + t.Fatalf("Register() error = %v", err) + } + if got := m.Names(); len(got) != 1 || got[0] != "audit" { + t.Fatalf("unexpected names: %v", got) + } + + m.HookRegistry().TriggerSessionStart(context.Background(), &hooks.SessionEvent{ + AgentID: "a1", + SessionKey: "s1", + }) + if !called { + t.Fatal("expected plugin hook to be called") + } +} + +func TestRegisterRejectsNilPlugin(t *testing.T) { + m := NewManager() + if err := m.Register(nil); err == nil { + t.Fatal("expected error for nil plugin") + } +} + +func TestRegisterRejectsEmptyName(t *testing.T) { + m := NewManager() + if err := m.Register(testPlugin{}); err == nil { + t.Fatal("expected error for empty name") + } +} + +func TestRegisterRejectsDuplicateName(t *testing.T) { + m := NewManager() + p := testPlugin{name: "dup"} + if err := m.Register(p); err != nil { + t.Fatalf("unexpected first register error: %v", err) + } + if err := m.Register(p); err == nil { + t.Fatal("expected duplicate name error") + } +} + +func TestRegisterPropagatesPluginError(t *testing.T) { + m := NewManager() + want := errors.New("register failed") + p := testPlugin{ + name: "bad", + registerFn: func(_ *hooks.HookRegistry) error { + return want + }, + } + err := m.Register(p) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, want) { + t.Fatalf("expected wrapped error %v, got %v", want, err) + } +} + +func TestRegisterRejectsPluginVersionMismatch(t *testing.T) { + m := NewManager() + p := testPlugin{ + name: "old-plugin", + apiVersion: "v0", + } + err := m.Register(p) + if err == nil { + t.Fatal("expected version mismatch error") + } +} + +func TestDescribeAll_UsesDescriptorWhenImplemented(t *testing.T) { + m := NewManager() + p := descriptorTestPlugin{ + testPlugin: testPlugin{name: "descriptor"}, + info: PluginInfo{ + Name: " descriptor-visible ", + APIVersion: " custom-v1 ", + Status: " active ", + }, + } + + if err := m.Register(p); err != nil { + t.Fatalf("Register() error = %v", err) + } + + got := m.DescribeAll() + want := []PluginInfo{ + { + Name: "descriptor-visible", + APIVersion: "custom-v1", + Status: "active", + }, + } + if !slices.Equal(got, want) { + t.Fatalf("DescribeAll() mismatch: got %v, want %v", got, want) + } +} + +func TestDescribeAll_FallsBackForPlainPlugin(t *testing.T) { + m := NewManager() + p := testPlugin{name: "plain"} + + if err := m.Register(p); err != nil { + t.Fatalf("Register() error = %v", err) + } + + got := m.DescribeAll() + want := []PluginInfo{ + { + Name: "plain", + APIVersion: APIVersion, + Status: "enabled", + }, + } + if !slices.Equal(got, want) { + t.Fatalf("DescribeAll() mismatch: got %v, want %v", got, want) + } +} + +func TestDescribeEnabled_MatchesDescribeAllForNow(t *testing.T) { + m := NewManager() + plain := testPlugin{name: "plain"} + described := descriptorTestPlugin{ + testPlugin: testPlugin{name: "described"}, + info: PluginInfo{ + Name: " described-visible ", + }, + } + + if err := m.RegisterAll(plain, described); err != nil { + t.Fatalf("RegisterAll() error = %v", err) + } + + all := m.DescribeAll() + enabled := m.DescribeEnabled() + if !slices.Equal(enabled, all) { + t.Fatalf("DescribeEnabled() mismatch: got %v, want %v", enabled, all) + } + + wantAll := []PluginInfo{ + { + Name: "plain", + APIVersion: APIVersion, + Status: "enabled", + }, + { + Name: "described-visible", + APIVersion: APIVersion, + Status: "enabled", + }, + } + if !slices.Equal(all, wantAll) { + t.Fatalf("DescribeAll() order/content mismatch: got %v, want %v", all, wantAll) + } +} + +func TestResolveSelection_DefaultEnabled(t *testing.T) { + result, err := ResolveSelection( + []string{"beta", "alpha", "gamma"}, + SelectionInput{ + DefaultEnabled: true, + Disabled: []string{"beta"}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"alpha", "gamma"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if !slices.Equal(result.DisabledNames, []string{"beta"}) { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } +} + +func TestResolveSelection_EnabledListOnly(t *testing.T) { + result, err := ResolveSelection( + []string{"a", "b", "c"}, + SelectionInput{ + DefaultEnabled: true, + Enabled: []string{"c", "a"}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"a", "c"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if !slices.Equal(result.DisabledNames, []string{"b"}) { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } +} + +func TestResolveSelection_DisabledWinsOverlap(t *testing.T) { + result, err := ResolveSelection( + []string{"a", "b", "c"}, + SelectionInput{ + Enabled: []string{"a", "b"}, + Disabled: []string{"b"}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"a"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if !slices.Equal(result.DisabledNames, []string{"b", "c"}) { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } +} + +func TestResolveSelection_UnknownEnabledFails(t *testing.T) { + result, err := ResolveSelection( + []string{"a"}, + SelectionInput{ + Enabled: []string{"missing"}, + }, + ) + if err == nil { + t.Fatal("expected error for unknown enabled plugin") + } + if !strings.Contains(err.Error(), "missing") { + t.Fatalf("expected error to mention unknown plugin, got %v", err) + } + if !slices.Equal(result.UnknownEnabled, []string{"missing"}) { + t.Fatalf("UnknownEnabled mismatch: got %v", result.UnknownEnabled) + } +} + +func TestResolveSelection_UnknownDisabledWarns(t *testing.T) { + result, err := ResolveSelection( + []string{"a"}, + SelectionInput{ + DefaultEnabled: true, + Disabled: []string{"missing"}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"a"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if len(result.DisabledNames) != 0 { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } + if !slices.Equal(result.UnknownDisabled, []string{"missing"}) { + t.Fatalf("UnknownDisabled mismatch: got %v", result.UnknownDisabled) + } + if !hasWarningSubstring(result.Warnings, `unknown disabled plugin "missing" ignored`) { + t.Fatalf("expected unknown disabled warning, got %v", result.Warnings) + } +} + +func TestResolveSelection_NormalizationAndDedupe(t *testing.T) { + result, err := ResolveSelection( + []string{" Alpha ", "beta", "gamma"}, + SelectionInput{ + Enabled: []string{"ALPHA", " alpha ", "BETA", "beta"}, + Disabled: []string{" beta", "BETA", "missing", " MISSING "}, + }, + ) + if err != nil { + t.Fatalf("ResolveSelection() error = %v", err) + } + if !slices.Equal(result.EnabledNames, []string{"alpha"}) { + t.Fatalf("EnabledNames mismatch: got %v", result.EnabledNames) + } + if !slices.Equal(result.DisabledNames, []string{"beta", "gamma"}) { + t.Fatalf("DisabledNames mismatch: got %v", result.DisabledNames) + } + if !slices.Equal(result.UnknownDisabled, []string{"missing"}) { + t.Fatalf("UnknownDisabled mismatch: got %v", result.UnknownDisabled) + } + if !hasWarningSubstring(result.Warnings, `duplicate enabled plugin "alpha" ignored`) { + t.Fatalf("expected duplicate enabled warning for alpha, got %v", result.Warnings) + } + if !hasWarningSubstring(result.Warnings, `duplicate enabled plugin "beta" ignored`) { + t.Fatalf("expected duplicate enabled warning for beta, got %v", result.Warnings) + } + if !hasWarningSubstring(result.Warnings, `duplicate disabled plugin "beta" ignored`) { + t.Fatalf("expected duplicate disabled warning for beta, got %v", result.Warnings) + } + if !hasWarningSubstring(result.Warnings, `duplicate disabled plugin "missing" ignored`) { + t.Fatalf("expected duplicate disabled warning for missing, got %v", result.Warnings) + } + if !hasWarningSubstring(result.Warnings, `unknown disabled plugin "missing" ignored`) { + t.Fatalf("expected unknown disabled warning, got %v", result.Warnings) + } +} + +func hasWarningSubstring(warnings []string, sub string) bool { + for _, warning := range warnings { + if strings.Contains(warning, sub) { + return true + } + } + return false +}