diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 599ea57fc6..97cf0fa059 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -37,6 +37,14 @@ type AgentInstance struct { Subagents *config.SubagentsConfig SkillsFilter []string Candidates []providers.FallbackCandidate + + // Router is non-nil when model routing is configured and the light model + // was successfully resolved. It scores each incoming message and decides + // whether to route to LightCandidates or stay with Candidates. + Router *routing.Router + // LightCandidates holds the resolved provider candidates for the light model. + // Pre-computed at agent creation to avoid repeated model_list lookups at runtime. + LightCandidates []providers.FallbackCandidate } // NewAgentInstance creates an agent instance from config. @@ -180,6 +188,25 @@ func NewAgentInstance( candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList) + // Model routing setup: pre-resolve light model candidates at creation time + // to avoid repeated model_list lookups on every incoming message. + var router *routing.Router + var lightCandidates []providers.FallbackCandidate + if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" { + lightModelCfg := providers.ModelConfig{Primary: rc.LightModel} + resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList) + if len(resolved) > 0 { + router = routing.New(routing.RouterConfig{ + LightModel: rc.LightModel, + Threshold: rc.Threshold, + }) + lightCandidates = resolved + } else { + log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q", + rc.LightModel, agentID) + } + } + return &AgentInstance{ ID: agentID, Name: agentName, @@ -200,6 +227,8 @@ func NewAgentInstance( Subagents: subagents, SkillsFilter: skillsFilter, Candidates: candidates, + Router: router, + LightCandidates: lightCandidates, } } diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 685b346e69..132bb3c981 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -824,6 +824,12 @@ func (al *AgentLoop) runLLMIteration( iteration := 0 var finalContent string + // Determine effective model tier for this conversation turn. + // selectCandidates evaluates routing once and the decision is sticky for + // all tool-follow-up iterations within the same turn so that a multi-step + // tool chain doesn't switch models mid-way through. + activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages) + for iteration < agent.MaxIterations { iteration++ @@ -842,7 +848,7 @@ func (al *AgentLoop) runLLMIteration( map[string]any{ "agent_id": agent.ID, "iteration": iteration, - "model": agent.Model, + "model": activeModel, "messages_count": len(messages), "tools_count": len(providerToolDefs), "max_tokens": agent.MaxTokens, @@ -858,7 +864,7 @@ func (al *AgentLoop) runLLMIteration( "tools_json": formatToolsForLog(providerToolDefs), }) - // Call LLM with fallback chain if candidates are configured. + // Call LLM with fallback chain if multiple candidates are configured. var response *providers.LLMResponse var err error @@ -879,10 +885,10 @@ func (al *AgentLoop) runLLMIteration( } callLLM := func() (*providers.LLMResponse, error) { - if len(agent.Candidates) > 1 && al.fallback != nil { + if len(activeCandidates) > 1 && al.fallback != nil { fbResult, fbErr := al.fallback.Execute( ctx, - agent.Candidates, + activeCandidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts) }, @@ -900,7 +906,7 @@ func (al *AgentLoop) runLLMIteration( } return fbResult.Response, nil } - return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts) + return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts) } // Retry loop for context/token errors @@ -1169,6 +1175,44 @@ func (al *AgentLoop) runLLMIteration( return finalContent, iteration, nil } +// selectCandidates returns the model candidates and resolved model name to use +// for a conversation turn. When model routing is configured and the incoming +// message scores below the complexity threshold, it returns the light model +// candidates instead of the primary ones. +// +// The returned (candidates, model) pair is used for all LLM calls within one +// turn — tool follow-up iterations use the same tier as the initial call so +// that a multi-step tool chain doesn't switch models mid-way. +func (al *AgentLoop) selectCandidates( + agent *AgentInstance, + userMsg string, + history []providers.Message, +) (candidates []providers.FallbackCandidate, model string) { + if agent.Router == nil || len(agent.LightCandidates) == 0 { + return agent.Candidates, agent.Model + } + + _, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model) + if !usedLight { + logger.DebugCF("agent", "Model routing: primary model selected", + map[string]any{ + "agent_id": agent.ID, + "score": score, + "threshold": agent.Router.Threshold(), + }) + return agent.Candidates, agent.Model + } + + logger.InfoCF("agent", "Model routing: light model selected", + map[string]any{ + "agent_id": agent.ID, + "light_model": agent.Router.LightModel(), + "score": score, + "threshold": agent.Router.Threshold(), + }) + return agent.LightCandidates, agent.Router.LightModel() +} + // maybeSummarize triggers summarization if the session history exceeds thresholds. func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { newHistory := agent.Sessions.GetHistory(sessionKey) diff --git a/pkg/config/config.go b/pkg/config/config.go index 7a0ec323c2..23dca8cb8d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -167,22 +167,35 @@ type SessionConfig struct { IdentityLinks map[string][]string `json:"identity_links,omitempty"` } +// RoutingConfig controls the intelligent model routing feature. +// When enabled, each incoming message is scored against structural features +// (message length, code blocks, tool call history, conversation depth, attachments). +// Messages scoring below Threshold are sent to LightModel; all others use the +// agent's primary model. This reduces cost and latency for simple tasks without +// requiring any keyword matching — all scoring is language-agnostic. +type RoutingConfig struct { + Enabled bool `json:"enabled"` + LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks + Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model +} + type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` - AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` - Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead - ModelFallbacks []string `json:"model_fallbacks,omitempty"` - ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` - ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` - SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` - SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` - MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead + ModelFallbacks []string `json:"model_fallbacks,omitempty"` + ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` + ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` + SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` + MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` + Routing *RoutingConfig `json:"routing,omitempty"` } const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB diff --git a/pkg/routing/classifier.go b/pkg/routing/classifier.go new file mode 100644 index 0000000000..8cddaf0690 --- /dev/null +++ b/pkg/routing/classifier.go @@ -0,0 +1,80 @@ +package routing + +// Classifier evaluates a feature set and returns a complexity score in [0, 1]. +// A higher score indicates a more complex task that benefits from a heavy model. +// The score is compared against the configured threshold: score >= threshold selects +// the primary (heavy) model; score < threshold selects the light model. +// +// Classifier is an interface so that future implementations (ML-based, embedding-based, +// or any other approach) can be swapped in without changing routing infrastructure. +type Classifier interface { + Score(f Features) float64 +} + +// RuleClassifier is the v1 implementation. +// It uses a weighted sum of structural signals with no external dependencies, +// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so +// that the returned score always falls within the [0, 1] contract. +// +// Individual weights (multiple signals can fire simultaneously): +// +// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex +// token 50-200: 0.15 — medium length; may or may not be complex +// code block present: 0.40 — coding tasks need the heavy model +// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow +// tool calls 1-3 (recent): 0.10 — some tool activity +// conversation depth > 10: 0.10 — long sessions carry implicit complexity +// attachments present: 1.00 — hard gate; multi-modal always needs heavy model +// +// Default threshold is 0.35, so: +// - Pure greetings / trivial Q&A: 0.00 → light ✓ +// - Medium prose message (50–200 tokens): 0.15 → light ✓ +// - Message with code block: 0.40 → heavy ✓ +// - Long message (>200 tokens): 0.35 → heavy ✓ +// - Active tool session + medium message: 0.25 → light (acceptable) +// - Any message with an image/audio attachment: 1.00 → heavy ✓ +type RuleClassifier struct{} + +// Score computes the complexity score for the given feature set. +// The returned value is in [0, 1]. Attachments short-circuit to 1.0. +func (c *RuleClassifier) Score(f Features) float64 { + // Hard gate: multi-modal inputs always require the heavy model. + if f.HasAttachments { + return 1.0 + } + + var score float64 + + // Token estimate — primary verbosity signal + switch { + case f.TokenEstimate > 200: + score += 0.35 + case f.TokenEstimate > 50: + score += 0.15 + } + + // Fenced code blocks — strongest indicator of a coding/technical task + if f.CodeBlockCount > 0 { + score += 0.40 + } + + // Recent tool call density — indicates an ongoing agentic workflow + switch { + case f.RecentToolCalls > 3: + score += 0.25 + case f.RecentToolCalls > 0: + score += 0.10 + } + + // Conversation depth — accumulated context implies compound task + if f.ConversationDepth > 10 { + score += 0.10 + } + + // Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire + // simultaneously (e.g., long message + code block + tool chain = 1.10 raw). + if score > 1.0 { + score = 1.0 + } + return score +} diff --git a/pkg/routing/features.go b/pkg/routing/features.go new file mode 100644 index 0000000000..c371e21aa3 --- /dev/null +++ b/pkg/routing/features.go @@ -0,0 +1,127 @@ +package routing + +import ( + "strings" + "unicode/utf8" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// lookbackWindow is the number of recent history entries scanned for tool calls. +// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant). +const lookbackWindow = 6 + +// Features holds the structural signals extracted from a message and its session context. +// Every dimension is language-agnostic by construction — no keyword or pattern matching +// against natural-language content. This ensures consistent routing for all locales. +type Features struct { + // TokenEstimate is a proxy for token count. + // CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each. + // This avoids API calls while giving accurate estimates for all scripts. + TokenEstimate int + + // CodeBlockCount is the number of fenced code blocks (``` pairs) in the message. + // Coding tasks almost always require the heavy model. + CodeBlockCount int + + // RecentToolCalls is the count of tool_call messages in the last lookbackWindow + // history entries. A high density indicates an active agentic workflow. + RecentToolCalls int + + // ConversationDepth is the total number of messages in the session history. + // Deep sessions tend to carry implicit complexity built up over many turns. + ConversationDepth int + + // HasAttachments is true when the message appears to contain media (images, + // audio, video). Multi-modal inputs require vision-capable heavy models. + HasAttachments bool +} + +// ExtractFeatures computes the structural feature vector for a message. +// It is a pure function with no side effects and zero allocations beyond +// the returned struct. +func ExtractFeatures(msg string, history []providers.Message) Features { + return Features{ + TokenEstimate: estimateTokens(msg), + CodeBlockCount: countCodeBlocks(msg), + RecentToolCalls: countRecentToolCalls(history), + ConversationDepth: len(history), + HasAttachments: hasAttachments(msg), + } +} + +// estimateTokens returns a token count proxy that handles both CJK and Latin text. +// CJK runes (U+2E80–U+9FFF, U+F900–U+FAFF, U+AC00–U+D7AF) map to roughly one +// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token +// for English). Splitting the count this way avoids the 3x underestimation that a +// flat rune_count/3 would produce for Chinese, Japanese, and Korean text. +func estimateTokens(msg string) int { + total := utf8.RuneCountInString(msg) + if total == 0 { + return 0 + } + cjk := 0 + for _, r := range msg { + if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF { + cjk++ + } + } + return cjk + (total-cjk)/4 +} + +// countCodeBlocks counts the number of complete fenced code blocks. +// Each ``` delimiter increments a counter; pairs of delimiters form one block. +// An unclosed opening fence (odd count) is treated as zero complete blocks +// since it may just be an inline code span or a typo. +func countCodeBlocks(msg string) int { + n := strings.Count(msg, "```") + return n / 2 +} + +// countRecentToolCalls counts messages with tool calls in the last lookbackWindow +// entries of history. It examines the ToolCalls field rather than parsing +// the content string, so it is robust to any message format. +func countRecentToolCalls(history []providers.Message) int { + start := len(history) - lookbackWindow + if start < 0 { + start = 0 + } + + count := 0 + for _, msg := range history[start:] { + if len(msg.ToolCalls) > 0 { + count += len(msg.ToolCalls) + } + } + return count +} + +// hasAttachments returns true when the message content contains embedded media. +// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and +// common image/audio URL extensions. This is intentionally conservative — +// false negatives (missing an attachment) just mean the routing falls back to +// the primary model anyway. +func hasAttachments(msg string) bool { + lower := strings.ToLower(msg) + + // Base64 data URIs embedded directly in the message + if strings.Contains(lower, "data:image/") || + strings.Contains(lower, "data:audio/") || + strings.Contains(lower, "data:video/") { + return true + } + + // Common image/audio extensions in URLs or file references + mediaExts := []string{ + ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", + ".mp3", ".wav", ".ogg", ".m4a", ".flac", + ".mp4", ".avi", ".mov", ".webm", + } + for _, ext := range mediaExts { + if strings.Contains(lower, ext) { + return true + } + } + + return false +} diff --git a/pkg/routing/router.go b/pkg/routing/router.go new file mode 100644 index 0000000000..b1fa347e95 --- /dev/null +++ b/pkg/routing/router.go @@ -0,0 +1,82 @@ +package routing + +import ( + "github.com/sipeed/picoclaw/pkg/providers" +) + +// defaultThreshold is used when the config threshold is zero or negative. +// At 0.35 a message needs at least one strong signal (code block, long text, +// or an attachment) before the heavy model is chosen. +const defaultThreshold = 0.35 + +// RouterConfig holds the validated model routing settings. +// It mirrors config.RoutingConfig but lives in pkg/routing to keep the +// dependency graph simple: pkg/agent resolves config → routing, not the reverse. +type RouterConfig struct { + // LightModel is the model_name (from model_list) used for simple tasks. + LightModel string + + // Threshold is the complexity score cutoff in [0, 1]. + // score >= Threshold → primary (heavy) model. + // score < Threshold → light model. + Threshold float64 +} + +// Router selects the appropriate model tier for each incoming message. +// It is safe for concurrent use from multiple goroutines. +type Router struct { + cfg RouterConfig + classifier Classifier +} + +// New creates a Router with the given config and the default RuleClassifier. +// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used. +func New(cfg RouterConfig) *Router { + if cfg.Threshold <= 0 { + cfg.Threshold = defaultThreshold + } + return &Router{ + cfg: cfg, + classifier: &RuleClassifier{}, + } +} + +// newWithClassifier creates a Router with a custom Classifier. +// Intended for unit tests that need to inject a deterministic scorer. +func newWithClassifier(cfg RouterConfig, c Classifier) *Router { + if cfg.Threshold <= 0 { + cfg.Threshold = defaultThreshold + } + return &Router{cfg: cfg, classifier: c} +} + +// SelectModel returns the model to use for this conversation turn along with +// the computed complexity score (for logging and debugging). +// +// - If score < cfg.Threshold: returns (cfg.LightModel, true, score) +// - Otherwise: returns (primaryModel, false, score) +// +// The caller is responsible for resolving the returned model name into +// provider candidates (see AgentInstance.LightCandidates). +func (r *Router) SelectModel( + msg string, + history []providers.Message, + primaryModel string, +) (model string, usedLight bool, score float64) { + features := ExtractFeatures(msg, history) + score = r.classifier.Score(features) + if score < r.cfg.Threshold { + return r.cfg.LightModel, true, score + } + return primaryModel, false, score +} + +// LightModel returns the configured light model name. +func (r *Router) LightModel() string { + return r.cfg.LightModel +} + +// Threshold returns the complexity threshold in use. +func (r *Router) Threshold() float64 { + return r.cfg.Threshold +} diff --git a/pkg/routing/router_test.go b/pkg/routing/router_test.go new file mode 100644 index 0000000000..2824d10abf --- /dev/null +++ b/pkg/routing/router_test.go @@ -0,0 +1,414 @@ +package routing + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// ── ExtractFeatures ────────────────────────────────────────────────────────── + +func TestExtractFeatures_EmptyMessage(t *testing.T) { + f := ExtractFeatures("", nil) + if f.TokenEstimate != 0 { + t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate) + } + if f.CodeBlockCount != 0 { + t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount) + } + if f.RecentToolCalls != 0 { + t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls) + } + if f.ConversationDepth != 0 { + t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth) + } + if f.HasAttachments { + t.Error("HasAttachments: got true, want false") + } +} + +func TestExtractFeatures_TokenEstimate(t *testing.T) { + // 30 ASCII runes: 0 CJK + 30/4 = 7 tokens + msg := strings.Repeat("a", 30) + f := ExtractFeatures(msg, nil) + if f.TokenEstimate != 7 { + t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate) + } +} + +func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) { + // 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token). + // Using a rune slice literal avoids CJK string literals in source. + msg := string([]rune{ + 0x4F60, 0x597D, 0x4E16, 0x754C, + 0x4F60, 0x597D, 0x4E16, 0x754C, + 0x4F60, + }) + f := ExtractFeatures(msg, nil) + if f.TokenEstimate != 9 { + t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate) + } +} + +func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) { + // Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens. + msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok" + f := ExtractFeatures(msg, nil) + if f.TokenEstimate != 6 { + t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate) + } +} + +func TestExtractFeatures_CodeBlocks(t *testing.T) { + cases := []struct { + msg string + want int + }{ + {"no code here", 0}, + {"```go\nfmt.Println()\n```", 1}, + {"```python\npass\n```\n```js\nconsole.log()\n```", 2}, + {"```unclosed", 0}, // odd number of fences = 0 complete blocks + } + for _, tc := range cases { + f := ExtractFeatures(tc.msg, nil) + if f.CodeBlockCount != tc.want { + t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want) + } + } +} + +func TestExtractFeatures_RecentToolCalls(t *testing.T) { + // History longer than lookbackWindow — only last lookbackWindow entries count. + history := make([]providers.Message, 10) + // Put 2 tool calls at positions 8 and 9 (within the last 6) + history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}} + history[9] = providers.Message{ + Role: "assistant", + ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}, + } + // Position 3 is outside the lookback window and must NOT be counted + history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}} + + f := ExtractFeatures("test", history) + // 1 (position 8) + 2 (position 9) = 3 + if f.RecentToolCalls != 3 { + t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls) + } +} + +func TestExtractFeatures_ConversationDepth(t *testing.T) { + history := make([]providers.Message, 7) + f := ExtractFeatures("msg", history) + if f.ConversationDepth != 7 { + t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth) + } +} + +func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) { + cases := []struct { + msg string + want bool + }{ + {"plain text", false}, + {"here is an image: data:image/png;base64,abc123", true}, + {"audio: data:audio/mp3;base64,xyz", true}, + {"video: data:video/mp4;base64,xyz", true}, + } + for _, tc := range cases { + f := ExtractFeatures(tc.msg, nil) + if f.HasAttachments != tc.want { + t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want) + } + } +} + +func TestExtractFeatures_HasAttachments_Extension(t *testing.T) { + cases := []struct { + msg string + want bool + }{ + {"check out photo.jpg", true}, + {"see screenshot.png", true}, + {"listen to audio.mp3", true}, + {"watch clip.mp4", true}, + {"just a .go file", false}, + {"document.pdf", false}, // pdf is not in the media list + } + for _, tc := range cases { + f := ExtractFeatures(tc.msg, nil) + if f.HasAttachments != tc.want { + t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want) + } + } +} + +// ── RuleClassifier ─────────────────────────────────────────────────────────── + +func TestRuleClassifier_ZeroFeatures(t *testing.T) { + c := &RuleClassifier{} + score := c.Score(Features{}) + if score != 0.0 { + t.Errorf("zero features: got %f, want 0.0", score) + } +} + +func TestRuleClassifier_AttachmentsHardGate(t *testing.T) { + c := &RuleClassifier{} + score := c.Score(Features{HasAttachments: true}) + if score != 1.0 { + t.Errorf("attachments: got %f, want 1.0", score) + } +} + +func TestRuleClassifier_CodeBlockAlone(t *testing.T) { + c := &RuleClassifier{} + // Code block alone = 0.40, above default threshold 0.35 + score := c.Score(Features{CodeBlockCount: 1}) + if score < 0.35 { + t.Errorf("code block: score %f is below default threshold 0.35", score) + } +} + +func TestRuleClassifier_LongMessage(t *testing.T) { + c := &RuleClassifier{} + // >200 tokens = 0.35, exactly at default threshold → heavy + score := c.Score(Features{TokenEstimate: 250}) + if score < 0.35 { + t.Errorf("long message: score %f is below default threshold 0.35", score) + } +} + +func TestRuleClassifier_MediumMessage(t *testing.T) { + c := &RuleClassifier{} + // 50-200 tokens = 0.15, below threshold → light + score := c.Score(Features{TokenEstimate: 100}) + if score >= 0.35 { + t.Errorf("medium message: score %f should be below default threshold 0.35", score) + } +} + +func TestRuleClassifier_ShortMessage(t *testing.T) { + c := &RuleClassifier{} + // <50 tokens, no other signals = 0.0 → light + score := c.Score(Features{TokenEstimate: 10}) + if score != 0.0 { + t.Errorf("short message: got %f, want 0.0", score) + } +} + +func TestRuleClassifier_ToolCallDensity(t *testing.T) { + c := &RuleClassifier{} + + scoreNone := c.Score(Features{RecentToolCalls: 0}) + scoreLow := c.Score(Features{RecentToolCalls: 2}) + scoreHigh := c.Score(Features{RecentToolCalls: 5}) + + if scoreNone != 0.0 { + t.Errorf("no tools: got %f, want 0.0", scoreNone) + } + if scoreLow <= scoreNone { + t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone) + } + if scoreHigh <= scoreLow { + t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow) + } +} + +func TestRuleClassifier_DeepConversation(t *testing.T) { + c := &RuleClassifier{} + shallow := c.Score(Features{ConversationDepth: 5}) + deep := c.Score(Features{ConversationDepth: 15}) + if deep <= shallow { + t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow) + } +} + +func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) { + c := &RuleClassifier{} + // Max all signals simultaneously + f := Features{ + TokenEstimate: 500, + CodeBlockCount: 3, + RecentToolCalls: 10, + ConversationDepth: 20, + } + score := c.Score(f) + if score > 1.0 { + t.Errorf("score %f exceeds 1.0", score) + } +} + +// ── Router ─────────────────────────────────────────────────────────────────── + +func TestRouter_DefaultThreshold(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash"}) + if r.Threshold() != defaultThreshold { + t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold) + } +} + +func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1}) + if r.Threshold() != defaultThreshold { + t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold) + } +} + +func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + msg := "hi" + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if !usedLight { + t.Error("simple message: expected light model to be selected") + } + if model != "gemini-flash" { + t.Errorf("simple message: model got %q, want %q", model, "gemini-flash") + } +} + +func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + msg := "```go\nfmt.Println(\"hello\")\n```" + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("code block: expected primary model to be selected") + } + if model != "claude-sonnet-4-6" { + t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6") + } +} + +func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + msg := "can you analyze this? data:image/png;base64,abc123" + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("attachment: expected primary model to be selected") + } + if model != "claude-sonnet-4-6" { + t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6") + } +} + +func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + // >200 token estimate: 210 * 3 = 630 chars + msg := strings.Repeat("word ", 210) + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("long message: expected primary model to be selected") + } + if model != "claude-sonnet-4-6" { + t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6") + } +} + +func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) { + // Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior. + // Routing is conservative: only promote to heavy when the signal is unambiguous. + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + history := []providers.Message{ + {Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}}, + {Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}}, + } + msg := "ok" + _, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6") + if !usedLight { + t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)") + } +} + +func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) { + // Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + history := []providers.Message{ + {Role: "assistant", ToolCalls: []providers.ToolCall{ + {Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"}, + }}, + } + // ~55 tokens * 3 = 165 chars + msg := strings.Repeat("word ", 55) + _, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6") + if usedLight { + t.Error("tool chain + medium message: expected primary model (score >= 0.35)") + } +} + +func TestRouter_SelectModel_CustomThreshold(t *testing.T) { + // Very low threshold: even a short message triggers heavy model + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05}) + msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05 + _, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("low threshold: medium message should use primary model") + } +} + +func TestRouter_SelectModel_HighThreshold(t *testing.T) { + // Very high threshold: even code blocks route to light + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99}) + msg := "```go\nfmt.Println()\n```" + _, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if !usedLight { + t.Error("very high threshold: code block (0.40) should route to light model") + } +} + +func TestRouter_LightModel(t *testing.T) { + r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35}) + if r.LightModel() != "my-fast-model" { + t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model") + } +} + +// ── newWithClassifier (internal testing hook) ───────────────────────────────── + +type fixedScoreClassifier struct{ score float64 } + +func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score } + +func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) { + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.2}, + ) + _, usedLight, _ := r.SelectModel("anything", nil, "heavy") + if !usedLight { + t.Error("low score with custom classifier: expected light model") + } +} + +func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) { + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.8}, + ) + _, usedLight, _ := r.SelectModel("anything", nil, "heavy") + if usedLight { + t.Error("high score with custom classifier: expected primary model") + } +} + +func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) { + // score == threshold → primary (uses >= comparison) + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.5}, + ) + _, usedLight, _ := r.SelectModel("anything", nil, "heavy") + if usedLight { + t.Error("score == threshold: expected primary model (>= threshold → primary)") + } +} + +func TestRouter_SelectModel_ReturnsScore(t *testing.T) { + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.42}, + ) + _, _, score := r.SelectModel("anything", nil, "heavy") + if score != 0.42 { + t.Errorf("score: got %f, want 0.42", score) + } +}