Skip to content

Commit 9b1e73d

Browse files
authored
Merge pull request #994 from is-Xiaoen/feat/model-routing
feat(routing): intelligent model routing based on structural complexity scoring
2 parents 4d965f2 + b84adac commit 9b1e73d

7 files changed

Lines changed: 809 additions & 20 deletions

File tree

pkg/agent/instance.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ type AgentInstance struct {
3737
Subagents *config.SubagentsConfig
3838
SkillsFilter []string
3939
Candidates []providers.FallbackCandidate
40+
41+
// Router is non-nil when model routing is configured and the light model
42+
// was successfully resolved. It scores each incoming message and decides
43+
// whether to route to LightCandidates or stay with Candidates.
44+
Router *routing.Router
45+
// LightCandidates holds the resolved provider candidates for the light model.
46+
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
47+
LightCandidates []providers.FallbackCandidate
4048
}
4149

4250
// NewAgentInstance creates an agent instance from config.
@@ -180,6 +188,25 @@ func NewAgentInstance(
180188

181189
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
182190

191+
// Model routing setup: pre-resolve light model candidates at creation time
192+
// to avoid repeated model_list lookups on every incoming message.
193+
var router *routing.Router
194+
var lightCandidates []providers.FallbackCandidate
195+
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
196+
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
197+
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
198+
if len(resolved) > 0 {
199+
router = routing.New(routing.RouterConfig{
200+
LightModel: rc.LightModel,
201+
Threshold: rc.Threshold,
202+
})
203+
lightCandidates = resolved
204+
} else {
205+
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
206+
rc.LightModel, agentID)
207+
}
208+
}
209+
183210
return &AgentInstance{
184211
ID: agentID,
185212
Name: agentName,
@@ -200,6 +227,8 @@ func NewAgentInstance(
200227
Subagents: subagents,
201228
SkillsFilter: skillsFilter,
202229
Candidates: candidates,
230+
Router: router,
231+
LightCandidates: lightCandidates,
203232
}
204233
}
205234

pkg/agent/loop.go

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,12 @@ func (al *AgentLoop) runLLMIteration(
824824
iteration := 0
825825
var finalContent string
826826

827+
// Determine effective model tier for this conversation turn.
828+
// selectCandidates evaluates routing once and the decision is sticky for
829+
// all tool-follow-up iterations within the same turn so that a multi-step
830+
// tool chain doesn't switch models mid-way through.
831+
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
832+
827833
for iteration < agent.MaxIterations {
828834
iteration++
829835

@@ -842,7 +848,7 @@ func (al *AgentLoop) runLLMIteration(
842848
map[string]any{
843849
"agent_id": agent.ID,
844850
"iteration": iteration,
845-
"model": agent.Model,
851+
"model": activeModel,
846852
"messages_count": len(messages),
847853
"tools_count": len(providerToolDefs),
848854
"max_tokens": agent.MaxTokens,
@@ -858,7 +864,7 @@ func (al *AgentLoop) runLLMIteration(
858864
"tools_json": formatToolsForLog(providerToolDefs),
859865
})
860866

861-
// Call LLM with fallback chain if candidates are configured.
867+
// Call LLM with fallback chain if multiple candidates are configured.
862868
var response *providers.LLMResponse
863869
var err error
864870

@@ -879,10 +885,10 @@ func (al *AgentLoop) runLLMIteration(
879885
}
880886

881887
callLLM := func() (*providers.LLMResponse, error) {
882-
if len(agent.Candidates) > 1 && al.fallback != nil {
888+
if len(activeCandidates) > 1 && al.fallback != nil {
883889
fbResult, fbErr := al.fallback.Execute(
884890
ctx,
885-
agent.Candidates,
891+
activeCandidates,
886892
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
887893
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
888894
},
@@ -900,7 +906,7 @@ func (al *AgentLoop) runLLMIteration(
900906
}
901907
return fbResult.Response, nil
902908
}
903-
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts)
909+
return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
904910
}
905911

906912
// Retry loop for context/token errors
@@ -1169,6 +1175,44 @@ func (al *AgentLoop) runLLMIteration(
11691175
return finalContent, iteration, nil
11701176
}
11711177

1178+
// selectCandidates returns the model candidates and resolved model name to use
1179+
// for a conversation turn. When model routing is configured and the incoming
1180+
// message scores below the complexity threshold, it returns the light model
1181+
// candidates instead of the primary ones.
1182+
//
1183+
// The returned (candidates, model) pair is used for all LLM calls within one
1184+
// turn — tool follow-up iterations use the same tier as the initial call so
1185+
// that a multi-step tool chain doesn't switch models mid-way.
1186+
func (al *AgentLoop) selectCandidates(
1187+
agent *AgentInstance,
1188+
userMsg string,
1189+
history []providers.Message,
1190+
) (candidates []providers.FallbackCandidate, model string) {
1191+
if agent.Router == nil || len(agent.LightCandidates) == 0 {
1192+
return agent.Candidates, agent.Model
1193+
}
1194+
1195+
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
1196+
if !usedLight {
1197+
logger.DebugCF("agent", "Model routing: primary model selected",
1198+
map[string]any{
1199+
"agent_id": agent.ID,
1200+
"score": score,
1201+
"threshold": agent.Router.Threshold(),
1202+
})
1203+
return agent.Candidates, agent.Model
1204+
}
1205+
1206+
logger.InfoCF("agent", "Model routing: light model selected",
1207+
map[string]any{
1208+
"agent_id": agent.ID,
1209+
"light_model": agent.Router.LightModel(),
1210+
"score": score,
1211+
"threshold": agent.Router.Threshold(),
1212+
})
1213+
return agent.LightCandidates, agent.Router.LightModel()
1214+
}
1215+
11721216
// maybeSummarize triggers summarization if the session history exceeds thresholds.
11731217
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
11741218
newHistory := agent.Sessions.GetHistory(sessionKey)

pkg/config/config.go

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -167,22 +167,35 @@ type SessionConfig struct {
167167
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
168168
}
169169

170+
// RoutingConfig controls the intelligent model routing feature.
171+
// When enabled, each incoming message is scored against structural features
172+
// (message length, code blocks, tool call history, conversation depth, attachments).
173+
// Messages scoring below Threshold are sent to LightModel; all others use the
174+
// agent's primary model. This reduces cost and latency for simple tasks without
175+
// requiring any keyword matching — all scoring is language-agnostic.
176+
type RoutingConfig struct {
177+
Enabled bool `json:"enabled"`
178+
LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks
179+
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
180+
}
181+
170182
type AgentDefaults struct {
171-
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
172-
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
173-
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
174-
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
175-
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
176-
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
177-
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
178-
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
179-
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
180-
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
181-
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
182-
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
183-
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
184-
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
185-
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
183+
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
184+
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
185+
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
186+
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
187+
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
188+
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
189+
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
190+
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
191+
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
192+
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
193+
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
194+
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
195+
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
196+
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
197+
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
198+
Routing *RoutingConfig `json:"routing,omitempty"`
186199
}
187200

188201
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB

pkg/routing/classifier.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package routing
2+
3+
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
4+
// A higher score indicates a more complex task that benefits from a heavy model.
5+
// The score is compared against the configured threshold: score >= threshold selects
6+
// the primary (heavy) model; score < threshold selects the light model.
7+
//
8+
// Classifier is an interface so that future implementations (ML-based, embedding-based,
9+
// or any other approach) can be swapped in without changing routing infrastructure.
10+
type Classifier interface {
11+
Score(f Features) float64
12+
}
13+
14+
// RuleClassifier is the v1 implementation.
15+
// It uses a weighted sum of structural signals with no external dependencies,
16+
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
17+
// that the returned score always falls within the [0, 1] contract.
18+
//
19+
// Individual weights (multiple signals can fire simultaneously):
20+
//
21+
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
22+
// token 50-200: 0.15 — medium length; may or may not be complex
23+
// code block present: 0.40 — coding tasks need the heavy model
24+
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
25+
// tool calls 1-3 (recent): 0.10 — some tool activity
26+
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
27+
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
28+
//
29+
// Default threshold is 0.35, so:
30+
// - Pure greetings / trivial Q&A: 0.00 → light ✓
31+
// - Medium prose message (50–200 tokens): 0.15 → light ✓
32+
// - Message with code block: 0.40 → heavy ✓
33+
// - Long message (>200 tokens): 0.35 → heavy ✓
34+
// - Active tool session + medium message: 0.25 → light (acceptable)
35+
// - Any message with an image/audio attachment: 1.00 → heavy ✓
36+
type RuleClassifier struct{}
37+
38+
// Score computes the complexity score for the given feature set.
39+
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
40+
func (c *RuleClassifier) Score(f Features) float64 {
41+
// Hard gate: multi-modal inputs always require the heavy model.
42+
if f.HasAttachments {
43+
return 1.0
44+
}
45+
46+
var score float64
47+
48+
// Token estimate — primary verbosity signal
49+
switch {
50+
case f.TokenEstimate > 200:
51+
score += 0.35
52+
case f.TokenEstimate > 50:
53+
score += 0.15
54+
}
55+
56+
// Fenced code blocks — strongest indicator of a coding/technical task
57+
if f.CodeBlockCount > 0 {
58+
score += 0.40
59+
}
60+
61+
// Recent tool call density — indicates an ongoing agentic workflow
62+
switch {
63+
case f.RecentToolCalls > 3:
64+
score += 0.25
65+
case f.RecentToolCalls > 0:
66+
score += 0.10
67+
}
68+
69+
// Conversation depth — accumulated context implies compound task
70+
if f.ConversationDepth > 10 {
71+
score += 0.10
72+
}
73+
74+
// Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire
75+
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
76+
if score > 1.0 {
77+
score = 1.0
78+
}
79+
return score
80+
}

0 commit comments

Comments
 (0)