Skip to content

Commit a21f7a9

Browse files
author
xuwenzhe
committed
fix: spawn tool now respects target agent's model config
Fixes #1322 Previously, when spawn tool was called with agent_id parameter, the subagent would use the caller agent's model instead of the target agent's configured model. Changes: - Add modelResolver field to SubagentManager - Add SetModelResolver() method to inject model resolution function - In runTask(), use target agent's model when agent_id is specified - Add test case for model resolver functionality Test Plan: - All existing tests pass - New TestSubagentManager_ModelResolver test passes
1 parent 30584f0 commit a21f7a9

File tree

3 files changed

+66
-1
lines changed

3 files changed

+66
-1
lines changed

pkg/agent/loop.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,13 @@ func registerSharedTools(
224224
if cfg.Tools.IsToolEnabled("subagent") {
225225
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
226226
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
227+
// Set model resolver so spawn can use target agent's model
228+
subagentManager.SetModelResolver(func(targetAgentID string) string {
229+
if targetAgent, ok := registry.GetAgent(targetAgentID); ok {
230+
return targetAgent.Model
231+
}
232+
return ""
233+
})
227234
spawnTool := tools.NewSpawnTool(subagentManager)
228235
currentAgentID := agentID
229236
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {

pkg/tools/spawn_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,38 @@ func TestSpawnTool_Execute_NilManager(t *testing.T) {
7777
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
7878
}
7979
}
80+
81+
func TestSubagentManager_ModelResolver(t *testing.T) {
82+
provider := &MockLLMProvider{}
83+
manager := NewSubagentManager(provider, "default-model", "/tmp/test")
84+
85+
// Set up model resolver
86+
resolvedAgentID := ""
87+
manager.SetModelResolver(func(agentID string) string {
88+
resolvedAgentID = agentID
89+
if agentID == "premium-agent" {
90+
return "gpt-4"
91+
}
92+
return ""
93+
})
94+
95+
// Verify resolver is set
96+
if manager.modelResolver == nil {
97+
t.Fatal("Model resolver should be set")
98+
}
99+
100+
// Test resolver is called with correct agent ID
101+
result := manager.modelResolver("premium-agent")
102+
if resolvedAgentID != "premium-agent" {
103+
t.Errorf("Expected resolver to be called with 'premium-agent', got '%s'", resolvedAgentID)
104+
}
105+
if result != "gpt-4" {
106+
t.Errorf("Expected 'gpt-4', got '%s'", result)
107+
}
108+
109+
// Test fallback for unknown agent
110+
result = manager.modelResolver("unknown-agent")
111+
if result != "" {
112+
t.Errorf("Expected empty string for unknown agent, got '%s'", result)
113+
}
114+
}

pkg/tools/subagent.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ type SubagentManager struct {
3434
hasMaxTokens bool
3535
hasTemperature bool
3636
nextID int
37+
// modelResolver resolves agentID to model name.
38+
// Returns empty string if agent not found (falls back to defaultModel).
39+
modelResolver func(agentID string) string
3740
}
3841

3942
func NewSubagentManager(
@@ -61,6 +64,16 @@ func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) {
6164
sm.hasTemperature = true
6265
}
6366

67+
// SetModelResolver sets a function to resolve agentID to model name.
68+
// When spawn is called with agent_id, this resolver is used to get the
69+
// target agent's configured model. If the resolver returns empty string
70+
// or is not set, falls back to the defaultModel.
71+
func (sm *SubagentManager) SetModelResolver(resolver func(agentID string) string) {
72+
sm.mu.Lock()
73+
defer sm.mu.Unlock()
74+
sm.modelResolver = resolver
75+
}
76+
6477
// SetTools sets the tool registry for subagent execution.
6578
// If not set, subagent will have access to the provided tools.
6679
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
@@ -147,8 +160,18 @@ After completing the task, provide a clear summary of what was done.`
147160
temperature := sm.temperature
148161
hasMaxTokens := sm.hasMaxTokens
149162
hasTemperature := sm.hasTemperature
163+
modelResolver := sm.modelResolver
164+
defaultModel := sm.defaultModel
150165
sm.mu.RUnlock()
151166

167+
// Resolve target agent model if agentID is specified
168+
model := defaultModel
169+
if task.AgentID != "" && modelResolver != nil {
170+
if resolvedModel := modelResolver(task.AgentID); resolvedModel != "" {
171+
model = resolvedModel
172+
}
173+
}
174+
152175
var llmOptions map[string]any
153176
if hasMaxTokens || hasTemperature {
154177
llmOptions = map[string]any{}
@@ -162,7 +185,7 @@ After completing the task, provide a clear summary of what was done.`
162185

163186
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
164187
Provider: sm.provider,
165-
Model: sm.defaultModel,
188+
Model: model,
166189
Tools: tools,
167190
MaxIterations: maxIter,
168191
LLMOptions: llmOptions,

0 commit comments

Comments
 (0)