diff --git a/docs/steering.md b/docs/steering.md index ad08f84250..63294ac5f0 100644 --- a/docs/steering.md +++ b/docs/steering.md @@ -21,6 +21,18 @@ Agent Loop ▼ └─ new LLM turn with steering message ``` +## Scoped queues + +Steering is now isolated per resolved session scope, not stored in a single +global queue. + +- The active turn writes and reads from its own scope key (usually the routed session key such as `agent::...`) +- `Steer()` still works outside an active turn through a legacy fallback queue +- `Continue()` first dequeues messages for the requested session scope, then falls back to the legacy queue for backwards compatibility + +This prevents a message arriving from another chat, DM peer, or routed agent +session from being injected into the wrong conversation. + ## Configuration In `config.json`, under `agents.defaults`: @@ -86,12 +98,18 @@ if response == "" { `Continue` internally uses `SkipInitialSteeringPoll: true` to avoid double-dequeuing the same messages (since it already extracted them and passes them directly as input). +`Continue` also resolves the target agent from the provided session key, so +agent-scoped sessions continue on the correct agent instead of always using +the default one. + ## Polling points in the loop -Steering is checked at **two points** in the agent cycle: +Steering is checked at the following points in the agent cycle: 1. **At loop start** — before the first LLM call, to catch messages enqueued during setup 2. **After every tool completes** — including the first and the last. If steering is found and there are remaining tools, they are all skipped immediately +3. **After a direct LLM response** — if a new steering message arrived while the model was generating a non-tool response, the loop continues instead of returning a stale answer +4. **Right before the turn is finalized** — if steering arrived at the very end of the turn, the agent immediately starts a continuation turn instead of leaving the message orphaned in the queue ## Why remaining tools are skipped @@ -156,11 +174,26 @@ When the agent loop (`Run()`) starts processing a message, it spawns a backgroun - Users on any channel (Telegram, Discord, etc.) don't need to do anything special — their messages are automatically captured as steering when the agent is busy - Audio messages are transcribed before being steered, so the agent receives text. If transcription fails, the original (non-transcribed) message is steered as-is +- Only messages that resolve to the **same steering scope** as the active turn are redirected. Messages for other chats/sessions are requeued onto the inbound bus so they can be processed normally +- `system` inbound messages are not treated as steering input - When `processMessage` finishes, the drain goroutine is canceled and normal message consumption resumes +## Steering with media + +Steering messages can include `Media` refs, just like normal inbound user +messages. + +- The original `media://` refs are preserved in session history via `AddFullMessage` +- Before the next provider call, steering messages go through the normal media resolution pipeline +- Image refs are converted to data URLs for multimodal providers; non-image refs are resolved the same way as standard inbound media + +This applies both to in-turn steering and to idle-session continuation through +`Continue()`. + ## Notes - Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue. - With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually. - With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once. - The steering queue has a maximum capacity of 10 messages (`MaxQueueSize`). `Steer()` returns an error when the queue is full. In the bus drain path, the error is logged as a warning and the message is effectively dropped. +- Manual `Steer()` calls made outside an active turn still go to the legacy fallback queue, so older integrations keep working. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f54482ae87..a3a23fb3d6 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -64,11 +64,12 @@ type processOptions struct { ChatID string // Target chat ID for tool execution UserMessage string // User message content (may include prefix) Media []string // media:// refs from inbound message - DefaultResponse string // Response when LLM returns empty - EnableSummary bool // Whether to trigger summarization - SendResponse bool // Whether to send response via bus - NoHistory bool // If true, don't load session history (for heartbeat) - SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue) + InitialSteeringMessages []providers.Message + DefaultResponse string // Response when LLM returns empty + EnableSummary bool // Whether to trigger summarization + SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) + SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue) } type continuationTarget struct { @@ -271,11 +272,14 @@ func (al *AgentLoop) Run(ctx context.Context) error { } // Start a goroutine that drains the bus while processMessage is - // running. Any inbound messages that arrive during processing are - // redirected into the steering queue so the agent loop can pick - // them up between tool calls. - drainCtx, drainCancel := context.WithCancel(ctx) - go al.drainBusToSteering(drainCtx) + // running. Only messages that resolve to the active turn scope are + // redirected into steering; other inbound messages are requeued. + drainCancel := func() {} + if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok { + drainCtx, cancel := context.WithCancel(ctx) + drainCancel = cancel + go al.drainBusToSteering(drainCtx, activeScope, activeAgentID) + } // Process message func() { @@ -292,16 +296,21 @@ func (al *AgentLoop) Run(ctx context.Context) error { // } // }() - defer drainCancel() + drainCanceled := false + cancelDrain := func() { + if drainCanceled { + return + } + drainCancel() + drainCanceled = true + } + defer cancelDrain() response, err := al.processMessage(ctx, msg) if err != nil { response = fmt.Sprintf("Error processing message: %v", err) } - - if response != "" { - al.publishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, response) - } + finalResponse := response target, targetErr := al.buildContinuationTarget(msg) if targetErr != nil { @@ -313,16 +322,20 @@ func (al *AgentLoop) Run(ctx context.Context) error { return } if target == nil { + cancelDrain() + if finalResponse != "" { + al.publishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, finalResponse) + } return } - for al.pendingSteeringCount() > 0 { + for al.pendingSteeringCountForScope(target.SessionKey) > 0 { logger.InfoCF("agent", "Continuing queued steering after turn end", map[string]any{ "channel": target.Channel, "chat_id": target.ChatID, "session_key": target.SessionKey, - "queue_depth": al.pendingSteeringCount(), + "queue_depth": al.pendingSteeringCountForScope(target.SessionKey), }) continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID) @@ -339,7 +352,39 @@ func (al *AgentLoop) Run(ctx context.Context) error { return } - al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, continued) + finalResponse = continued + } + + cancelDrain() + + for al.pendingSteeringCountForScope(target.SessionKey) > 0 { + logger.InfoCF("agent", "Draining steering queued during turn shutdown", + map[string]any{ + "channel": target.Channel, + "chat_id": target.ChatID, + "session_key": target.SessionKey, + "queue_depth": al.pendingSteeringCountForScope(target.SessionKey), + }) + + continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID) + if continueErr != nil { + logger.WarnCF("agent", "Failed to continue queued steering after shutdown drain", + map[string]any{ + "channel": target.Channel, + "chat_id": target.ChatID, + "error": continueErr.Error(), + }) + return + } + if continued == "" { + break + } + + finalResponse = continued + } + + if finalResponse != "" { + al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, finalResponse) } }() } @@ -349,15 +394,27 @@ func (al *AgentLoop) Run(ctx context.Context) error { } // drainBusToSteering continuously consumes inbound messages and redirects -// them into the steering queue. It runs in a goroutine while processMessage -// is active and stops when drainCtx is canceled (i.e., processMessage returns). -func (al *AgentLoop) drainBusToSteering(ctx context.Context) { +// messages from the active scope into the steering queue. Messages from other +// scopes are requeued so they can be processed normally after the active turn. +func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, activeAgentID string) { for { msg, ok := al.bus.ConsumeInbound(ctx) if !ok { return } + msgScope, _, scopeOK := al.resolveSteeringTarget(msg) + if !scopeOK || msgScope != activeScope { + if err := al.requeueInboundMessage(msg); err != nil { + logger.WarnCF("agent", "Failed to requeue non-steering inbound message", map[string]any{ + "error": err.Error(), + "channel": msg.Channel, + "sender_id": msg.SenderID, + }) + } + return + } + // Transcribe audio if needed before steering, so the agent sees text. msg, _ = al.transcribeAudioInMessage(ctx, msg) @@ -366,11 +423,13 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context) { "channel": msg.Channel, "sender_id": msg.SenderID, "content_len": len(msg.Content), + "scope": activeScope, }) - if err := al.Steer(providers.Message{ + if err := al.enqueueSteeringMessage(activeScope, activeAgentID, providers.Message{ Role: "user", Content: msg.Content, + Media: append([]string(nil), msg.Media...), }); err != nil { logger.WarnCF("agent", "Failed to steer message, will be lost", map[string]any{ @@ -422,13 +481,6 @@ func (al *AgentLoop) publishResponseIfNeeded(ctx context.Context, channel, chatI }) } -func (al *AgentLoop) pendingSteeringCount() int { - if al.steering == nil { - return 0 - } - return al.steering.len() -} - func (al *AgentLoop) buildContinuationTarget(msg bus.InboundMessage) (*continuationTarget, error) { if msg.Channel == "system" { return nil, nil @@ -1085,6 +1137,25 @@ func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string { return route.SessionKey } +func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) { + if msg.Channel == "system" { + return "", "", false + } + + route, agent, err := al.resolveMessageRoute(msg) + if err != nil || agent == nil { + return "", "", false + } + + return resolveScopeKey(route, msg.SessionKey), agent.ID, true +} + +func (al *AgentLoop) requeueInboundMessage(msg bus.InboundMessage) error { + pubCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + return al.bus.PublishInbound(pubCtx, msg) +} + func (al *AgentLoop) processSystemMessage( ctx context.Context, msg bus.InboundMessage, @@ -1346,16 +1417,25 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er } } - if !ts.opts.NoHistory { - rootMsg := providers.Message{Role: "user", Content: ts.userMessage} - ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content) + if !ts.opts.NoHistory && (strings.TrimSpace(ts.userMessage) != "" || len(ts.media) > 0) { + rootMsg := providers.Message{ + Role: "user", + Content: ts.userMessage, + Media: append([]string(nil), ts.media...), + } + if len(rootMsg.Media) > 0 { + ts.agent.Sessions.AddFullMessage(ts.sessionKey, rootMsg) + } else { + ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content) + } ts.recordPersistedMessage(rootMsg) } activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages) - var pendingMessages []providers.Message + pendingMessages := append([]providers.Message(nil), ts.opts.InitialSteeringMessages...) var finalContent string +turnLoop: for ts.currentIteration() < ts.agent.MaxIterations || len(pendingMessages) > 0 || func() bool { graceful, _ := ts.gracefulInterruptRequested() return graceful @@ -1369,19 +1449,24 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er ts.setIteration(iteration) ts.setPhase(TurnPhaseRunning) - if iteration > 1 || !ts.opts.SkipInitialSteeringPoll { - if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + if iteration > 1 { + if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 { + pendingMessages = append(pendingMessages, steerMsgs...) + } + } else if !ts.opts.SkipInitialSteeringPoll { + if steerMsgs := al.dequeueSteeringMessagesForScopeWithFallback(ts.sessionKey); len(steerMsgs) > 0 { pendingMessages = append(pendingMessages, steerMsgs...) } } if len(pendingMessages) > 0 { + resolvedPending := resolveMediaRefs(pendingMessages, al.mediaStore, maxMediaSize) totalContentLen := 0 - for _, pm := range pendingMessages { - messages = append(messages, pm) + for i, pm := range pendingMessages { + messages = append(messages, resolvedPending[i]) totalContentLen += len(pm.Content) if !ts.opts.NoHistory { - ts.agent.Sessions.AddMessage(ts.sessionKey, pm.Role, pm.Content) + ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm) ts.recordPersistedMessage(pm) } logger.InfoCF("agent", "Injected steering message into context", @@ -1389,6 +1474,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er "agent_id": ts.agent.ID, "iteration": iteration, "content_len": len(pm.Content), + "media_count": len(pm.Media), }) } al.emitEvent( @@ -1660,10 +1746,21 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er }) if len(response.ToolCalls) == 0 || gracefulTerminal { - finalContent = response.Content - if finalContent == "" && response.ReasoningContent != "" { - finalContent = response.ReasoningContent + responseContent := response.Content + if responseContent == "" && response.ReasoningContent != "" { + responseContent = response.ReasoningContent + } + if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 { + logger.InfoCF("agent", "Steering arrived after direct LLM response; continuing turn", + map[string]any{ + "agent_id": ts.agent.ID, + "iteration": iteration, + "steering_count": len(steerMsgs), + }) + pendingMessages = append(pendingMessages, steerMsgs...) + continue } + finalContent = responseContent logger.InfoCF("agent", "LLM response without tool calls (direct answer)", map[string]any{ "agent_id": ts.agent.ID, @@ -1870,7 +1967,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er ts.recordPersistedMessage(toolResultMsg) } - if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 { pendingMessages = append(pendingMessages, steerMsgs...) } @@ -1926,6 +2023,18 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er }) } + if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 { + logger.InfoCF("agent", "Steering arrived after turn completion; continuing turn before finalizing", + map[string]any{ + "agent_id": ts.agent.ID, + "steering_count": len(steerMsgs), + "session_key": ts.sessionKey, + }) + pendingMessages = append(pendingMessages, steerMsgs...) + finalContent = "" + goto turnLoop + } + if ts.hardAbortRequested() { turnStatus = TurnEndStatusAborted return al.abortTurn(ts) diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index 77c2e0c177..eb8afa1dde 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -8,6 +8,7 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" ) // SteeringMode controls how queued steering messages are dequeued. @@ -20,6 +21,9 @@ const ( SteeringAll SteeringMode = "all" // MaxQueueSize number of possible messages in the Steering Queue MaxQueueSize = 10 + // manualSteeringScope is the legacy fallback queue used when no active + // turn/session scope is available. + manualSteeringScope = "__manual__" ) // parseSteeringMode normalizes a config string into a SteeringMode. @@ -35,56 +39,117 @@ func parseSteeringMode(s string) SteeringMode { // steeringQueue is a thread-safe queue of user messages that can be injected // into a running agent loop to interrupt it between tool calls. type steeringQueue struct { - mu sync.Mutex - queue []providers.Message - mode SteeringMode + mu sync.Mutex + queues map[string][]providers.Message + mode SteeringMode } func newSteeringQueue(mode SteeringMode) *steeringQueue { return &steeringQueue{ - mode: mode, + queues: make(map[string][]providers.Message), + mode: mode, } } -// push enqueues a steering message. +func normalizeSteeringScope(scope string) string { + scope = strings.TrimSpace(scope) + if scope == "" { + return manualSteeringScope + } + return scope +} + +// push enqueues a steering message in the legacy fallback scope. func (sq *steeringQueue) push(msg providers.Message) error { + return sq.pushScope(manualSteeringScope, msg) +} + +// pushScope enqueues a steering message for the provided scope. +func (sq *steeringQueue) pushScope(scope string, msg providers.Message) error { sq.mu.Lock() defer sq.mu.Unlock() - if len(sq.queue) >= MaxQueueSize { + + scope = normalizeSteeringScope(scope) + queue := sq.queues[scope] + if len(queue) >= MaxQueueSize { return fmt.Errorf("steering queue is full") } - sq.queue = append(sq.queue, msg) + sq.queues[scope] = append(queue, msg) return nil } -// dequeue removes and returns pending steering messages according to the -// configured mode. Returns nil when the queue is empty. +// dequeue removes and returns pending steering messages from the legacy +// fallback scope according to the configured mode. func (sq *steeringQueue) dequeue() []providers.Message { + return sq.dequeueScope(manualSteeringScope) +} + +// dequeueScope removes and returns pending steering messages for the provided +// scope according to the configured mode. +func (sq *steeringQueue) dequeueScope(scope string) []providers.Message { + sq.mu.Lock() + defer sq.mu.Unlock() + + return sq.dequeueLocked(normalizeSteeringScope(scope)) +} + +// dequeueScopeWithFallback drains the scoped queue first and falls back to the +// legacy manual scope for backwards compatibility. +func (sq *steeringQueue) dequeueScopeWithFallback(scope string) []providers.Message { sq.mu.Lock() defer sq.mu.Unlock() - if len(sq.queue) == 0 { + scope = strings.TrimSpace(scope) + if scope != "" { + if msgs := sq.dequeueLocked(scope); len(msgs) > 0 { + return msgs + } + } + + return sq.dequeueLocked(manualSteeringScope) +} + +func (sq *steeringQueue) dequeueLocked(scope string) []providers.Message { + queue := sq.queues[scope] + if len(queue) == 0 { return nil } switch sq.mode { case SteeringAll: - msgs := sq.queue - sq.queue = nil + msgs := append([]providers.Message(nil), queue...) + delete(sq.queues, scope) return msgs - default: // one-at-a-time - msg := sq.queue[0] - sq.queue[0] = providers.Message{} // Clear reference for GC - sq.queue = sq.queue[1:] + default: + msg := queue[0] + queue[0] = providers.Message{} // Clear reference for GC + queue = queue[1:] + if len(queue) == 0 { + delete(sq.queues, scope) + } else { + sq.queues[scope] = queue + } return []providers.Message{msg} } } -// len returns the number of queued messages. +// len returns the number of queued messages across all scopes. func (sq *steeringQueue) len() int { sq.mu.Lock() defer sq.mu.Unlock() - return len(sq.queue) + + total := 0 + for _, queue := range sq.queues { + total += len(queue) + } + return total +} + +// lenScope returns the number of queued messages for a specific scope. +func (sq *steeringQueue) lenScope(scope string) int { + sq.mu.Lock() + defer sq.mu.Unlock() + return len(sq.queues[normalizeSteeringScope(scope)]) } // setMode updates the steering mode. @@ -101,26 +166,40 @@ func (sq *steeringQueue) getMode() SteeringMode { return sq.mode } -// --- AgentLoop steering API --- - // Steer enqueues a user message to be injected into the currently running // agent loop. The message will be picked up after the current tool finishes // executing, causing any remaining tool calls in the batch to be skipped. func (al *AgentLoop) Steer(msg providers.Message) error { + scope := "" + agentID := "" + if ts := al.getActiveTurnState(); ts != nil { + scope = ts.sessionKey + agentID = ts.agentID + } + return al.enqueueSteeringMessage(scope, agentID, msg) +} + +func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers.Message) error { if al.steering == nil { return fmt.Errorf("steering queue is not initialized") } - if err := al.steering.push(msg); err != nil { + + if err := al.steering.pushScope(scope, msg); err != nil { logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{ "error": err.Error(), "role": msg.Role, + "scope": normalizeSteeringScope(scope), }) return err } + + queueDepth := al.steering.lenScope(scope) logger.DebugCF("agent", "Steering message enqueued", map[string]any{ "role": msg.Role, "content_len": len(msg.Content), - "queue_len": al.steering.len(), + "media_count": len(msg.Media), + "queue_len": queueDepth, + "scope": normalizeSteeringScope(scope), }) meta := EventMeta{ @@ -129,11 +208,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error { } if ts := al.getActiveTurnState(); ts != nil { meta = ts.eventMeta("Steer", "turn.interrupt.received") - } else if registry := al.GetRegistry(); registry != nil { - if agent := registry.GetDefaultAgent(); agent != nil { - meta.AgentID = agent.ID + } else { + if strings.TrimSpace(agentID) != "" { + meta.AgentID = agentID + } + normalizedScope := normalizeSteeringScope(scope) + if normalizedScope != manualSteeringScope { + meta.SessionKey = normalizedScope + } + if meta.AgentID == "" { + if registry := al.GetRegistry(); registry != nil { + if agent := registry.GetDefaultAgent(); agent != nil { + meta.AgentID = agent.ID + } + } } } + al.emitEvent( EventKindInterruptReceived, meta, @@ -141,7 +232,7 @@ func (al *AgentLoop) Steer(msg providers.Message) error { Kind: InterruptKindSteering, Role: msg.Role, ContentLen: len(msg.Content), - QueueDepth: al.steering.len(), + QueueDepth: queueDepth, }, ) @@ -165,7 +256,7 @@ func (al *AgentLoop) SetSteeringMode(mode SteeringMode) { } // dequeueSteeringMessages is the internal method called by the agent loop -// to poll for steering messages. Returns nil when no messages are pending. +// to poll for steering messages in the legacy fallback scope. func (al *AgentLoop) dequeueSteeringMessages() []providers.Message { if al.steering == nil { return nil @@ -173,6 +264,60 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message { return al.steering.dequeue() } +func (al *AgentLoop) dequeueSteeringMessagesForScope(scope string) []providers.Message { + if al.steering == nil { + return nil + } + return al.steering.dequeueScope(scope) +} + +func (al *AgentLoop) dequeueSteeringMessagesForScopeWithFallback(scope string) []providers.Message { + if al.steering == nil { + return nil + } + return al.steering.dequeueScopeWithFallback(scope) +} + +func (al *AgentLoop) pendingSteeringCountForScope(scope string) int { + if al.steering == nil { + return 0 + } + return al.steering.lenScope(scope) +} + +func (al *AgentLoop) continueWithSteeringMessages( + ctx context.Context, + agent *AgentInstance, + sessionKey, channel, chatID string, + steeringMsgs []providers.Message, +) (string, error) { + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, + Channel: channel, + ChatID: chatID, + DefaultResponse: defaultResponse, + EnableSummary: true, + SendResponse: false, + InitialSteeringMessages: steeringMsgs, + SkipInitialSteeringPoll: true, + }) +} + +func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance { + registry := al.GetRegistry() + if registry == nil { + return nil + } + + if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil { + if agent, ok := registry.GetAgent(parsed.AgentID); ok { + return agent + } + } + + return registry.GetDefaultAgent() +} + // Continue resumes an idle agent by dequeuing any pending steering messages // and running them through the agent loop. This is used when the agent's last // message was from the assistant (i.e., it has stopped processing) and the @@ -184,14 +329,14 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s return "", fmt.Errorf("turn %s is still active", active.TurnID) } - steeringMsgs := al.dequeueSteeringMessages() + steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey) if len(steeringMsgs) == 0 { return "", nil } - agent := al.GetRegistry().GetDefaultAgent() + agent := al.agentForSession(sessionKey) if agent == nil { - return "", fmt.Errorf("no default agent available") + return "", fmt.Errorf("no agent available for session %q", sessionKey) } if tool, ok := agent.Tools.Get("message"); ok { @@ -200,23 +345,7 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s } } - // Build a combined user message from the steering messages. - var contents []string - for _, msg := range steeringMsgs { - contents = append(contents, msg.Content) - } - combinedContent := strings.Join(contents, "\n") - - return al.runAgentLoop(ctx, agent, processOptions{ - SessionKey: sessionKey, - Channel: channel, - ChatID: chatID, - UserMessage: combinedContent, - DefaultResponse: defaultResponse, - EnableSummary: true, - SendResponse: false, - SkipInitialSteeringPoll: true, - }) + return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs) } func (al *AgentLoop) InterruptGraceful(hint string) error { diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index bb5d42c73b..cf2e86904c 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -5,13 +5,16 @@ import ( "encoding/json" "fmt" "os" + "path/filepath" "reflect" + "strings" "sync" "testing" "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/tools" @@ -337,6 +340,96 @@ func TestAgentLoop_Continue_WithMessages(t *testing.T) { } } +func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + Session: config.SessionConfig{ + DMScope: "per-peer", + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + + activeMsg := bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "active turn", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg) + if !ok { + t.Fatal("expected active message to resolve to a steering scope") + } + + otherMsg := bus.InboundMessage{ + Channel: "telegram", + SenderID: "user2", + ChatID: "chat2", + Content: "other session", + Peer: bus.Peer{ + Kind: "direct", + ID: "user2", + }, + } + otherScope, _, ok := al.resolveSteeringTarget(otherMsg) + if !ok { + t.Fatal("expected other message to resolve to a steering scope") + } + if otherScope == activeScope { + t.Fatalf("expected different steering scopes, got same scope %q", activeScope) + } + + if err := msgBus.PublishInbound(context.Background(), otherMsg); err != nil { + t.Fatalf("PublishInbound failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + done := make(chan struct{}) + go func() { + al.drainBusToSteering(ctx, activeScope, activeAgentID) + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for drainBusToSteering to stop") + } + + if msgs := al.dequeueSteeringMessagesForScope(activeScope); len(msgs) != 0 { + t.Fatalf("expected no steering messages for active scope, got %v", msgs) + } + + requeued, ok := msgBus.ConsumeInbound(context.Background()) + if !ok { + t.Fatal("expected message to be requeued on the inbound bus") + } + if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID || + requeued.SenderID != otherMsg.SenderID || requeued.Content != otherMsg.Content { + t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg) + } +} + // slowTool simulates a tool that takes some time to execute. type slowTool struct { name string @@ -472,6 +565,52 @@ func (p *lateSteeringProvider) GetDefaultModel() string { return "late-steering-mock" } +type blockingDirectProvider struct { + mu sync.Mutex + calls int + firstStarted chan struct{} + releaseFirst chan struct{} + firstResp string + finalResp string +} + +func (p *blockingDirectProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.mu.Lock() + p.calls++ + call := p.calls + firstStarted := p.firstStarted + releaseFirst := p.releaseFirst + firstResp := p.firstResp + finalResp := p.finalResp + if call == 1 && p.firstStarted != nil { + close(p.firstStarted) + p.firstStarted = nil + } + p.mu.Unlock() + + if call == 1 { + select { + case <-releaseFirst: + case <-ctx.Done(): + return nil, ctx.Err() + } + return &providers.LLMResponse{Content: firstResp}, nil + } + + _ = firstStarted + return &providers.LLMResponse{Content: finalResp}, nil +} + +func (p *blockingDirectProvider) GetDefaultModel() string { + return "blocking-direct-mock" +} + type interruptibleTool struct { name string started chan struct{} @@ -744,18 +883,16 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) { out1, ok := msgBus.SubscribeOutbound(subCtx) if !ok { - t.Fatal("expected first outbound response") + t.Fatal("expected outbound response") } - if out1.Content != "first response" { - t.Fatalf("expected first response, got %q", out1.Content) + if out1.Content != "continued response" { + t.Fatalf("expected continued response, got %q", out1.Content) } - out2, ok := msgBus.SubscribeOutbound(subCtx) - if !ok { - t.Fatal("expected continued outbound response") - } - if out2.Content != "continued response" { - t.Fatalf("expected continued response, got %q", out2.Content) + noExtraCtx, cancelNoExtra := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancelNoExtra() + if out2, ok := msgBus.SubscribeOutbound(noExtraCtx); ok { + t.Fatalf("expected stale direct response to be suppressed, got extra outbound %q", out2.Content) } cancelRun() @@ -789,6 +926,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) { } } +func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + provider := &blockingDirectProvider{ + firstStarted: make(chan struct{}), + releaseFirst: make(chan struct{}), + firstResp: "stale direct response", + finalResp: "fresh response after steering", + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + + resultCh := make(chan struct { + resp string + err error + }, 1) + go func() { + resp, err := al.ProcessDirectWithChannel( + context.Background(), + "initial request", + sessionKey, + "test", + "chat1", + ) + resultCh <- struct { + resp string + err error + }{resp: resp, err: err} + }() + + select { + case <-provider.firstStarted: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for first LLM call to start") + } + + if err := al.Steer(providers.Message{Role: "user", Content: "follow-up instruction"}); err != nil { + t.Fatalf("Steer failed: %v", err) + } + close(provider.releaseFirst) + + select { + case result := <-resultCh: + if result.err != nil { + t.Fatalf("unexpected error: %v", result.err) + } + if result.resp != "fresh response after steering" { + t.Fatalf("expected refreshed response, got %q", result.resp) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for ProcessDirectWithChannel") + } + + provider.mu.Lock() + calls := provider.calls + provider.mu.Unlock() + if calls != 2 { + t.Fatalf("expected 2 provider calls, got %d", calls) + } + + if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 { + t.Fatalf("expected steering queue to be empty after continuation, got %v", msgs) + } +} + +func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + store := media.NewFileMediaStore() + pngPath := filepath.Join(tmpDir, "steer.png") + pngHeader := []byte{ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, + 0x00, 0x00, 0x00, 0x0D, + 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, + 0x00, 0x00, 0x00, + 0x90, 0x77, 0x53, 0xDE, + } + if err = os.WriteFile(pngPath, pngHeader, 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + ref, err := store.Store(pngPath, media.MediaMeta{Filename: "steer.png", ContentType: "image/png"}, "test") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + var capturedMessages []providers.Message + var capMu sync.Mutex + provider := &capturingMockProvider{ + response: "ack", + captureFn: func(msgs []providers.Message) { + capMu.Lock() + defer capMu.Unlock() + capturedMessages = append([]providers.Message(nil), msgs...) + }, + } + + sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + al.SetMediaStore(store) + + if err = al.Steer(providers.Message{ + Role: "user", + Content: "describe this image", + Media: []string{ref}, + }); err != nil { + t.Fatalf("Steer failed: %v", err) + } + + resp, err := al.Continue(context.Background(), sessionKey, "test", "chat1") + if err != nil { + t.Fatalf("Continue failed: %v", err) + } + if resp != "ack" { + t.Fatalf("expected ack, got %q", resp) + } + + capMu.Lock() + msgs := append([]providers.Message(nil), capturedMessages...) + capMu.Unlock() + + foundResolvedMedia := false + for _, msg := range msgs { + if msg.Role != "user" || msg.Content != "describe this image" || len(msg.Media) != 1 { + continue + } + if strings.HasPrefix(msg.Media[0], "data:image/png;base64,") { + foundResolvedMedia = true + break + } + } + if !foundResolvedMedia { + t.Fatal("expected continue path to inject steering media into the provider request") + } + + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + history := defaultAgent.Sessions.GetHistory(sessionKey) + foundOriginalRef := false + for _, msg := range history { + if msg.Role == "user" && len(msg.Media) == 1 && msg.Media[0] == ref { + foundOriginalRef = true + break + } + } + if !foundOriginalRef { + t.Fatal("expected original steering media ref to be preserved in session history") + } +} + func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil {