diff --git a/go.mod b/go.mod index 9f755bbc91..1c699a7243 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 + github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 @@ -16,6 +17,7 @@ require ( github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 + github.com/rivo/tview v0.42.0 github.com/slack-go/slack v0.17.3 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 @@ -35,7 +37,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect - github.com/gdamore/tcell/v2 v2.13.8 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect @@ -44,7 +45,6 @@ require ( github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/rivo/tview v0.42.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rs/zerolog v1.34.0 // indirect github.com/spf13/pflag v1.0.10 // indirect diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 88afa61194..ac9b449a21 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -178,6 +178,17 @@ func (al *AgentLoop) Run(ctx context.Context) error { // Initialize MCP servers for all agents if al.cfg.Tools.MCP.Enabled { mcpManager := mcp.NewManager() + // Ensure MCP connections are cleaned up on exit, regardless of initialization success + // This fixes resource leak when LoadFromMCPConfig partially succeeds then fails + defer func() { + if err := mcpManager.Close(); err != nil { + logger.ErrorCF("agent", "Failed to close MCP manager", + map[string]any{ + "error": err.Error(), + }) + } + }() + defaultAgent := al.registry.GetDefaultAgent() var workspacePath string if defaultAgent != nil && defaultAgent.Workspace != "" { @@ -192,16 +203,6 @@ func (al *AgentLoop) Run(ctx context.Context) error { "error": err.Error(), }) } else { - // Ensure MCP connections are cleaned up on exit, only if initialization succeeded - defer func() { - if err := mcpManager.Close(); err != nil { - logger.ErrorCF("agent", "Failed to close MCP manager", - map[string]any{ - "error": err.Error(), - }) - } - }() - // Register MCP tools for all agents servers := mcpManager.GetServers() uniqueTools := 0 diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go index 8b6d6d9aa9..7b63cc979f 100644 --- a/pkg/mcp/manager.go +++ b/pkg/mcp/manager.go @@ -11,6 +11,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -108,7 +109,7 @@ type ServerConnection struct { type Manager struct { servers map[string]*ServerConnection mu sync.RWMutex - closed bool + closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race wg sync.WaitGroup // tracks in-flight CallTool calls } @@ -440,14 +441,20 @@ func (m *Manager) CallTool( serverName, toolName string, arguments map[string]any, ) (*mcp.CallToolResult, error) { + // Check if closed before acquiring lock (fast path) + if m.closed.Load() { + return nil, fmt.Errorf("manager is closed") + } + m.mu.RLock() - if m.closed { + // Double-check after acquiring lock to prevent TOCTOU race + if m.closed.Load() { m.mu.RUnlock() return nil, fmt.Errorf("manager is closed") } conn, ok := m.servers[serverName] if ok { - m.wg.Add(1) + m.wg.Add(1) // Add to WaitGroup while holding the lock } m.mu.RUnlock() @@ -471,15 +478,14 @@ func (m *Manager) CallTool( // Close closes all server connections func (m *Manager) Close() error { - m.mu.Lock() - if m.closed { - m.mu.Unlock() - return nil + // Use Swap to atomically set closed=true and get the previous value + // This prevents TOCTOU race with CallTool's closed check + if m.closed.Swap(true) { + return nil // already closed } - m.closed = true - m.mu.Unlock() // Wait for all in-flight CallTool calls to finish before closing sessions + // After closed=true is set, no new CallTool can start (they check closed first) m.wg.Wait() m.mu.Lock() diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go index 6dd71a3c2f..8ce81d09e8 100644 --- a/pkg/mcp/manager_test.go +++ b/pkg/mcp/manager_test.go @@ -268,7 +268,7 @@ func TestGetAllTools_FiltersEmptyTools(t *testing.T) { func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) { t.Run("manager closed", func(t *testing.T) { mgr := NewManager() - mgr.closed = true + mgr.closed.Store(true) _, err := mgr.CallTool(context.Background(), "s1", "tool", nil) if err == nil || !strings.Contains(err.Error(), "manager is closed") {