From bb7202462a7701b306fdec924fc3acc17cb10724 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Fri, 13 Mar 2026 13:56:50 +0100 Subject: [PATCH 1/4] feat(agent): steering --- docs/design/steering-spec.md | 257 +++++++++++++ docs/steering.md | 124 ++++++ pkg/agent/loop.go | 238 +++++++----- pkg/agent/steering.go | 188 +++++++++ pkg/agent/steering_test.go | 713 +++++++++++++++++++++++++++++++++++ pkg/config/config.go | 1 + pkg/config/defaults.go | 1 + 7 files changed, 1424 insertions(+), 98 deletions(-) create mode 100644 docs/design/steering-spec.md create mode 100644 docs/steering.md create mode 100644 pkg/agent/steering.go create mode 100644 pkg/agent/steering_test.go diff --git a/docs/design/steering-spec.md b/docs/design/steering-spec.md new file mode 100644 index 0000000000..7cd8dc452c --- /dev/null +++ b/docs/design/steering-spec.md @@ -0,0 +1,257 @@ +# Steering — Implementation Specification + +## Problem + +When the agent is running (executing a chain of tool calls), the user has no way to redirect it. They must wait for the full cycle to complete before sending a new message. This creates a poor experience when the agent takes a wrong direction — the user watches it waste time on tools that are no longer relevant. + +## Solution + +Steering introduces a **message queue** that external callers can push into at any time. The agent loop polls this queue at well-defined checkpoints. When a steering message is found, the agent: + +1. Stops executing further tools in the current batch +2. Injects the user's message into the conversation context +3. Calls the LLM again with the updated context + +The user's intent reaches the model **as soon as the current tool finishes**, not after the entire turn completes. + +## Architecture Overview + +```mermaid +graph TD + subgraph External Callers + CH[Channel Handler] + API[HTTP API] + WS[WebSocket] + end + + subgraph AgentLoop + SQ[steeringQueue] + RLI[runLLMIteration] + TE[Tool Execution Loop] + LLM[LLM Call] + end + + CH -->|Steer| SQ + API -->|Steer| SQ + WS -->|Steer| SQ + + RLI -->|1. initial poll| SQ + TE -->|2. poll after each tool| SQ + TE -->|3. poll after last tool| SQ + + SQ -->|pendingMessages| RLI + RLI -->|inject into context| LLM +``` + +## Data Structures + +### steeringQueue + +A thread-safe FIFO queue, private to the `agent` package. + +| Field | Type | Description | +|-------|------|-------------| +| `mu` | `sync.Mutex` | Protects all access to `queue` and `mode` | +| `queue` | `[]providers.Message` | Pending steering messages | +| `mode` | `SteeringMode` | Dequeue strategy | + +**Methods:** + +| Method | Description | +|--------|-------------| +| `push(msg)` | Appends a message to the queue | +| `dequeue() []Message` | Removes and returns messages according to `mode`. Returns `nil` if empty | +| `len() int` | Returns the current queue length | +| `setMode(mode)` | Updates the dequeue strategy | +| `getMode() SteeringMode` | Returns the current mode | + +### SteeringMode + +| Value | Constant | Behavior | +|-------|----------|----------| +| `"one-at-a-time"` | `SteeringOneAtATime` | `dequeue()` returns only the **first** message. Remaining messages stay in the queue for subsequent polls. | +| `"all"` | `SteeringAll` | `dequeue()` drains the **entire** queue and returns all messages at once. | + +Default: `"one-at-a-time"`. + +### processOptions extension + +A new field was added to `processOptions`: + +| Field | Type | Description | +|-------|------|-------------| +| `SkipInitialSteeringPoll` | `bool` | When `true`, the initial steering poll at loop start is skipped. Used by `Continue()` to avoid double-dequeuing. | + +## Public API on AgentLoop + +| Method | Signature | Description | +|--------|-----------|-------------| +| `Steer` | `Steer(msg providers.Message)` | Enqueues a steering message. Thread-safe, can be called from any goroutine. | +| `SteeringMode` | `SteeringMode() SteeringMode` | Returns the current dequeue mode. | +| `SetSteeringMode` | `SetSteeringMode(mode SteeringMode)` | Changes the dequeue mode at runtime. | +| `Continue` | `Continue(ctx, sessionKey, channel, chatID) (string, error)` | Resumes an idle agent using pending steering messages. Returns `""` if queue is empty. | + +## Integration into the Agent Loop + +### Where steering is wired + +The steering queue lives as a field on `AgentLoop`: + +``` +AgentLoop + ├── bus + ├── cfg + ├── registry + ├── steering *steeringQueue ← new + ├── ... +``` + +It is initialized in `NewAgentLoop` from `cfg.Agents.Defaults.SteeringMode`. + +### Detailed flow through runLLMIteration + +```mermaid +sequenceDiagram + participant User + participant AgentLoop + participant runLLMIteration + participant ToolExecution + participant LLM + + User->>AgentLoop: Steer(message) + Note over AgentLoop: steeringQueue.push(message) + + Note over runLLMIteration: ── iteration starts ── + + runLLMIteration->>AgentLoop: dequeueSteeringMessages()
[initial poll] + AgentLoop-->>runLLMIteration: [] (empty, or messages) + + alt pendingMessages not empty + runLLMIteration->>runLLMIteration: inject into messages[]
save to session + end + + runLLMIteration->>LLM: Chat(messages, tools) + LLM-->>runLLMIteration: response with toolCalls[0..N] + + loop for each tool call (sequential) + alt i > 0 + ToolExecution->>AgentLoop: dequeueSteeringMessages() + AgentLoop-->>ToolExecution: steeringMessages + + alt steering found + Note over ToolExecution: Mark tool[i..N] as
"Skipped due to queued user message." + ToolExecution-->>runLLMIteration: steeringAfterTools = steeringMessages + Note over ToolExecution: break out of tool loop + end + end + + ToolExecution->>ToolExecution: execute tool[i] + ToolExecution->>ToolExecution: process result,
append to messages[] + + alt last tool (i == N-1) + ToolExecution->>AgentLoop: dequeueSteeringMessages() + AgentLoop-->>ToolExecution: steeringMessages (may be empty) + Note over ToolExecution: steeringAfterTools = steeringMessages + end + end + + alt steeringAfterTools not empty + ToolExecution-->>runLLMIteration: pendingMessages = steeringAfterTools + Note over runLLMIteration: next iteration will inject
these before calling LLM + end + + Note over runLLMIteration: ── loop back to iteration start ── +``` + +### Polling checkpoints + +| # | Location | When | Purpose | +|---|----------|------|---------| +| 1 | Top of `runLLMIteration`, before first LLM call | Once, at loop entry | Catch messages enqueued while the agent was still setting up context | +| 2 | Between tool calls, before tool `[i]` where `i > 0` | After each tool finishes | Interrupt mid-batch if the user sent a steering message | +| 3 | After the last tool in the batch | After tool `[N-1]` finishes | Catch messages that arrived during the last tool's execution | + +### What happens to skipped tools + +When steering interrupts a tool batch at index `i`, all tools from `i` to `N-1` are **not executed**. Instead, a tool result message is generated for each: + +```json +{ + "role": "tool", + "content": "Skipped due to queued user message.", + "tool_call_id": "" +} +``` + +These results are: +- Appended to the conversation `messages[]` +- Saved to the session via `AddFullMessage` + +This ensures the LLM knows which of its requested actions were not performed. + +### Loop condition change + +The iteration loop condition was changed from: + +```go +for iteration < agent.MaxIterations +``` + +to: + +```go +for iteration < agent.MaxIterations || len(pendingMessages) > 0 +``` + +This allows **one extra iteration** when steering arrives right at the max iteration boundary, ensuring the steering message is always processed. + +### Tool execution: parallel → sequential + +**Before steering:** all tool calls in a batch were executed in parallel using `sync.WaitGroup`. + +**After steering:** tool calls execute **sequentially**. This is required because steering must be polled between individual tool completions. A parallel execution model would not allow interrupting mid-batch. + +> **Trade-off:** This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal. The benefit of being able to interrupt outweighs the cost. + +## The Continue() method + +`Continue` handles the case where the agent is **idle** (its last message was from the assistant) and the user has enqueued steering messages in the meantime. + +```mermaid +flowchart TD + A[Continue called] --> B{dequeueSteeringMessages} + B -->|empty| C["return ('', nil)"] + B -->|messages found| D[Combine message contents] + D --> E["runAgentLoop with
SkipInitialSteeringPoll: true"] + E --> F[Return response] +``` + +**Why `SkipInitialSteeringPoll: true`?** Because `Continue` already dequeued the messages itself. Without this flag, `runLLMIteration` would poll again at the start and find nothing (the queue is already empty), or worse, double-process if new messages arrived in the meantime. + +## Configuration + +```json +{ + "agents": { + "defaults": { + "steering_mode": "one-at-a-time" + } + } +} +``` + +| Field | Type | Default | Env var | +|-------|------|---------|---------| +| `steering_mode` | `string` | `"one-at-a-time"` | `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` | + + +## Design decisions and trade-offs + +| Decision | Rationale | +|----------|-----------| +| Sequential tool execution | Required for per-tool steering polls. Parallel execution cannot be interrupted mid-batch. | +| Polling-based (not channel/signal) | Keeps the implementation simple. No need for `select` or signal channels. The polling cost is negligible (mutex lock + slice length check). | +| `one-at-a-time` as default | Gives the model a chance to react to each steering message individually. More predictable behavior than dumping all messages at once. | +| Skipped tools get explicit error results | The LLM protocol requires a tool result for every tool call in the assistant message. Omitting them would cause API errors. The skip message also informs the model about what was not done. | +| `Continue()` uses `SkipInitialSteeringPoll` | Prevents race conditions and double-dequeuing when resuming an idle agent. | +| Queue stored on `AgentLoop`, not `AgentInstance` | Steering is a loop-level concern (it affects the iteration flow), not a per-agent concern. All agents share the same steering queue since `processMessage` is sequential. | diff --git a/docs/steering.md b/docs/steering.md new file mode 100644 index 0000000000..6f4cbd27b4 --- /dev/null +++ b/docs/steering.md @@ -0,0 +1,124 @@ +# Steering + +Steering allows injecting messages into an already-running agent loop, interrupting it between tool calls without waiting for the entire cycle to complete. + +## How it works + +When the agent is executing a sequence of tool calls (e.g. the model requested 3 tools in a single turn), steering checks the queue **after each tool** completes. If it finds queued messages: + +1. The remaining tools are **skipped** and receive `"Skipped due to queued user message."` as their result +2. The steering messages are **injected into the conversation context** +3. The model is called again with the updated context, including the user's steering message + +``` +User ──► Steer("change approach") + │ +Agent Loop ▼ + ├─ tool[0] ✔ (executed) + ├─ [polling] → steering found! + ├─ tool[1] ✘ (skipped) + ├─ tool[2] ✘ (skipped) + └─ new LLM turn with steering message +``` + +## Configuration + +In `config.json`, under `agents.defaults`: + +```json +{ + "agents": { + "defaults": { + "steering_mode": "one-at-a-time" + } + } +} +``` + +### Modes + +| Value | Behavior | +|-------|----------| +| `"one-at-a-time"` | **(default)** Dequeues only one message per polling cycle. If there are 3 messages in the queue, they are processed one at a time across 3 successive iterations. | +| `"all"` | Drains the entire queue in a single poll. All pending messages are injected into the context together. | + +The environment variable `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` can be used as an alternative. + +## Go API + +### Steer — Send a steering message + +```go +agentLoop.Steer(providers.Message{ + Role: "user", + Content: "change direction, focus on X instead", +}) +``` + +The message is enqueued in a thread-safe manner. It will be picked up at the next polling point (after the current tool finishes). + +### SteeringMode / SetSteeringMode + +```go +// Read the current mode +mode := agentLoop.SteeringMode() // SteeringOneAtATime | SteeringAll + +// Change it at runtime +agentLoop.SetSteeringMode(agent.SteeringAll) +``` + +### Continue — Resume an idle agent + +When the agent is idle (it has finished processing and its last message was from the assistant), `Continue` checks if there are steering messages in the queue and uses them to start a new cycle: + +```go +response, err := agentLoop.Continue(ctx, sessionKey, channel, chatID) +if response == "" { + // No steering messages in queue, the agent stays idle +} +``` + +`Continue` internally uses `SkipInitialSteeringPoll: true` to avoid double-dequeuing the same messages (since it already extracted them and passes them directly as input). + +## Polling points in the loop + +Steering is checked at **three points** in the agent cycle: + +1. **At loop start** — before the first LLM call, to catch messages enqueued during setup +2. **After each tool** — between tool calls within the same batch +3. **After the last tool** — to catch messages that arrived while the last tool was executing + +## Skipped tool behavior + +When steering interrupts a batch of tool calls, the tools that were not yet executed receive a `tool` result with: + +``` +Content: "Skipped due to queued user message." +``` + +This is saved to the session and sent to the model, so it is aware that some requested actions were not performed. + +## Full flow example + +``` +1. User: "search for info on X, write a file, and send me a message" + +2. LLM responds with 3 tool calls: [web_search, write_file, message] + +3. web_search is executed → result saved + +4. [polling] → User called Steer("no, search for Y instead") + +5. write_file is skipped → "Skipped due to queued user message." + message is skipped → "Skipped due to queued user message." + +6. Message "search for Y instead" injected into context + +7. LLM receives the full updated context and responds accordingly +``` + +## Notes + +- Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue. +- With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually. +- With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index dfa339dee1..f3ed3ca6a8 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -48,6 +48,7 @@ type AgentLoop struct { transcriber voice.Transcriber cmdRegistry *commands.Registry mcp mcpRuntime + steering *steeringQueue mu sync.RWMutex // Track active requests for safe provider cleanup activeRequests sync.WaitGroup @@ -55,15 +56,16 @@ type AgentLoop struct { // processOptions configures how a message is processed type processOptions struct { - SessionKey string // Session identifier for history/context - Channel string // Target channel for tool execution - ChatID string // Target chat ID for tool execution - UserMessage string // User message content (may include prefix) - Media []string // media:// refs from inbound message - DefaultResponse string // Response when LLM returns empty - EnableSummary bool // Whether to trigger summarization - SendResponse bool // Whether to send response via bus - NoHistory bool // If true, don't load session history (for heartbeat) + SessionKey string // Session identifier for history/context + Channel string // Target channel for tool execution + ChatID string // Target chat ID for tool execution + UserMessage string // User message content (may include prefix) + Media []string // media:// refs from inbound message + DefaultResponse string // Response when LLM returns empty + EnableSummary bool // Whether to trigger summarization + SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) + SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue) } const ( @@ -105,6 +107,7 @@ func NewAgentLoop( summarizing: sync.Map{}, fallback: fallbackChain, cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()), + steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)), } return al @@ -999,6 +1002,16 @@ func (al *AgentLoop) runLLMIteration( ) (string, int, error) { iteration := 0 var finalContent string + var pendingMessages []providers.Message + + // Poll for steering messages at loop start (in case the user typed while + // the agent was setting up), unless the caller already provided initial + // steering messages (e.g. Continue). + if !opts.SkipInitialSteeringPoll { + if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 { + pendingMessages = msgs + } + } // Determine effective model tier for this conversation turn. // selectCandidates evaluates routing once and the decision is sticky for @@ -1006,9 +1019,25 @@ func (al *AgentLoop) runLLMIteration( // tool chain doesn't switch models mid-way through. activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages) - for iteration < agent.MaxIterations { + for iteration < agent.MaxIterations || len(pendingMessages) > 0 { iteration++ + // Inject pending steering messages into the conversation context + // before the next LLM call. + if len(pendingMessages) > 0 { + for _, pm := range pendingMessages { + messages = append(messages, pm) + agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content) + logger.InfoCF("agent", "Injected steering message into context", + map[string]any{ + "agent_id": agent.ID, + "iteration": iteration, + "content_len": len(pm.Content), + }) + } + pendingMessages = nil + } + logger.DebugCF("agent", "LLM iteration", map[string]any{ "agent_id": agent.ID, @@ -1251,107 +1280,110 @@ func (al *AgentLoop) runLLMIteration( // Save assistant message with tool calls to session agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) - // Execute tool calls in parallel - type indexedAgentResult struct { - result *tools.ToolResult - tc providers.ToolCall - } - - agentResults := make([]indexedAgentResult, len(normalizedToolCalls)) - var wg sync.WaitGroup + // Execute tool calls sequentially. After each tool completes, check + // for steering messages. If any are found, skip remaining tools. + var steeringAfterTools []providers.Message for i, tc := range normalizedToolCalls { - agentResults[i].tc = tc - - wg.Add(1) - go func(idx int, tc providers.ToolCall) { - defer wg.Done() - - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := utils.Truncate(string(argsJSON), 200) - logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]any{ - "agent_id": agent.ID, - "tool": tc.Name, - "iteration": iteration, - }) - - // Create async callback for tools that implement AsyncExecutor. - // When the background work completes, this publishes the result - // as an inbound system message so processSystemMessage routes it - // back to the user via the normal agent loop. - asyncCallback := func(_ context.Context, result *tools.ToolResult) { - // Send ForUser content directly to the user (immediate feedback), - // mirroring the synchronous tool execution path. - if !result.Silent && result.ForUser != "" { - outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer outCancel() - _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, - Content: result.ForUser, + // Check for steering before executing (except for the first tool) + if i > 0 { + if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + steeringAfterTools = steerMsgs + logger.InfoCF("agent", "Steering interrupt: skipping remaining tools", + map[string]any{ + "agent_id": agent.ID, + "skipped_from": i, + "total_tools": len(normalizedToolCalls), + "steering_count": len(steerMsgs), }) - } - // Determine content for the agent loop (ForLLM or error). - content := result.ForLLM - if content == "" && result.Err != nil { - content = result.Err.Error() - } - if content == "" { - return + // Mark remaining tool calls as skipped + for j := i; j < len(normalizedToolCalls); j++ { + skippedTC := normalizedToolCalls[j] + toolResultMsg := providers.Message{ + Role: "tool", + Content: "Skipped due to queued user message.", + ToolCallID: skippedTC.ID, + } + messages = append(messages, toolResultMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) } + break + } + } - logger.InfoCF("agent", "Async tool completed, publishing result", - map[string]any{ - "tool": tc.Name, - "content_len": len(content), - "channel": opts.Channel, - }) + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "agent_id": agent.ID, + "tool": tc.Name, + "iteration": iteration, + }) - pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer pubCancel() - _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ - Channel: "system", - SenderID: fmt.Sprintf("async:%s", tc.Name), - ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID), - Content: content, + // Create async callback for tools that implement AsyncExecutor. + asyncCallback := func(_ context.Context, result *tools.ToolResult) { + if !result.Silent && result.ForUser != "" { + outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer outCancel() + _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: result.ForUser, }) } - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, - asyncCallback, - ) - agentResults[idx].result = toolResult - }(i, tc) - } - wg.Wait() + content := result.ForLLM + if content == "" && result.Err != nil { + content = result.Err.Error() + } + if content == "" { + return + } + + logger.InfoCF("agent", "Async tool completed, publishing result", + map[string]any{ + "tool": tc.Name, + "content_len": len(content), + "channel": opts.Channel, + }) - // Process results in original order (send to user, save to session) - for _, r := range agentResults { - // Send ForUser content to user immediately if not Silent - if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse { + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ + Channel: "system", + SenderID: fmt.Sprintf("async:%s", tc.Name), + ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID), + Content: content, + }) + } + + toolResult := agent.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + opts.Channel, + opts.ChatID, + asyncCallback, + ) + + // Process tool result + if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, - Content: r.result.ForUser, + Content: toolResult.ForUser, }) logger.DebugCF("agent", "Sent tool result to user", map[string]any{ - "tool": r.tc.Name, - "content_len": len(r.result.ForUser), + "tool": tc.Name, + "content_len": len(toolResult.ForUser), }) } - // If tool returned media refs, publish them as outbound media - if len(r.result.Media) > 0 { - parts := make([]bus.MediaPart, 0, len(r.result.Media)) - for _, ref := range r.result.Media { + if len(toolResult.Media) > 0 { + parts := make([]bus.MediaPart, 0, len(toolResult.Media)) + for _, ref := range toolResult.Media { part := bus.MediaPart{Ref: ref} if al.mediaStore != nil { if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { @@ -1369,21 +1401,31 @@ func (al *AgentLoop) runLLMIteration( }) } - // Determine content for LLM based on tool result - contentForLLM := r.result.ForLLM - if contentForLLM == "" && r.result.Err != nil { - contentForLLM = r.result.Err.Error() + contentForLLM := toolResult.ForLLM + if contentForLLM == "" && toolResult.Err != nil { + contentForLLM = toolResult.Err.Error() } toolResultMsg := providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: r.tc.ID, + ToolCallID: tc.ID, } messages = append(messages, toolResultMsg) - - // Save tool result message to session agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + + // After the last tool, also check for steering messages. + if i == len(normalizedToolCalls)-1 { + if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + steeringAfterTools = steerMsgs + } + } + } + + // If steering messages were captured during tool execution, they + // become pendingMessages for the next iteration of the inner loop. + if len(steeringAfterTools) > 0 { + pendingMessages = steeringAfterTools } // Tick down TTL of discovered tools after processing tool results. diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go new file mode 100644 index 0000000000..8c7c79c160 --- /dev/null +++ b/pkg/agent/steering.go @@ -0,0 +1,188 @@ +package agent + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// SteeringMode controls how queued steering messages are dequeued. +type SteeringMode string + +const ( + // SteeringOneAtATime dequeues only the first queued message per poll. + SteeringOneAtATime SteeringMode = "one-at-a-time" + // SteeringAll drains the entire queue in a single poll. + SteeringAll SteeringMode = "all" + // MaxQueueSize number of possible messages in the Steering Queue + MaxQueueSize = 10 +) + +// parseSteeringMode normalizes a config string into a SteeringMode. +func parseSteeringMode(s string) SteeringMode { + switch s { + case "all": + return SteeringAll + default: + return SteeringOneAtATime + } +} + +// steeringQueue is a thread-safe queue of user messages that can be injected +// into a running agent loop to interrupt it between tool calls. +type steeringQueue struct { + mu sync.Mutex + queue []providers.Message + mode SteeringMode +} + +func newSteeringQueue(mode SteeringMode) *steeringQueue { + return &steeringQueue{ + mode: mode, + } +} + +// push enqueues a steering message. +func (sq *steeringQueue) push(msg providers.Message) error { + sq.mu.Lock() + defer sq.mu.Unlock() + if len(sq.queue) >= MaxQueueSize { + return fmt.Errorf("steering queue is full") + } + sq.queue = append(sq.queue, msg) + return nil +} + +// dequeue removes and returns pending steering messages according to the +// configured mode. Returns nil when the queue is empty. +func (sq *steeringQueue) dequeue() []providers.Message { + sq.mu.Lock() + defer sq.mu.Unlock() + + if len(sq.queue) == 0 { + return nil + } + + switch sq.mode { + case SteeringAll: + msgs := sq.queue + sq.queue = nil + return msgs + default: // one-at-a-time + msg := sq.queue[0] + sq.queue[0] = providers.Message{} // Clear reference for GC + sq.queue = sq.queue[1:] + return []providers.Message{msg} + } +} + +// len returns the number of queued messages. +func (sq *steeringQueue) len() int { + sq.mu.Lock() + defer sq.mu.Unlock() + return len(sq.queue) +} + +// setMode updates the steering mode. +func (sq *steeringQueue) setMode(mode SteeringMode) { + sq.mu.Lock() + defer sq.mu.Unlock() + sq.mode = mode +} + +// getMode returns the current steering mode. +func (sq *steeringQueue) getMode() SteeringMode { + sq.mu.Lock() + defer sq.mu.Unlock() + return sq.mode +} + +// --- AgentLoop steering API --- + +// Steer enqueues a user message to be injected into the currently running +// agent loop. The message will be picked up after the current tool finishes +// executing, causing any remaining tool calls in the batch to be skipped. +func (al *AgentLoop) Steer(msg providers.Message) error { + if al.steering == nil { + return fmt.Errorf("steering queue is not initialized") + } + if err := al.steering.push(msg); err != nil { + logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{ + "error": err.Error(), + "role": msg.Role, + }) + return err + } + logger.DebugCF("agent", "Steering message enqueued", map[string]any{ + "role": msg.Role, + "content_len": len(msg.Content), + "queue_len": al.steering.len(), + }) + + return nil +} + +// SteeringMode returns the current steering mode. +func (al *AgentLoop) SteeringMode() SteeringMode { + if al.steering == nil { + return SteeringOneAtATime + } + return al.steering.getMode() +} + +// SetSteeringMode updates the steering mode. +func (al *AgentLoop) SetSteeringMode(mode SteeringMode) { + if al.steering == nil { + return + } + al.steering.setMode(mode) +} + +// dequeueSteeringMessages is the internal method called by the agent loop +// to poll for steering messages. Returns nil when no messages are pending. +func (al *AgentLoop) dequeueSteeringMessages() []providers.Message { + if al.steering == nil { + return nil + } + return al.steering.dequeue() +} + +// Continue resumes an idle agent by dequeuing any pending steering messages +// and running them through the agent loop. This is used when the agent's last +// message was from the assistant (i.e., it has stopped processing) and the +// user has since enqueued steering messages. +// +// If no steering messages are pending, it returns an empty string. +func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) { + steeringMsgs := al.dequeueSteeringMessages() + if len(steeringMsgs) == 0 { + return "", nil + } + + agent := al.GetRegistry().GetDefaultAgent() + if agent == nil { + return "", fmt.Errorf("no default agent available") + } + + // Build a combined user message from the steering messages. + var contents []string + for _, msg := range steeringMsgs { + contents = append(contents, msg.Content) + } + combinedContent := strings.Join(contents, "\n") + + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, + Channel: channel, + ChatID: chatID, + UserMessage: combinedContent, + DefaultResponse: defaultResponse, + EnableSummary: true, + SendResponse: false, + SkipInitialSteeringPoll: true, + }) +} diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go new file mode 100644 index 0000000000..173814d066 --- /dev/null +++ b/pkg/agent/steering_test.go @@ -0,0 +1,713 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "os" + "sync" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// --- steeringQueue unit tests --- + +func TestSteeringQueue_PushDequeue_OneAtATime(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + + sq.push(providers.Message{Role: "user", Content: "msg1"}) + sq.push(providers.Message{Role: "user", Content: "msg2"}) + sq.push(providers.Message{Role: "user", Content: "msg3"}) + + if sq.len() != 3 { + t.Fatalf("expected 3 messages, got %d", sq.len()) + } + + msgs := sq.dequeue() + if len(msgs) != 1 { + t.Fatalf("expected 1 message in one-at-a-time mode, got %d", len(msgs)) + } + if msgs[0].Content != "msg1" { + t.Fatalf("expected 'msg1', got %q", msgs[0].Content) + } + if sq.len() != 2 { + t.Fatalf("expected 2 remaining, got %d", sq.len()) + } + + msgs = sq.dequeue() + if len(msgs) != 1 || msgs[0].Content != "msg2" { + t.Fatalf("expected 'msg2', got %v", msgs) + } + + msgs = sq.dequeue() + if len(msgs) != 1 || msgs[0].Content != "msg3" { + t.Fatalf("expected 'msg3', got %v", msgs) + } + + msgs = sq.dequeue() + if msgs != nil { + t.Fatalf("expected nil from empty queue, got %v", msgs) + } +} + +func TestSteeringQueue_PushDequeue_All(t *testing.T) { + sq := newSteeringQueue(SteeringAll) + + sq.push(providers.Message{Role: "user", Content: "msg1"}) + sq.push(providers.Message{Role: "user", Content: "msg2"}) + sq.push(providers.Message{Role: "user", Content: "msg3"}) + + msgs := sq.dequeue() + if len(msgs) != 3 { + t.Fatalf("expected 3 messages in all mode, got %d", len(msgs)) + } + if msgs[0].Content != "msg1" || msgs[1].Content != "msg2" || msgs[2].Content != "msg3" { + t.Fatalf("unexpected messages: %v", msgs) + } + + if sq.len() != 0 { + t.Fatalf("expected 0 remaining, got %d", sq.len()) + } + + msgs = sq.dequeue() + if msgs != nil { + t.Fatalf("expected nil from empty queue, got %v", msgs) + } +} + +func TestSteeringQueue_EmptyDequeue(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + if msgs := sq.dequeue(); msgs != nil { + t.Fatalf("expected nil, got %v", msgs) + } +} + +func TestSteeringQueue_SetMode(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + if sq.getMode() != SteeringOneAtATime { + t.Fatalf("expected one-at-a-time, got %v", sq.getMode()) + } + + sq.setMode(SteeringAll) + if sq.getMode() != SteeringAll { + t.Fatalf("expected all, got %v", sq.getMode()) + } + + // Push two messages and verify all-mode drains them + sq.push(providers.Message{Role: "user", Content: "a"}) + sq.push(providers.Message{Role: "user", Content: "b"}) + + msgs := sq.dequeue() + if len(msgs) != 2 { + t.Fatalf("expected 2 messages after mode switch, got %d", len(msgs)) + } +} + +func TestSteeringQueue_ConcurrentAccess(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + + var wg sync.WaitGroup + const n = MaxQueueSize + + // Push from multiple goroutines + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)}) + }(i) + } + wg.Wait() + + if sq.len() != n { + t.Fatalf("expected %d messages, got %d", n, sq.len()) + } + + // Drain from multiple goroutines + var drained int + var mu sync.Mutex + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if msgs := sq.dequeue(); len(msgs) > 0 { + mu.Lock() + drained += len(msgs) + mu.Unlock() + } + }() + } + wg.Wait() + + if drained != n { + t.Fatalf("expected to drain %d messages, got %d", n, drained) + } +} + +func TestSteeringQueue_Overflow(t *testing.T) { + sq := newSteeringQueue(SteeringOneAtATime) + + // Fill the queue up to its maximum capacity + for i := 0; i < MaxQueueSize; i++ { + err := sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)}) + if err != nil { + t.Fatalf("unexpected error pushing message %d: %v", i, err) + } + } + + // Sanity check: ensure the queue is actually full + if sq.len() != MaxQueueSize { + t.Fatalf("expected queue length %d, got %d", MaxQueueSize, sq.len()) + } + + // Attempt to push one more message, which MUST fail + err := sq.push(providers.Message{Role: "user", Content: "overflow_msg"}) + + // Assert the error happened and is the exact one we expect + if err == nil { + t.Fatal("expected an error when pushing to a full queue, but got nil") + } + + expectedErr := "steering queue is full" + if err.Error() != expectedErr { + t.Errorf("expected error message %q, got %q", expectedErr, err.Error()) + } +} + +func TestParseSteeringMode(t *testing.T) { + tests := []struct { + input string + expected SteeringMode + }{ + {"", SteeringOneAtATime}, + {"one-at-a-time", SteeringOneAtATime}, + {"all", SteeringAll}, + {"unknown", SteeringOneAtATime}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + if got := parseSteeringMode(tt.input); got != tt.expected { + t.Fatalf("parseSteeringMode(%q) = %v, want %v", tt.input, got, tt.expected) + } + }) + } +} + +// --- AgentLoop steering integration tests --- + +func TestAgentLoop_Steer_Enqueues(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + al.Steer(providers.Message{Role: "user", Content: "interrupt me"}) + + if al.steering.len() != 1 { + t.Fatalf("expected 1 steering message, got %d", al.steering.len()) + } + + msgs := al.dequeueSteeringMessages() + if len(msgs) != 1 || msgs[0].Content != "interrupt me" { + t.Fatalf("unexpected dequeued message: %v", msgs) + } +} + +func TestAgentLoop_SteeringMode_GetSet(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + if al.SteeringMode() != SteeringOneAtATime { + t.Fatalf("expected default mode one-at-a-time, got %v", al.SteeringMode()) + } + + al.SetSteeringMode(SteeringAll) + if al.SteeringMode() != SteeringAll { + t.Fatalf("expected all mode, got %v", al.SteeringMode()) + } +} + +func TestAgentLoop_SteeringMode_ConfiguredFromConfig(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + SteeringMode: "all", + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + if al.SteeringMode() != SteeringAll { + t.Fatalf("expected 'all' mode from config, got %v", al.SteeringMode()) + } +} + +func TestAgentLoop_Continue_NoMessages(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + resp, err := al.Continue(context.Background(), "test-session", "test", "chat1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp != "" { + t.Fatalf("expected empty response for no steering messages, got %q", resp) + } +} + +func TestAgentLoop_Continue_WithMessages(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "continued response"} + al := NewAgentLoop(cfg, msgBus, provider) + + al.Steer(providers.Message{Role: "user", Content: "new direction"}) + + resp, err := al.Continue(context.Background(), "test-session", "test", "chat1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp != "continued response" { + t.Fatalf("expected 'continued response', got %q", resp) + } +} + +// slowTool simulates a tool that takes some time to execute. +type slowTool struct { + name string + duration time.Duration + execCh chan struct{} // closed when Execute starts +} + +func (t *slowTool) Name() string { return t.name } +func (t *slowTool) Description() string { return "slow tool for testing" } +func (t *slowTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} +func (t *slowTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + if t.execCh != nil { + close(t.execCh) + } + time.Sleep(t.duration) + return tools.SilentResult(fmt.Sprintf("executed %s", t.name)) +} + +// toolCallProvider returns an LLM response with tool calls on the first call, +// then a direct response on subsequent calls. +type toolCallProvider struct { + mu sync.Mutex + calls int + toolCalls []providers.ToolCall + finalResp string +} + +func (m *toolCallProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.calls++ + + if m.calls == 1 && len(m.toolCalls) > 0 { + return &providers.LLMResponse{ + Content: "", + ToolCalls: m.toolCalls, + }, nil + } + + return &providers.LLMResponse{ + Content: m.finalResp, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *toolCallProvider) GetDefaultModel() string { + return "tool-call-mock" +} + +func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + tool1ExecCh := make(chan struct{}) + tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh} + tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond} + + provider := &toolCallProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "tool_one", + Function: &providers.FunctionCall{ + Name: "tool_one", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + { + ID: "call_2", + Type: "function", + Name: "tool_two", + Function: &providers.FunctionCall{ + Name: "tool_two", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "steered response", + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + al.RegisterTool(tool1) + al.RegisterTool(tool2) + + // Start processing in a goroutine + type result struct { + resp string + err error + } + resultCh := make(chan result, 1) + + go func() { + resp, err := al.ProcessDirectWithChannel( + context.Background(), + "do something", + "test-session", + "test", + "chat1", + ) + resultCh <- result{resp, err} + }() + + // Wait for tool_one to start executing, then enqueue a steering message + select { + case <-tool1ExecCh: + // tool_one has started executing + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for tool_one to start") + } + + al.Steer(providers.Message{Role: "user", Content: "change course"}) + + // Get the result + select { + case r := <-resultCh: + if r.err != nil { + t.Fatalf("unexpected error: %v", r.err) + } + if r.resp != "steered response" { + t.Fatalf("expected 'steered response', got %q", r.resp) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for agent loop to complete") + } + + // The provider should have been called twice: + // 1. first call returned tool calls + // 2. second call (after steering) returned the final response + provider.mu.Lock() + calls := provider.calls + provider.mu.Unlock() + if calls != 2 { + t.Fatalf("expected 2 provider calls, got %d", calls) + } +} + +func TestAgentLoop_Steering_InitialPoll(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Provider that captures messages it receives + var capturedMessages []providers.Message + var capMu sync.Mutex + provider := &capturingMockProvider{ + response: "ack", + captureFn: func(msgs []providers.Message) { + capMu.Lock() + capturedMessages = make([]providers.Message, len(msgs)) + copy(capturedMessages, msgs) + capMu.Unlock() + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + + // Enqueue a steering message before processing starts + al.Steer(providers.Message{Role: "user", Content: "pre-enqueued steering"}) + + // Process a normal message - the initial steering poll should inject the steering message + _, err = al.ProcessDirectWithChannel( + context.Background(), + "initial message", + "test-session", + "test", + "chat1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // The steering message should have been injected into the conversation + capMu.Lock() + msgs := capturedMessages + capMu.Unlock() + + // Look for the steering message in the captured messages + found := false + for _, m := range msgs { + if m.Content == "pre-enqueued steering" { + found = true + break + } + } + if !found { + t.Fatal("expected steering message to be injected into conversation context") + } +} + +// capturingMockProvider captures messages sent to Chat for inspection. +type capturingMockProvider struct { + response string + calls int + captureFn func([]providers.Message) +} + +func (m *capturingMockProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.captureFn != nil { + m.captureFn(messages) + } + return &providers.LLMResponse{ + Content: m.response, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *capturingMockProvider) GetDefaultModel() string { + return "capturing-mock" +} + +func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + execCh := make(chan struct{}) + tool1 := &slowTool{name: "slow_tool", duration: 50 * time.Millisecond, execCh: execCh} + tool2 := &slowTool{name: "skipped_tool", duration: 50 * time.Millisecond} + + // Provider that captures messages on the second call (after tools) + var secondCallMessages []providers.Message + var capMu sync.Mutex + callCount := 0 + + provider := &toolCallProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "slow_tool", + Function: &providers.FunctionCall{ + Name: "slow_tool", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + { + ID: "call_2", + Type: "function", + Name: "skipped_tool", + Function: &providers.FunctionCall{ + Name: "skipped_tool", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "done", + } + + // Wrap provider to capture messages on second call + wrappedProvider := &wrappingProvider{ + inner: provider, + onChat: func(msgs []providers.Message) { + capMu.Lock() + callCount++ + if callCount >= 2 { + secondCallMessages = make([]providers.Message, len(msgs)) + copy(secondCallMessages, msgs) + } + capMu.Unlock() + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, wrappedProvider) + al.RegisterTool(tool1) + al.RegisterTool(tool2) + + resultCh := make(chan string, 1) + go func() { + resp, _ := al.ProcessDirectWithChannel( + context.Background(), "go", "test-session", "test", "chat1", + ) + resultCh <- resp + }() + + <-execCh + al.Steer(providers.Message{Role: "user", Content: "interrupt!"}) + + select { + case <-resultCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + + // Check that the skipped tool result message is in the conversation + capMu.Lock() + msgs := secondCallMessages + capMu.Unlock() + + foundSkipped := false + for _, m := range msgs { + if m.Role == "tool" && m.ToolCallID == "call_2" && m.Content == "Skipped due to queued user message." { + foundSkipped = true + break + } + } + if !foundSkipped { + // Log what we actually got + for i, m := range msgs { + t.Logf("msg[%d]: role=%s toolCallID=%s content=%s", i, m.Role, m.ToolCallID, truncate(m.Content, 80)) + } + t.Fatal("expected skipped tool result for call_2") + } +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +// wrappingProvider wraps another provider to hook into Chat calls. +type wrappingProvider struct { + inner providers.LLMProvider + onChat func([]providers.Message) +} + +func (w *wrappingProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + if w.onChat != nil { + w.onChat(messages) + } + return w.inner.Chat(ctx, messages, tools, model, opts) +} + +func (w *wrappingProvider) GetDefaultModel() string { + return w.inner.GetDefaultModel() +} + +// Ensure NormalizeToolCall handles our test tool calls. +func init() { + // This is a no-op init; we just need the tool call tests to work + // with the proper argument serialization. + _ = json.Marshal +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 1903412248..a8b8f337fa 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -234,6 +234,7 @@ type AgentDefaults struct { SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` Routing *RoutingConfig `json:"routing,omitempty"` + SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all" } const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 189af0a845..5e6b89a4c1 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -35,6 +35,7 @@ func DefaultConfig() *Config { MaxToolIterations: 50, SummarizeMessageThreshold: 20, SummarizeTokenPercent: 75, + SteeringMode: "one-at-a-time", }, }, Bindings: []AgentBinding{}, From b5adb2bb02bdbd567a45a6c5d393b4fd1da09a84 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Fri, 13 Mar 2026 18:16:19 +0100 Subject: [PATCH 2/4] fix loop --- docs/design/steering-spec.md | 101 ++++++++++++++++++++++++++--------- docs/steering.md | 58 +++++++++++++++++--- pkg/agent/loop.go | 101 ++++++++++++++++++++++++----------- 3 files changed, 195 insertions(+), 65 deletions(-) diff --git a/docs/design/steering-spec.md b/docs/design/steering-spec.md index 7cd8dc452c..0951bf864e 100644 --- a/docs/design/steering-spec.md +++ b/docs/design/steering-spec.md @@ -19,30 +19,62 @@ The user's intent reaches the model **as soon as the current tool finishes**, no ```mermaid graph TD subgraph External Callers - CH[Channel Handler] - API[HTTP API] - WS[WebSocket] + TG[Telegram] + DC[Discord] + SL[Slack] end subgraph AgentLoop + BUS[MessageBus] + DRAIN[drainBusToSteering goroutine] SQ[steeringQueue] RLI[runLLMIteration] TE[Tool Execution Loop] LLM[LLM Call] end - CH -->|Steer| SQ - API -->|Steer| SQ - WS -->|Steer| SQ + TG -->|PublishInbound| BUS + DC -->|PublishInbound| BUS + SL -->|PublishInbound| BUS + + BUS -->|ConsumeInbound while busy| DRAIN + DRAIN -->|Steer| SQ RLI -->|1. initial poll| SQ TE -->|2. poll after each tool| SQ - TE -->|3. poll after last tool| SQ SQ -->|pendingMessages| RLI RLI -->|inject into context| LLM ``` +### Bus drain mechanism + +Channels (Telegram, Discord, etc.) publish messages to the `MessageBus` via `PublishInbound`. Without additional wiring, these messages would sit in the bus buffer until the current `processMessage` finishes — meaning steering would never work for real users. + +The solution: when `Run()` starts processing a message, it spawns a **drain goroutine** (`drainBusToSteering`) that keeps consuming from the bus and calling `Steer()`. When `processMessage` returns, the drain is canceled and normal consumption resumes. + +```mermaid +sequenceDiagram + participant Bus + participant Run + participant Drain + participant AgentLoop + + Run->>Bus: ConsumeInbound() → msg + Run->>Drain: spawn drainBusToSteering(ctx) + Run->>Run: processMessage(msg) + + Note over Drain: running concurrently + + Bus-->>Drain: ConsumeInbound() → newMsg + Drain->>AgentLoop: al.transcribeAudioInMessage(ctx, newMsg) + Drain->>AgentLoop: Steer(providers.Message{Content: newMsg.Content}) + + Run->>Run: processMessage returns + Run->>Drain: cancel context + Note over Drain: exits +``` + ## Data Structures ### steeringQueue @@ -59,7 +91,7 @@ A thread-safe FIFO queue, private to the `agent` package. | Method | Description | |--------|-------------| -| `push(msg)` | Appends a message to the queue | +| `push(msg) error` | Appends a message to the queue. Returns an error if the queue is full (`MaxQueueSize`) | | `dequeue() []Message` | Removes and returns messages according to `mode`. Returns `nil` if empty | | `len() int` | Returns the current queue length | | `setMode(mode)` | Updates the dequeue strategy | @@ -86,7 +118,7 @@ A new field was added to `processOptions`: | Method | Signature | Description | |--------|-----------|-------------| -| `Steer` | `Steer(msg providers.Message)` | Enqueues a steering message. Thread-safe, can be called from any goroutine. | +| `Steer` | `Steer(msg providers.Message) error` | Enqueues a steering message. Returns an error if the queue is full or not initialized. Thread-safe, can be called from any goroutine. | | `SteeringMode` | `SteeringMode() SteeringMode` | Returns the current dequeue mode. | | `SetSteeringMode` | `SetSteeringMode(mode SteeringMode)` | Changes the dequeue mode at runtime. | | `Continue` | `Continue(ctx, sessionKey, channel, chatID) (string, error)` | Resumes an idle agent using pending steering messages. Returns `""` if queue is empty. | @@ -134,24 +166,18 @@ sequenceDiagram LLM-->>runLLMIteration: response with toolCalls[0..N] loop for each tool call (sequential) - alt i > 0 - ToolExecution->>AgentLoop: dequeueSteeringMessages() - AgentLoop-->>ToolExecution: steeringMessages - - alt steering found - Note over ToolExecution: Mark tool[i..N] as
"Skipped due to queued user message." - ToolExecution-->>runLLMIteration: steeringAfterTools = steeringMessages - Note over ToolExecution: break out of tool loop - end - end - ToolExecution->>ToolExecution: execute tool[i] ToolExecution->>ToolExecution: process result,
append to messages[] - alt last tool (i == N-1) - ToolExecution->>AgentLoop: dequeueSteeringMessages() - AgentLoop-->>ToolExecution: steeringMessages (may be empty) + ToolExecution->>AgentLoop: dequeueSteeringMessages() + AgentLoop-->>ToolExecution: steeringMessages + + alt steering found + opt remaining tools > 0 + Note over ToolExecution: Mark tool[i+1..N-1] as
"Skipped due to queued user message." + end Note over ToolExecution: steeringAfterTools = steeringMessages + Note over ToolExecution: break out of tool loop end end @@ -168,12 +194,11 @@ sequenceDiagram | # | Location | When | Purpose | |---|----------|------|---------| | 1 | Top of `runLLMIteration`, before first LLM call | Once, at loop entry | Catch messages enqueued while the agent was still setting up context | -| 2 | Between tool calls, before tool `[i]` where `i > 0` | After each tool finishes | Interrupt mid-batch if the user sent a steering message | -| 3 | After the last tool in the batch | After tool `[N-1]` finishes | Catch messages that arrived during the last tool's execution | +| 2 | After every tool completes (including the first and the last) | Immediately after each tool's result is processed | Interrupt the batch as early as possible — if steering is found and there are remaining tools, they are all skipped | ### What happens to skipped tools -When steering interrupts a tool batch at index `i`, all tools from `i` to `N-1` are **not executed**. Instead, a tool result message is generated for each: +When steering interrupts a tool batch after tool `[i]` completes, all tools from `[i+1]` to `[N-1]` are **not executed**. Instead, a tool result message is generated for each: ```json { @@ -213,6 +238,27 @@ This allows **one extra iteration** when steering arrives right at the max itera > **Trade-off:** This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal. The benefit of being able to interrupt outweighs the cost. +### Why skip remaining tools (instead of letting them finish) + +Two strategies were considered when a steering message is detected mid-batch: + +1. **Skip remaining tools** (chosen) — stop executing, mark the rest as skipped, inject steering +2. **Finish all tools, then inject** — let everything run, append steering afterwards + +Strategy 2 was rejected for three reasons: + +**Irreversible side effects.** Tools can send emails, write files, spawn subagents, or call external APIs. If the user says "stop" or "change direction", those actions have already happened and cannot be undone. + +| Tool batch | Steering | Skip (1) | Finish (2) | +|---|---|---|---| +| `[search, send_email]` | "don't send it" | Email not sent | Email sent | +| `[query, write_file, spawn]` | "wrong database" | Only query runs | File + subagent wasted | +| `[fetch₁, fetch₂, fetch₃, write]` | topic change | 1 fetch | 3 fetches + write, all discarded | + +**Wasted latency.** Tools like web fetches and API calls take seconds each. In a 3-tool batch averaging 3-4s per tool, the user would wait 10+ seconds for work that gets thrown away. + +**The LLM retains full awareness.** Skipped tools receive an explicit `"Skipped due to queued user message."` result, so the model knows what was not done and can decide whether to re-execute with the new context or take a different path. + ## The Continue() method `Continue` handles the case where the agent is **idle** (its last message was from the assistant) and the user has enqueued steering messages in the meantime. @@ -255,3 +301,6 @@ flowchart TD | Skipped tools get explicit error results | The LLM protocol requires a tool result for every tool call in the assistant message. Omitting them would cause API errors. The skip message also informs the model about what was not done. | | `Continue()` uses `SkipInitialSteeringPoll` | Prevents race conditions and double-dequeuing when resuming an idle agent. | | Queue stored on `AgentLoop`, not `AgentInstance` | Steering is a loop-level concern (it affects the iteration flow), not a per-agent concern. All agents share the same steering queue since `processMessage` is sequential. | +| Bus drain goroutine in `Run()` | Channels (Telegram, Discord, etc.) publish to the bus via `PublishInbound`. Without the drain, messages would queue in the bus channel buffer and only be consumed after `processMessage` returns — defeating the purpose of steering. The drain goroutine bridges the gap by consuming new bus messages and calling `Steer()` while the agent is busy. | +| Audio transcription before steering | The drain goroutine calls `al.transcribeAudioInMessage(ctx, msg)` before steering, so voice messages are converted to text before the agent sees them. If transcription fails, the error is silently discarded and the original message is steered as-is. | +| `MaxQueueSize = 10` | Prevents unbounded memory growth if a user sends many messages while the agent is busy. Excess messages are dropped with a warning. | diff --git a/docs/steering.md b/docs/steering.md index 6f4cbd27b4..ad08f84250 100644 --- a/docs/steering.md +++ b/docs/steering.md @@ -49,13 +49,16 @@ The environment variable `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` can be used as ### Steer — Send a steering message ```go -agentLoop.Steer(providers.Message{ +err := agentLoop.Steer(providers.Message{ Role: "user", Content: "change direction, focus on X instead", }) +if err != nil { + // Queue is full (MaxQueueSize=10) or not initialized +} ``` -The message is enqueued in a thread-safe manner. It will be picked up at the next polling point (after the current tool finishes). +The message is enqueued in a thread-safe manner. Returns an error if the queue is full or not initialized. It will be picked up at the next polling point (after the current tool finishes). ### SteeringMode / SetSteeringMode @@ -73,6 +76,9 @@ When the agent is idle (it has finished processing and its last message was from ```go response, err := agentLoop.Continue(ctx, sessionKey, channel, chatID) +if err != nil { + // Error (e.g. "no default agent available") +} if response == "" { // No steering messages in queue, the agent stays idle } @@ -82,21 +88,48 @@ if response == "" { ## Polling points in the loop -Steering is checked at **three points** in the agent cycle: +Steering is checked at **two points** in the agent cycle: 1. **At loop start** — before the first LLM call, to catch messages enqueued during setup -2. **After each tool** — between tool calls within the same batch -3. **After the last tool** — to catch messages that arrived while the last tool was executing +2. **After every tool completes** — including the first and the last. If steering is found and there are remaining tools, they are all skipped immediately + +## Why remaining tools are skipped + +When a steering message is detected, all remaining tools in the batch are skipped rather than executed. The alternative — let all tools finish and inject the steering message afterwards — was considered and rejected. Here is why. + +### Preventing unwanted side effects + +Tools can have **irreversible side effects**. If the user says "no, wait" while the agent is mid-batch, executing the remaining tools means those side effects happen anyway: + +| Tool batch | Steering message | With skip | Without skip | +|---|---|---|---| +| `[web_search, send_email]` | "don't send it" | Email **not** sent | Email sent, damage done | +| `[query_db, write_file, spawn_agent]` | "use another database" | Only the query runs | File written + subagent spawned, all wasted | +| `[search₁, search₂, search₃, write_file]` | user changes topic entirely | 1 search | 3 searches + file write, all irrelevant | + +### Avoiding wasted time + +Tools that take seconds (web fetches, API calls, database queries) would all run to completion before the agent sees the user's correction. In a batch of 3 tools each taking 3-4 seconds, that's 10+ seconds of work that will be discarded. -## Skipped tool behavior +With skipping, the agent reacts as soon as the current tool finishes — typically within a few seconds instead of waiting for the entire batch. -When steering interrupts a batch of tool calls, the tools that were not yet executed receive a `tool` result with: +### The LLM gets full context + +Skipped tools receive an explicit error result (`"Skipped due to queued user message."`), so the model knows exactly which actions were not performed. It can then decide whether to re-execute them with the new context, or take a different path entirely. + +### Trade-off: sequential execution + +Skipping requires tools to run **sequentially** (the previous implementation ran them in parallel). This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal compared to the benefit of being able to stop unwanted actions. + +## Skipped tool result format + +When steering interrupts a batch, each tool that was not executed receives a `tool` result with: ``` Content: "Skipped due to queued user message." ``` -This is saved to the session and sent to the model, so it is aware that some requested actions were not performed. +This is saved to the session via `AddFullMessage` and sent to the model, so it is aware that some requested actions were not performed. ## Full flow example @@ -117,8 +150,17 @@ This is saved to the session and sent to the model, so it is aware that some req 7. LLM receives the full updated context and responds accordingly ``` +## Automatic bus drain + +When the agent loop (`Run()`) starts processing a message, it spawns a background goroutine that keeps consuming new inbound messages from the bus. These messages are automatically redirected into the steering queue via `Steer()`. This means: + +- Users on any channel (Telegram, Discord, etc.) don't need to do anything special — their messages are automatically captured as steering when the agent is busy +- Audio messages are transcribed before being steered, so the agent receives text. If transcription fails, the original (non-transcribed) message is steered as-is +- When `processMessage` finishes, the drain goroutine is canceled and normal message consumption resumes + ## Notes - Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue. - With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually. - With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once. +- The steering queue has a maximum capacity of 10 messages (`MaxQueueSize`). `Steer()` returns an error when the queue is full. In the bus drain path, the error is logged as a warning and the message is effectively dropped. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f3ed3ca6a8..eb132c2dbf 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -260,6 +260,13 @@ func (al *AgentLoop) Run(ctx context.Context) error { continue } + // Start a goroutine that drains the bus while processMessage is + // running. Any inbound messages that arrive during processing are + // redirected into the steering queue so the agent loop can pick + // them up between tool calls. + drainCtx, drainCancel := context.WithCancel(ctx) + go al.drainBusToSteering(drainCtx) + // Process message func() { // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. @@ -275,6 +282,8 @@ func (al *AgentLoop) Run(ctx context.Context) error { // } // }() + defer drainCancel() + response, err := al.processMessage(ctx, msg) if err != nil { response = fmt.Sprintf("Error processing message: %v", err) @@ -321,6 +330,39 @@ func (al *AgentLoop) Run(ctx context.Context) error { return nil } +// drainBusToSteering continuously consumes inbound messages and redirects +// them into the steering queue. It runs in a goroutine while processMessage +// is active and stops when drainCtx is canceled (i.e., processMessage returns). +func (al *AgentLoop) drainBusToSteering(ctx context.Context) { + for { + msg, ok := al.bus.ConsumeInbound(ctx) + if !ok { + return + } + + // Transcribe audio if needed before steering, so the agent sees text. + msg, _ = al.transcribeAudioInMessage(ctx, msg) + + logger.InfoCF("agent", "Redirecting inbound message to steering queue", + map[string]any{ + "channel": msg.Channel, + "sender_id": msg.SenderID, + "content_len": len(msg.Content), + }) + + if err := al.Steer(providers.Message{ + Role: "user", + Content: msg.Content, + }); err != nil { + logger.WarnCF("agent", "Failed to steer message, will be lost", + map[string]any{ + "error": err.Error(), + "channel": msg.Channel, + }) + } + } +} + func (al *AgentLoop) Stop() { al.running.Store(false) } @@ -1285,33 +1327,6 @@ func (al *AgentLoop) runLLMIteration( var steeringAfterTools []providers.Message for i, tc := range normalizedToolCalls { - // Check for steering before executing (except for the first tool) - if i > 0 { - if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { - steeringAfterTools = steerMsgs - logger.InfoCF("agent", "Steering interrupt: skipping remaining tools", - map[string]any{ - "agent_id": agent.ID, - "skipped_from": i, - "total_tools": len(normalizedToolCalls), - "steering_count": len(steerMsgs), - }) - - // Mark remaining tool calls as skipped - for j := i; j < len(normalizedToolCalls); j++ { - skippedTC := normalizedToolCalls[j] - toolResultMsg := providers.Message{ - Role: "tool", - Content: "Skipped due to queued user message.", - ToolCallID: skippedTC.ID, - } - messages = append(messages, toolResultMsg) - agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) - } - break - } - } - argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), @@ -1414,11 +1429,35 @@ func (al *AgentLoop) runLLMIteration( messages = append(messages, toolResultMsg) agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) - // After the last tool, also check for steering messages. - if i == len(normalizedToolCalls)-1 { - if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { - steeringAfterTools = steerMsgs + // After EVERY tool (including the first and last), check for + // steering messages. If found and there are remaining tools, + // skip them all. + if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + remaining := len(normalizedToolCalls) - i - 1 + if remaining > 0 { + logger.InfoCF("agent", "Steering interrupt: skipping remaining tools", + map[string]any{ + "agent_id": agent.ID, + "completed": i + 1, + "skipped": remaining, + "total_tools": len(normalizedToolCalls), + "steering_count": len(steerMsgs), + }) + + // Mark remaining tool calls as skipped + for j := i + 1; j < len(normalizedToolCalls); j++ { + skippedTC := normalizedToolCalls[j] + toolResultMsg := providers.Message{ + Role: "tool", + Content: "Skipped due to queued user message.", + ToolCallID: skippedTC.ID, + } + messages = append(messages, toolResultMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + } } + steeringAfterTools = steerMsgs + break } } From 950b373e2b0f415983947d36b182a627084c5014 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Fri, 13 Mar 2026 18:32:26 +0100 Subject: [PATCH 3/4] fix lint --- pkg/agent/steering_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index 173814d066..f9cae38da4 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -320,6 +320,7 @@ func (t *slowTool) Parameters() map[string]any { "properties": map[string]any{}, } } + func (t *slowTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { if t.execCh != nil { close(t.execCh) From a7756e2a1c8515d1a86af30af83ca8b0fb7a97cf Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Fri, 13 Mar 2026 18:45:51 +0100 Subject: [PATCH 4/4] fix lint --- pkg/agent/steering_test.go | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index f9cae38da4..e8cdb23449 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -202,9 +202,19 @@ func TestParseSteeringMode(t *testing.T) { // --- AgentLoop steering integration tests --- func TestAgentLoop_Steer_Enqueues(t *testing.T) { - al, _, _, _, cleanup := newTestAgentLoop(t) + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) defer cleanup() + if cfg == nil { + t.Fatal("expected config to be initialized") + } + if msgBus == nil { + t.Fatal("expected message bus to be initialized") + } + if provider == nil { + t.Fatal("expected provider to be initialized") + } + al.Steer(providers.Message{Role: "user", Content: "interrupt me"}) if al.steering.len() != 1 { @@ -218,9 +228,19 @@ func TestAgentLoop_Steer_Enqueues(t *testing.T) { } func TestAgentLoop_SteeringMode_GetSet(t *testing.T) { - al, _, _, _, cleanup := newTestAgentLoop(t) + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) defer cleanup() + if cfg == nil { + t.Fatal("expected config to be initialized") + } + if msgBus == nil { + t.Fatal("expected message bus to be initialized") + } + if provider == nil { + t.Fatal("expected provider to be initialized") + } + if al.SteeringMode() != SteeringOneAtATime { t.Fatalf("expected default mode one-at-a-time, got %v", al.SteeringMode()) } @@ -260,9 +280,19 @@ func TestAgentLoop_SteeringMode_ConfiguredFromConfig(t *testing.T) { } func TestAgentLoop_Continue_NoMessages(t *testing.T) { - al, _, _, _, cleanup := newTestAgentLoop(t) + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) defer cleanup() + if cfg == nil { + t.Fatal("expected config to be initialized") + } + if msgBus == nil { + t.Fatal("expected message bus to be initialized") + } + if provider == nil { + t.Fatal("expected provider to be initialized") + } + resp, err := al.Continue(context.Background(), "test-session", "test", "chat1") if err != nil { t.Fatalf("unexpected error: %v", err)