diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 6593c67f8a9..57c10e7a094 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -98,6 +98,7 @@ vi.mock('../tools/mcp-client-manager.js', () => ({ McpClientManager: vi.fn().mockImplementation(() => ({ startConfiguredMcpServers: vi.fn(), getMcpInstructions: vi.fn().mockReturnValue('MCP Instructions'), + setMainRegistries: vi.fn(), })), })); @@ -368,6 +369,7 @@ describe('Server Config (config.ts)', () => { mcpStarted = true; }), getMcpInstructions: vi.fn(), + setMainRegistries: vi.fn(), }) as Partial as McpClientManager, ); @@ -401,6 +403,7 @@ describe('Server Config (config.ts)', () => { mcpStarted = true; }), getMcpInstructions: vi.fn(), + setMainRegistries: vi.fn(), }) as Partial as McpClientManager, ); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 31c2128f316..cd795002680 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -239,6 +239,8 @@ export interface AgentOverride { modelConfig?: ModelConfig; runConfig?: AgentRunConfig; enabled?: boolean; + tools?: string[]; + mcpServers?: Record; } export interface AgentSettings { @@ -520,6 +522,7 @@ export interface ConfigParameters { question?: string; coreTools?: string[]; + mainAgentTools?: string[]; /** @deprecated Use Policy Engine instead */ allowedTools?: string[]; /** @deprecated Use Policy Engine instead */ @@ -675,6 +678,7 @@ export class Config implements McpContext, AgentLoopContext { readonly enableConseca: boolean; private readonly coreTools: string[] | undefined; + private readonly mainAgentTools: string[] | undefined; /** @deprecated Use Policy Engine instead */ private readonly allowedTools: string[] | undefined; /** @deprecated Use Policy Engine instead */ @@ -888,6 +892,7 @@ export class Config implements McpContext, AgentLoopContext { this.question = params.question; this.coreTools = params.coreTools; + this.mainAgentTools = params.mainAgentTools; this.allowedTools = params.allowedTools; this.excludeTools = params.excludeTools; this.toolDiscoveryCommand = params.toolDiscoveryCommand; @@ -1231,10 +1236,14 @@ export class Config implements McpContext, AgentLoopContext { discoverToolsHandle?.end(); this.mcpClientManager = new McpClientManager( this.clientVersion, - this._toolRegistry, this, this.eventEmitter, ); + this.mcpClientManager.setMainRegistries({ + toolRegistry: this._toolRegistry, + promptRegistry: this.promptRegistry, + resourceRegistry: this.resourceRegistry, + }); // We do not await this promise so that the CLI can start up even if // MCP servers are slow to connect. this.mcpInitializationPromise = Promise.allSettled([ @@ -1887,6 +1896,10 @@ export class Config implements McpContext, AgentLoopContext { return this.coreTools; } + getMainAgentTools(): string[] | undefined { + return this.mainAgentTools; + } + getAllowedTools(): string[] | undefined { return this.allowedTools; } @@ -2982,7 +2995,11 @@ export class Config implements McpContext, AgentLoopContext { } async createToolRegistry(): Promise { - const registry = new ToolRegistry(this, this.messageBus); + const registry = new ToolRegistry( + this, + this.messageBus, + /* isMainRegistry= */ true, + ); // helper to create & register core tools that are enabled const maybeRegister = ( diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts index c35ae2e0841..dce8708628b 100644 --- a/packages/core/src/tools/mcp-client-manager.test.ts +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -14,9 +14,11 @@ import { type MockedObject, } from 'vitest'; import { McpClientManager } from './mcp-client-manager.js'; -import { McpClient, MCPDiscoveryState } from './mcp-client.js'; +import { McpClient, MCPDiscoveryState, MCPServerStatus } from './mcp-client.js'; import type { ToolRegistry } from './tool-registry.js'; import type { Config, GeminiCLIExtension } from '../config/config.js'; +import type { PromptRegistry } from '../prompts/prompt-registry.js'; +import type { ResourceRegistry } from '../resources/resource-registry.js'; vi.mock('./mcp-client.js', async () => { const originalModule = await vi.importActual('./mcp-client.js'); @@ -34,21 +36,25 @@ describe('McpClientManager', () => { beforeEach(() => { mockedMcpClient = vi.mockObject({ connect: vi.fn(), - discover: vi.fn(), + discoverInto: vi.fn(), disconnect: vi.fn(), - getStatus: vi.fn(), + getStatus: vi.fn().mockReturnValue(MCPServerStatus.DISCONNECTED), getServerConfig: vi.fn(), + getServerName: vi.fn().mockReturnValue('test-server'), } as unknown as McpClient); vi.mocked(McpClient).mockReturnValue(mockedMcpClient); mockConfig = vi.mockObject({ isTrustedFolder: vi.fn().mockReturnValue(true), getMcpServers: vi.fn().mockReturnValue({}), - getPromptRegistry: () => {}, - getResourceRegistry: () => {}, + getPromptRegistry: vi.fn().mockReturnValue({ registerPrompt: vi.fn() }), + getResourceRegistry: vi + .fn() + .mockReturnValue({ setResourcesForServer: vi.fn() }), getDebugMode: () => false, - getWorkspaceContext: () => {}, + getWorkspaceContext: () => ({ getDirectories: () => [] }), getAllowedMcpServers: vi.fn().mockReturnValue([]), getBlockedMcpServers: vi.fn().mockReturnValue([]), + getExcludedMcpServers: vi.fn().mockReturnValue([]), getMcpServerCommand: vi.fn().mockReturnValue(''), getMcpEnablementCallbacks: vi.fn().mockReturnValue(undefined), getGeminiClient: vi.fn().mockReturnValue({ @@ -56,21 +62,39 @@ describe('McpClientManager', () => { }), refreshMcpContext: vi.fn(), } as unknown as Config); - toolRegistry = {} as ToolRegistry; + toolRegistry = vi.mockObject({ + registerTool: vi.fn(), + unregisterTool: vi.fn(), + sortTools: vi.fn(), + getMessageBus: vi.fn().mockReturnValue({}), + removeMcpToolsByServer: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), + } as unknown as ToolRegistry); }); afterEach(() => { vi.restoreAllMocks(); }); + const setupManager = (manager: McpClientManager) => { + manager.setMainRegistries({ + toolRegistry, + promptRegistry: + mockConfig.getPromptRegistry() as unknown as PromptRegistry, + resourceRegistry: + mockConfig.getResourceRegistry() as unknown as ResourceRegistry, + }); + return manager; + }; + it('should discover tools from all configured', async () => { mockConfig.getMcpServers.mockReturnValue({ 'test-server': { command: 'node' }, }); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); - expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce(); expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce(); }); @@ -80,12 +104,12 @@ describe('McpClientManager', () => { 'server-2': { command: 'node' }, 'server-3': { command: 'node' }, }); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); // Each client should be connected/discovered expect(mockedMcpClient.connect).toHaveBeenCalledTimes(3); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(3); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(3); // But context refresh should happen only once expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce(); @@ -95,7 +119,7 @@ describe('McpClientManager', () => { mockConfig.getMcpServers.mockReturnValue({ 'test-server': { command: 'node' }, }); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.NOT_STARTED); const promise = manager.startConfiguredMcpServers(); expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS); @@ -112,7 +136,7 @@ describe('McpClientManager', () => { isFileEnabled: vi.fn().mockResolvedValue(false), }); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const promise = manager.startConfiguredMcpServers(); expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS); await promise; @@ -120,7 +144,7 @@ describe('McpClientManager', () => { expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED); expect(manager.getMcpServerCount()).toBe(0); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should mark discovery completed when all configured servers are blocked', async () => { @@ -129,7 +153,7 @@ describe('McpClientManager', () => { }); mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const promise = manager.startConfiguredMcpServers(); expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS); await promise; @@ -137,7 +161,7 @@ describe('McpClientManager', () => { expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED); expect(manager.getMcpServerCount()).toBe(0); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should not discover tools if folder is not trusted', async () => { @@ -145,10 +169,10 @@ describe('McpClientManager', () => { 'test-server': { command: 'node' }, }); mockConfig.isTrustedFolder.mockReturnValue(false); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should not start blocked servers', async () => { @@ -156,10 +180,10 @@ describe('McpClientManager', () => { 'test-server': { command: 'node' }, }); mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should only start allowed servers if allow list is not empty', async () => { @@ -168,14 +192,14 @@ describe('McpClientManager', () => { 'another-server': { command: 'node' }, }); mockConfig.getAllowedMcpServers.mockReturnValue(['another-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); - expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce(); }); it('should start servers from extensions', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startExtension({ name: 'test-extension', mcpServers: { @@ -188,11 +212,11 @@ describe('McpClientManager', () => { id: '123', }); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); - expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce(); }); it('should not start servers from disabled extensions', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startExtension({ name: 'test-extension', mcpServers: { @@ -205,7 +229,7 @@ describe('McpClientManager', () => { id: '123', }); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should add blocked servers to the blockedMcpServers list', async () => { @@ -213,7 +237,7 @@ describe('McpClientManager', () => { 'test-server': { command: 'node' }, }); mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(manager.getBlockedMcpServers()).toEqual([ { name: 'test-server', extensionName: '' }, @@ -224,10 +248,10 @@ describe('McpClientManager', () => { mockConfig.getMcpServers.mockReturnValue({ 'test-server': { excludeTools: ['dangerous_tool'] }, }); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); // But it should still be tracked in allServerConfigs expect(manager.getMcpServers()).toHaveProperty('test-server'); @@ -240,16 +264,16 @@ describe('McpClientManager', () => { 'test-server': serverConfig, }); mockedMcpClient.getServerConfig.mockReturnValue(serverConfig); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1); await manager.restart(); expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1); expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2); }); }); @@ -260,21 +284,21 @@ describe('McpClientManager', () => { 'test-server': serverConfig, }); mockedMcpClient.getServerConfig.mockReturnValue(serverConfig); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1); await manager.restartServer('test-server'); expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1); expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2); }); it('should throw an error if the server does not exist', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await expect(manager.restartServer('non-existent')).rejects.toThrow( 'No MCP server registered with the name "non-existent"', ); @@ -296,7 +320,7 @@ describe('McpClientManager', () => { }); mockedMcpClient.getServerConfig.mockReturnValue(originalConfig); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); // First call should use the original config @@ -321,9 +345,10 @@ describe('McpClientManager', () => { (name, config) => ({ connect: vi.fn(), - discover: vi.fn(), + discoverInto: vi.fn(), disconnect: vi.fn(), getServerConfig: vi.fn().mockReturnValue(config), + getServerName: vi.fn().mockReturnValue(name), getInstructions: vi .fn() .mockReturnValue( @@ -333,12 +358,7 @@ describe('McpClientManager', () => { ), }) as unknown as McpClient, ); - - const manager = new McpClientManager( - '0.0.1', - {} as ToolRegistry, - mockConfig, - ); + const manager = new McpClientManager('0.0.1', mockConfig); mockConfig.getMcpServers.mockReturnValue({ 'server-with-instructions': { command: 'node' }, @@ -373,11 +393,7 @@ describe('McpClientManager', () => { 'test-server': { command: 'node' }, }); - const manager = new McpClientManager( - '0.0.1', - {} as ToolRegistry, - mockConfig, - ); + const manager = new McpClientManager('0.0.1', mockConfig); await expect(manager.startConfiguredMcpServers()).resolves.not.toThrow(); }); @@ -396,11 +412,8 @@ describe('McpClientManager', () => { 'test-server': { command: 'node' }, }); - const manager = new McpClientManager( - '0.0.1', - {} as ToolRegistry, - mockConfig, - ); + const manager = new McpClientManager('0.0.1', mockConfig); + await manager.startConfiguredMcpServers(); await expect(manager.restartServer('test-server')).resolves.not.toThrow(); @@ -409,7 +422,7 @@ describe('McpClientManager', () => { describe('Extension handling', () => { it('should remove mcp servers from allServerConfigs when stopExtension is called', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const mcpServers = { 'test-server': { command: 'node', args: ['server.js'] }, }; @@ -431,7 +444,7 @@ describe('McpClientManager', () => { }); it('should merge extension configuration with an existing user-configured server', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const userConfig = { command: 'node', args: ['user-server.js'] }; mockConfig.getMcpServers.mockReturnValue({ @@ -468,7 +481,7 @@ describe('McpClientManager', () => { }); it('should securely merge tool lists and env variables regardless of load order', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const userConfig = { excludeTools: ['user-tool'], @@ -523,7 +536,7 @@ describe('McpClientManager', () => { // Reset for Case 2 vi.mocked(McpClient).mockClear(); - const manager2 = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager2 = setupManager(new McpClientManager('0.0.1', mockConfig)); // Case 2: User config loads first, then Extension loads // This call will skip discovery because userConfig has no connection details @@ -551,7 +564,7 @@ describe('McpClientManager', () => { }); it('should result in empty includeTools if intersection is empty', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const userConfig = { includeTools: ['user-tool'] }; const extConfig = { command: 'node', @@ -567,7 +580,7 @@ describe('McpClientManager', () => { }); it('should respect a single allowlist if only one is provided', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const userConfig = { includeTools: ['user-tool'] }; const extConfig = { command: 'node', args: ['ext.js'] }; @@ -579,7 +592,7 @@ describe('McpClientManager', () => { }); it('should allow partial overrides of connection properties', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const extConfig = { command: 'node', args: ['ext.js'], timeout: 1000 }; const userOverride = { args: ['overridden.js'] }; @@ -599,7 +612,7 @@ describe('McpClientManager', () => { }); it('should prevent one extension from hijacking another extension server name', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const extension1: GeminiCLIExtension = { name: 'extension-1', @@ -641,7 +654,7 @@ describe('McpClientManager', () => { it('should remove servers from blockedMcpServers when stopExtension is called', async () => { mockConfig.getBlockedMcpServers.mockReturnValue(['blocked-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const mcpServers = { 'blocked-server': { command: 'node', args: ['server.js'] }, }; @@ -679,7 +692,7 @@ describe('McpClientManager', () => { }); it('should emit hint instead of full error when user has not interacted with MCP', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.emitDiagnostic( 'error', 'Something went wrong', @@ -698,7 +711,7 @@ describe('McpClientManager', () => { }); it('should emit full error when user has interacted with MCP', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.setUserInteractedWithMcp(); manager.emitDiagnostic( 'error', @@ -714,7 +727,7 @@ describe('McpClientManager', () => { }); it('should still deduplicate diagnostic messages after user interaction', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.setUserInteractedWithMcp(); manager.emitDiagnostic('error', 'Same error'); @@ -724,7 +737,7 @@ describe('McpClientManager', () => { }); it('should only show hint once per session', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.emitDiagnostic('error', 'Error 1'); manager.emitDiagnostic('error', 'Error 2'); @@ -737,7 +750,7 @@ describe('McpClientManager', () => { }); it('should capture last error for a server even when silenced', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.emitDiagnostic( 'error', @@ -752,7 +765,7 @@ describe('McpClientManager', () => { }); it('should show previously deduplicated errors after interaction clears state', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.emitDiagnostic('error', 'Same error'); expect(coreEventsMock.emitFeedback).toHaveBeenCalledTimes(1); // The hint diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index b2a022402ec..a607b19508a 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -13,6 +13,7 @@ import type { ToolRegistry } from './tool-registry.js'; import { McpClient, MCPDiscoveryState, + MCPServerStatus, populateMcpServerCommand, } from './mcp-client.js'; import { getErrorMessage, isAuthenticationError } from '../utils/errors.js'; @@ -20,6 +21,11 @@ import type { EventEmitter } from 'node:events'; import { coreEvents } from '../utils/events.js'; import { debugLogger } from '../utils/debugLogger.js'; +import { createHash } from 'node:crypto'; +import { stableStringify } from '../policy/stable-stringify.js'; +import type { PromptRegistry } from '../prompts/prompt-registry.js'; +import type { ResourceRegistry } from '../resources/resource-registry.js'; + /** * Manages the lifecycle of multiple MCP clients, including local child processes. * This class is responsible for starting, stopping, and discovering tools from @@ -30,7 +36,6 @@ export class McpClientManager { // Track all configured servers (including disabled ones) for UI display private allServerConfigs: Map = new Map(); private readonly clientVersion: string; - private readonly toolRegistry: ToolRegistry; private readonly cliConfig: Config; // If we have ongoing MCP client discovery, this completes once that is done. private discoveryPromise: Promise | undefined; @@ -42,6 +47,10 @@ export class McpClientManager { extensionName: string; }> = []; + private mainToolRegistry: ToolRegistry | undefined; + private mainPromptRegistry: PromptRegistry | undefined; + private mainResourceRegistry: ResourceRegistry | undefined; + /** * Track whether the user has explicitly interacted with MCP in this session * (e.g. by running an /mcp command). @@ -66,16 +75,24 @@ export class McpClientManager { constructor( clientVersion: string, - toolRegistry: ToolRegistry, cliConfig: Config, eventEmitter?: EventEmitter, ) { this.clientVersion = clientVersion; - this.toolRegistry = toolRegistry; this.cliConfig = cliConfig; this.eventEmitter = eventEmitter; } + setMainRegistries(registries: { + toolRegistry: ToolRegistry; + promptRegistry: PromptRegistry; + resourceRegistry: ResourceRegistry; + }) { + this.mainToolRegistry = registries.toolRegistry; + this.mainPromptRegistry = registries.promptRegistry; + this.mainResourceRegistry = registries.resourceRegistry; + } + setUserInteractedWithMcp() { this.userInteractedWithMcp = true; } @@ -147,6 +164,16 @@ export class McpClientManager { return this.clients.get(serverName); } + removeRegistries(registries: { + toolRegistry: ToolRegistry; + promptRegistry: PromptRegistry; + resourceRegistry: ResourceRegistry; + }): void { + for (const client of this.clients.values()) { + client.removeRegistries(registries); + } + } + /** * For all the MCP servers associated with this extension: * @@ -236,16 +263,17 @@ export class McpClientManager { return false; } - private async disconnectClient(name: string, skipRefresh = false) { - const existing = this.clients.get(name); + private async disconnectClient(clientKey: string, skipRefresh = false) { + const existing = this.clients.get(clientKey); if (existing) { + const serverName = existing.getServerName(); try { - this.clients.delete(name); + this.clients.delete(clientKey); this.eventEmitter?.emit('mcp-client-update', this.clients); await existing.disconnect(); } catch (error) { debugLogger.warn( - `Error stopping client '${name}': ${getErrorMessage(error)}`, + `Error stopping client '${serverName}': ${getErrorMessage(error)}`, ); } finally { if (!skipRefresh) { @@ -257,6 +285,16 @@ export class McpClientManager { } } + private getClientKey(name: string, config: MCPServerConfig): string { + const { extension, ...rest } = config; + const keyData = { + name, + config: rest, + extensionId: extension?.id, + }; + return createHash('sha256').update(stableStringify(keyData)).digest('hex'); + } + /** * Merges two MCP configurations. The second configuration (override) * takes precedence for scalar properties, but array properties are @@ -305,6 +343,11 @@ export class McpClientManager { async maybeDiscoverMcpServer( name: string, config: MCPServerConfig, + registries?: { + toolRegistry: ToolRegistry; + promptRegistry: PromptRegistry; + resourceRegistry: ResourceRegistry; + }, ): Promise { const existingConfig = this.allServerConfigs.get(name); if ( @@ -337,11 +380,27 @@ export class McpClientManager { // Always track server config for UI display this.allServerConfigs.set(name, finalConfig); - // Capture the existing client synchronously here before any asynchronous - // operations. This ensures that if multiple discovery turns happen - // concurrently, this turn only replaces/disconnects the client that was - // present when this specific configuration update request began. - const existing = this.clients.get(name); + const clientKey = this.getClientKey(name, finalConfig); + + // If no registries are provided (main agent) and a server with this name already exists + // but with a different configuration, handle potential conflicts. + if (!registries) { + const existingSameName = Array.from(this.clients.values()).find( + (c) => c.getServerName() === name, + ); + if (existingSameName) { + const existingConfigFromClient = existingSameName.getServerConfig(); + const existingKey = this.getClientKey(name, existingConfigFromClient); + + if (existingKey !== clientKey) { + // This is a configuration update (hot-reload). + // We should stop the old client before starting the new one. + await this.disconnectClient(existingKey, true); + } + } + } + + const existing = this.clients.get(clientKey); // If no connection details are provided, we can't discover this server. // This often happens when a user provides only overrides (like excludeTools) @@ -363,7 +422,7 @@ export class McpClientManager { // User-disabled servers: disconnect if running, don't start if (await this.isDisabledByUser(name)) { if (existing) { - await this.disconnectClient(name); + await this.disconnectClient(clientKey); } return; } @@ -374,34 +433,48 @@ export class McpClientManager { return; } - const currentDiscoveryPromise = new Promise((resolve, reject) => { - (async () => { + const currentDiscoveryPromise = new Promise((resolve) => { + void (async () => { try { - if (existing) { - this.clients.delete(name); - await existing.disconnect(); + let client = existing; + if (!client) { + client = new McpClient( + name, + finalConfig, + this.cliConfig.getWorkspaceContext(), + this.cliConfig, + this.cliConfig.getDebugMode(), + this.clientVersion, + async () => { + debugLogger.log( + `🔔 Refreshing context for server '${name}'...`, + ); + await this.scheduleMcpContextRefresh(); + }, + ); + this.clients.set(clientKey, client); + this.eventEmitter?.emit('mcp-client-update', this.clients); } - const client = new McpClient( - name, - finalConfig, - this.toolRegistry, - this.cliConfig.getPromptRegistry(), - this.cliConfig.getResourceRegistry(), - this.cliConfig.getWorkspaceContext(), - this.cliConfig, - this.cliConfig.getDebugMode(), - this.clientVersion, - async () => { - debugLogger.log(`🔔 Refreshing context for server '${name}'...`); - await this.scheduleMcpContextRefresh(); - }, - ); - this.clients.set(name, client); - this.eventEmitter?.emit('mcp-client-update', this.clients); + const targetRegistries = + registries ?? + (this.mainToolRegistry && + this.mainPromptRegistry && + this.mainResourceRegistry + ? { + toolRegistry: this.mainToolRegistry, + promptRegistry: this.mainPromptRegistry, + resourceRegistry: this.mainResourceRegistry, + } + : undefined); + try { - await client.connect(); - await client.discover(this.cliConfig); + if (client.getStatus() === MCPServerStatus.DISCONNECTED) { + await client.connect(); + } + if (targetRegistries) { + await client.discoverInto(this.cliConfig, targetRegistries); + } this.eventEmitter?.emit('mcp-client-update', this.clients); } catch (error) { this.eventEmitter?.emit('mcp-client-update', this.clients); @@ -421,13 +494,13 @@ export class McpClientManager { const errorMessage = getErrorMessage(error); this.emitDiagnostic( 'error', - `Error initializing MCP server '${name}': ${errorMessage}`, + `Fatal error ensuring MCP server '${name}' is connected: ${errorMessage}`, error, ); } finally { resolve(); } - })().catch(reject); + })(); }); if (this.discoveryPromise) { @@ -510,6 +583,11 @@ export class McpClientManager { * Restarts all MCP servers (including newly enabled ones). */ async restart(): Promise { + const disconnectionPromises = Array.from(this.clients.keys()).map((key) => + this.disconnectClient(key, true), + ); + await Promise.all(disconnectionPromises); + await Promise.all( Array.from(this.allServerConfigs.entries()).map( async ([name, config]) => { @@ -534,6 +612,8 @@ export class McpClientManager { if (!config) { throw new Error(`No MCP server registered with the name "${name}"`); } + const clientKey = this.getClientKey(name, config); + await this.disconnectClient(clientKey, true); await this.maybeDiscoverMcpServer(name, config); await this.scheduleMcpContextRefresh(); } @@ -578,11 +658,12 @@ export class McpClientManager { getMcpInstructions(): string { const instructions: string[] = []; - for (const [name, client] of this.clients) { + for (const client of this.clients.values()) { + const serverName = client.getServerName(); const clientInstructions = client.getInstructions(); if (clientInstructions) { instructions.push( - `The following are instructions provided by the tool server '${name}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`, + `The following are instructions provided by the tool server '${serverName}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`, ); } } diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 21b5c28615f..4a14b671a0d 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +/* eslint-disable @typescript-eslint/no-explicit-any */ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; @@ -160,16 +161,17 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await client.discover(MOCK_CONTEXT); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); expect(mockedClient.listTools).toHaveBeenCalledWith( {}, expect.objectContaining({ timeout: 600000, progressReporter: client }), @@ -244,16 +246,17 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await client.discover(MOCK_CONTEXT); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2); expect(consoleWarnSpy).not.toHaveBeenCalled(); consoleWarnSpy.mockRestore(); @@ -296,16 +299,19 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await expect(client.discover(MOCK_CONTEXT)).rejects.toThrow('Test error'); + await expect( + client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }), + ).rejects.toThrow('Test error'); expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith( 'error', `Error discovering prompts from test-server: Test error`, @@ -354,18 +360,19 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await expect(client.discover(MOCK_CONTEXT)).rejects.toThrow( - 'No prompts, tools, or resources found on the server.', - ); + await expect( + client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }), + ).rejects.toThrow('No prompts, tools, or resources found on the server.'); }); it('should discover tools if server supports them', async () => { @@ -417,16 +424,17 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await client.discover(MOCK_CONTEXT); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); }); @@ -485,9 +493,6 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -495,7 +500,11 @@ describe('mcp-client', () => { ); await client.connect(); - await client.discover(mockConfig); + await client.discoverInto(mockConfig, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); // Verify tool registration expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); @@ -566,9 +575,6 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -576,7 +582,11 @@ describe('mcp-client', () => { ); await client.connect(); - await client.discover(mockConfig); + await client.discoverInto(mockConfig, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); expect(mockPolicyEngine.addRule).not.toHaveBeenCalled(); @@ -644,9 +654,6 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -654,7 +661,11 @@ describe('mcp-client', () => { ); await client.connect(); - await client.discover(mockConfig); + await client.discoverInto(mockConfig, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); @@ -733,16 +744,17 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await client.discover(MOCK_CONTEXT); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock .calls[0][0]; @@ -818,16 +830,17 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await client.discover(MOCK_CONTEXT); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); expect(resourceRegistry.setResourcesForServer).toHaveBeenCalledWith( 'test-server', [ @@ -907,16 +920,17 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await client.discover(MOCK_CONTEXT); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2); expect(resourceListHandler).toBeDefined(); @@ -996,16 +1010,17 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - promptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await client.discover(MOCK_CONTEXT); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2); expect(promptListHandler).toBeDefined(); @@ -1080,16 +1095,17 @@ describe('mcp-client', () => { { command: 'test-command', }, - mockedToolRegistry, - mockedPromptRegistry, - resourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', ); await client.connect(); - await client.discover(MOCK_CONTEXT); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry: mockedPromptRegistry, + resourceRegistry, + }); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce(); @@ -1138,24 +1154,27 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - mockedToolRegistry, - { + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + + await client.connect(); + // INJECTED REGISTRIES + (client as any).registeredRegistries?.add({ + toolRegistry: mockedToolRegistry, + promptRegistry: { getPromptsByServer: vi.fn().mockReturnValue([]), registerPrompt: vi.fn(), } as unknown as PromptRegistry, - { + resourceRegistry: { getResourcesByServer: vi.fn().mockReturnValue([]), registerResource: vi.fn(), removeResourcesByServer: vi.fn(), setResourcesForServer: vi.fn(), } as unknown as ResourceRegistry, - workspaceContext, - MOCK_CONTEXT, - false, - '0.0.1', - ); - - await client.connect(); + }); expect(mockedClient.setNotificationHandler).toHaveBeenCalledWith( ToolListChangedNotificationSchema, @@ -1183,28 +1202,31 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - { + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + + await client.connect(); + // INJECTED REGISTRIES + (client as any).registeredRegistries?.add({ + toolRegistry: { getToolsByServer: vi.fn().mockReturnValue([]), registerTool: vi.fn(), sortTools: vi.fn(), } as unknown as ToolRegistry, - { + promptRegistry: { getPromptsByServer: vi.fn().mockReturnValue([]), registerPrompt: vi.fn(), } as unknown as PromptRegistry, - { + resourceRegistry: { getResourcesByServer: vi.fn().mockReturnValue([]), registerResource: vi.fn(), removeResourcesByServer: vi.fn(), setResourcesForServer: vi.fn(), } as unknown as ResourceRegistry, - workspaceContext, - MOCK_CONTEXT, - false, - '0.0.1', - ); - - await client.connect(); + }); // Should be called for ProgressNotificationSchema, even if no other capabilities expect(mockedClient.setNotificationHandler).toHaveBeenCalled(); @@ -1234,28 +1256,31 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - { + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + + await client.connect(); + // INJECTED REGISTRIES + (client as any).registeredRegistries?.add({ + toolRegistry: { getToolsByServer: vi.fn().mockReturnValue([]), registerTool: vi.fn(), sortTools: vi.fn(), } as unknown as ToolRegistry, - { + promptRegistry: { getPromptsByServer: vi.fn().mockReturnValue([]), registerPrompt: vi.fn(), } as unknown as PromptRegistry, - { + resourceRegistry: { getResourcesByServer: vi.fn().mockReturnValue([]), registerResource: vi.fn(), removeResourcesByServer: vi.fn(), setResourcesForServer: vi.fn(), } as unknown as ResourceRegistry, - workspaceContext, - MOCK_CONTEXT, - false, - '0.0.1', - ); - - await client.connect(); + }); const toolUpdateCall = mockedClient.setNotificationHandler.mock.calls.find( @@ -1308,12 +1333,6 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - mockedToolRegistry, - {} as PromptRegistry, - { - removeMcpResourcesByServer: vi.fn(), - registerResource: vi.fn(), - } as unknown as ResourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -1323,6 +1342,15 @@ describe('mcp-client', () => { // 1. Connect (sets up listener) await client.connect(); + // INJECTED REGISTRIES + (client as any).registeredRegistries?.add({ + toolRegistry: mockedToolRegistry, + promptRegistry: {} as PromptRegistry, + resourceRegistry: { + removeMcpResourcesByServer: vi.fn(), + registerResource: vi.fn(), + } as unknown as ResourceRegistry, + }); // 2. Extract the callback passed to setNotificationHandler for tools const toolUpdateCall = @@ -1388,9 +1416,6 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - mockedToolRegistry, - {} as PromptRegistry, - {} as ResourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -1398,6 +1423,12 @@ describe('mcp-client', () => { ); await client.connect(); + // INJECTED REGISTRIES + (client as any).registeredRegistries?.add({ + toolRegistry: mockedToolRegistry, + promptRegistry: {} as PromptRegistry, + resourceRegistry: {} as ResourceRegistry, + }); const toolUpdateCall = mockedClient.setNotificationHandler.mock.calls.find( @@ -1463,9 +1494,6 @@ describe('mcp-client', () => { const clientA = new McpClient( 'server-A', { command: 'cmd-a' }, - mockedToolRegistry, - {} as PromptRegistry, - {} as ResourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -1476,9 +1504,6 @@ describe('mcp-client', () => { const clientB = new McpClient( 'server-B', { command: 'cmd-b' }, - mockedToolRegistry, - {} as PromptRegistry, - {} as ResourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -1487,7 +1512,19 @@ describe('mcp-client', () => { ); await clientA.connect(); + // INJECTED REGISTRIES + (clientA as any).registeredRegistries?.add({ + toolRegistry: mockedToolRegistry, + promptRegistry: {} as PromptRegistry, + resourceRegistry: {} as ResourceRegistry, + }); await clientB.connect(); + // INJECTED REGISTRIES + (clientB as any).registeredRegistries?.add({ + toolRegistry: mockedToolRegistry, + promptRegistry: {} as PromptRegistry, + resourceRegistry: {} as ResourceRegistry, + }); const toolUpdateCallA = mockClientA.setNotificationHandler.mock.calls.find( @@ -1572,25 +1609,28 @@ describe('mcp-client', () => { 'test-server', // Set a very short timeout { command: 'test-command', timeout: 50 }, - mockedToolRegistry, - { + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + + await client.connect(); + // INJECTED REGISTRIES + (client as any).registeredRegistries?.add({ + toolRegistry: mockedToolRegistry, + promptRegistry: { getPromptsByServer: vi.fn().mockReturnValue([]), registerPrompt: vi.fn(), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry, - { + resourceRegistry: { getResourcesByServer: vi.fn().mockReturnValue([]), registerResource: vi.fn(), removeResourcesByServer: vi.fn(), setResourcesForServer: vi.fn(), } as unknown as ResourceRegistry, - workspaceContext, - MOCK_CONTEXT, - false, - '0.0.1', - ); - - await client.connect(); + }); const toolUpdateCall = mockedClient.setNotificationHandler.mock.calls.find( @@ -1648,26 +1688,29 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - mockedToolRegistry, - { + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + onContextUpdatedSpy, + ); + + await client.connect(); + // INJECTED REGISTRIES + (client as any).registeredRegistries?.add({ + toolRegistry: mockedToolRegistry, + promptRegistry: { getPromptsByServer: vi.fn().mockReturnValue([]), registerPrompt: vi.fn(), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry, - { + resourceRegistry: { getResourcesByServer: vi.fn().mockReturnValue([]), registerResource: vi.fn(), removeResourcesByServer: vi.fn(), setResourcesForServer: vi.fn(), } as unknown as ResourceRegistry, - workspaceContext, - MOCK_CONTEXT, - false, - '0.0.1', - onContextUpdatedSpy, - ); - - await client.connect(); + }); const toolUpdateCall = mockedClient.setNotificationHandler.mock.calls.find( diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index b3e1023b594..58b7b6c8e22 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -130,6 +130,12 @@ export interface McpProgressReporter { unregisterProgressToken(token: string | number): void; } +export interface RegistrySet { + toolRegistry: ToolRegistry; + promptRegistry: PromptRegistry; + resourceRegistry: ResourceRegistry; +} + /** * A client for a single MCP server. * @@ -147,6 +153,8 @@ export class McpClient implements McpProgressReporter { private isRefreshingPrompts: boolean = false; private pendingPromptRefresh: boolean = false; + private readonly registeredRegistries = new Set(); + /** * Map of progress tokens to tool call IDs. * This allows us to route progress notifications to the correct tool call. @@ -156,9 +164,6 @@ export class McpClient implements McpProgressReporter { constructor( private readonly serverName: string, private readonly serverConfig: MCPServerConfig, - private readonly toolRegistry: ToolRegistry, - private readonly promptRegistry: PromptRegistry, - private readonly resourceRegistry: ResourceRegistry, private readonly workspaceContext: WorkspaceContext, private readonly cliConfig: McpContext, private readonly debugMode: boolean, @@ -166,6 +171,10 @@ export class McpClient implements McpProgressReporter { private readonly onContextUpdated?: (signal?: AbortSignal) => Promise, ) {} + getServerName(): string { + return this.serverName; + } + /** * Connects to the MCP server. */ @@ -210,27 +219,34 @@ export class McpClient implements McpProgressReporter { } /** - * Discovers tools and prompts from the MCP server. + * Discovers tools and prompts from the MCP server into the specified registries. */ - async discover(cliConfig: McpContext): Promise { + async discoverInto( + cliConfig: McpContext, + registries: RegistrySet, + ): Promise { this.assertConnected(); + this.registeredRegistries.add(registries); const prompts = await this.fetchPrompts(); - const tools = await this.discoverTools(cliConfig); + const tools = await this.discoverTools( + cliConfig, + registries.toolRegistry.getMessageBus(), + ); const resources = await this.discoverResources(); - this.updateResourceRegistry(resources); + this.updateResourceRegistry(resources, registries.resourceRegistry); if (prompts.length === 0 && tools.length === 0 && resources.length === 0) { throw new Error('No prompts, tools, or resources found on the server.'); } for (const prompt of prompts) { - this.promptRegistry.registerPrompt(prompt); + registries.promptRegistry.registerPrompt(prompt); } for (const tool of tools) { - this.toolRegistry.registerTool(tool); + registries.toolRegistry.registerTool(tool); } - this.toolRegistry.sortTools(); + registries.toolRegistry.sortTools(); // Validate MCP tool names in policy rules against discovered tools try { @@ -250,6 +266,14 @@ export class McpClient implements McpProgressReporter { } } + /** + * Unregisters registries so this client will no longer update them when it receives + * list_changed notifications from the server. + */ + removeRegistries(registries: RegistrySet): void { + this.registeredRegistries.delete(registries); + } + /** * Disconnects from the MCP server. */ @@ -257,9 +281,11 @@ export class McpClient implements McpProgressReporter { if (this.status !== MCPServerStatus.CONNECTED) { return; } - this.toolRegistry.removeMcpToolsByServer(this.serverName); - this.promptRegistry.removePromptsByServer(this.serverName); - this.resourceRegistry.removeResourcesByServer(this.serverName); + for (const registries of this.registeredRegistries) { + registries.toolRegistry.removeMcpToolsByServer(this.serverName); + registries.promptRegistry.removePromptsByServer(this.serverName); + registries.resourceRegistry.removeResourcesByServer(this.serverName); + } this.updateStatus(MCPServerStatus.DISCONNECTING); const client = this.client; this.client = undefined; @@ -294,6 +320,7 @@ export class McpClient implements McpProgressReporter { private async discoverTools( cliConfig: McpContext, + messageBus: MessageBus, options?: { timeout?: number; signal?: AbortSignal }, ): Promise { this.assertConnected(); @@ -302,7 +329,7 @@ export class McpClient implements McpProgressReporter { this.serverConfig, this.client!, cliConfig, - this.toolRegistry.messageBus, + messageBus, { ...(options ?? { timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, @@ -329,8 +356,11 @@ export class McpClient implements McpProgressReporter { return discoverResources(this.serverName, this.client!, this.cliConfig); } - private updateResourceRegistry(resources: Resource[]): void { - this.resourceRegistry.setResourcesForServer(this.serverName, resources); + private updateResourceRegistry( + resources: Resource[], + resourceRegistry: ResourceRegistry, + ): void { + resourceRegistry.setResourcesForServer(this.serverName, resources); } async readResource( @@ -482,23 +512,32 @@ export class McpClient implements McpProgressReporter { try { newResources = await this.discoverResources(); - // Verification Retry: If no resources are found or resources didn't change, - // wait briefly and try one more time. Some servers notify before they're fully ready. - const currentResources = - this.resourceRegistry.getResourcesByServer(this.serverName) || []; - const resourceMatch = - newResources.length === currentResources.length && - newResources.every((nr: Resource) => - currentResources.some((cr: MCPResource) => cr.uri === nr.uri), - ); + for (const registries of this.registeredRegistries) { + // Verification Retry: If no resources are found or resources didn't change, + // wait briefly and try one more time. Some servers notify before they're fully ready. + const currentResources = + registries.resourceRegistry.getResourcesByServer( + this.serverName, + ) || []; + const resourceMatch = + newResources.length === currentResources.length && + newResources.every((nr: Resource) => + currentResources.some((cr: MCPResource) => cr.uri === nr.uri), + ); - if (resourceMatch && !this.pendingResourceRefresh) { - debugLogger.log( - `No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`, + if (resourceMatch && !this.pendingResourceRefresh) { + debugLogger.log( + `No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`, + ); + const retryDelay = 500; + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + newResources = await this.discoverResources(); + } + + this.updateResourceRegistry( + newResources, + registries.resourceRegistry, ); - const retryDelay = 500; - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - newResources = await this.discoverResources(); } } catch (err) { debugLogger.error( @@ -508,8 +547,6 @@ export class McpClient implements McpProgressReporter { break; } - this.updateResourceRegistry(newResources); - if (this.onContextUpdated) { await this.onContextUpdated(abortController.signal); } @@ -575,30 +612,33 @@ export class McpClient implements McpProgressReporter { signal: abortController.signal, }); - // Verification Retry: If no prompts are found or prompts didn't change, - // wait briefly and try one more time. Some servers notify before they're fully ready. - const currentPrompts = - this.promptRegistry.getPromptsByServer(this.serverName) || []; - const promptsMatch = - newPrompts.length === currentPrompts.length && - newPrompts.every((np) => - currentPrompts.some((cp) => cp.name === np.name), - ); + for (const registries of this.registeredRegistries) { + // Verification Retry: If no prompts are found or prompts didn't change, + // wait briefly and try one more time. Some servers notify before they're fully ready. + const currentPrompts = + registries.promptRegistry.getPromptsByServer(this.serverName) || + []; + const promptsMatch = + newPrompts.length === currentPrompts.length && + newPrompts.every((np) => + currentPrompts.some((cp) => cp.name === np.name), + ); - if (promptsMatch && !this.pendingPromptRefresh) { - debugLogger.log( - `No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`, - ); - const retryDelay = 500; - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - newPrompts = await this.fetchPrompts({ - signal: abortController.signal, - }); - } + if (promptsMatch && !this.pendingPromptRefresh) { + debugLogger.log( + `No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`, + ); + const retryDelay = 500; + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + newPrompts = await this.fetchPrompts({ + signal: abortController.signal, + }); + } - this.promptRegistry.removePromptsByServer(this.serverName); - for (const prompt of newPrompts) { - this.promptRegistry.registerPrompt(prompt); + registries.promptRegistry.removePromptsByServer(this.serverName); + for (const prompt of newPrompts) { + registries.promptRegistry.registerPrompt(prompt); + } } } catch (err) { debugLogger.error( @@ -666,42 +706,58 @@ export class McpClient implements McpProgressReporter { const abortController = new AbortController(); const timeoutId = setTimeout(() => abortController.abort(), timeoutMs); - let newTools; try { - newTools = await this.discoverTools(this.cliConfig, { - signal: abortController.signal, - }); - debugLogger.log( - `Refresh for '${this.serverName}' discovered ${newTools.length} tools.`, - ); - - // Verification Retry (Option 3): If no tools are found or tools didn't change, - // wait briefly and try one more time. Some servers notify before they're fully ready. - const currentTools = - this.toolRegistry.getToolsByServer(this.serverName) || []; - const toolNamesMatch = - newTools.length === currentTools.length && - newTools.every((nt) => - currentTools.some( - (ct) => - ct.name === nt.name || - (ct instanceof DiscoveredMCPTool && - ct.serverToolName === nt.serverToolName), - ), - ); - - if (toolNamesMatch && !this.pendingToolRefresh) { - debugLogger.log( - `No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`, + for (const registries of this.registeredRegistries) { + let newTools = await this.discoverTools( + this.cliConfig, + registries.toolRegistry.getMessageBus(), + { + signal: abortController.signal, + }, ); - const retryDelay = 500; - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - newTools = await this.discoverTools(this.cliConfig, { - signal: abortController.signal, - }); debugLogger.log( - `Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`, + `Refresh for '${this.serverName}' discovered ${newTools.length} tools.`, ); + + // Verification Retry (Option 3): If no tools are found or tools didn't change, + // wait briefly and try one more time. Some servers notify before they're fully ready. + const currentTools = + registries.toolRegistry.getToolsByServer(this.serverName) || []; + const toolNamesMatch = + newTools.length === currentTools.length && + newTools.every((nt) => + currentTools.some( + (ct) => + ct.name === nt.name || + (ct instanceof DiscoveredMCPTool && + ct.serverToolName === nt.serverToolName), + ), + ); + + if (toolNamesMatch && !this.pendingToolRefresh) { + debugLogger.log( + `No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`, + ); + const retryDelay = 500; + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + newTools = await this.discoverTools( + this.cliConfig, + registries.toolRegistry.getMessageBus(), + { + signal: abortController.signal, + }, + ); + debugLogger.log( + `Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`, + ); + } + + registries.toolRegistry.removeMcpToolsByServer(this.serverName); + + for (const tool of newTools) { + registries.toolRegistry.registerTool(tool); + } + registries.toolRegistry.sortTools(); } } catch (err) { debugLogger.error( @@ -711,13 +767,6 @@ export class McpClient implements McpProgressReporter { break; } - this.toolRegistry.removeMcpToolsByServer(this.serverName); - - for (const tool of newTools) { - this.toolRegistry.registerTool(tool); - } - this.toolRegistry.sortTools(); - if (this.onContextUpdated) { await this.onContextUpdated(abortController.signal); } diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index ba272006337..291f43d9080 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -284,6 +284,26 @@ describe('ToolRegistry', () => { }); }); + describe('removeMcpToolsByServer', () => { + it('should remove all tools from a specific server', () => { + const serverName = 'test-server'; + const mcpTool1 = createMCPTool(serverName, 'tool1', 'desc1'); + const mcpTool2 = createMCPTool(serverName, 'tool2', 'desc2'); + const otherTool = createMCPTool('other-server', 'tool3', 'desc3'); + + toolRegistry.registerTool(mcpTool1); + toolRegistry.registerTool(mcpTool2); + toolRegistry.registerTool(otherTool); + + expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(2); + + toolRegistry.removeMcpToolsByServer(serverName); + + expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(0); + expect(toolRegistry.getToolsByServer('other-server')).toHaveLength(1); + }); + }); + describe('excluded tools', () => { const simpleTool = new MockTool({ name: 'tool-a', diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index 7e1faffb426..c91e4ca7e3b 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -223,10 +223,16 @@ export class ToolRegistry { private allKnownTools: Map = new Map(); private config: Config; readonly messageBus: MessageBus; + private isMainRegistry: boolean; - constructor(config: Config, messageBus: MessageBus) { + constructor( + config: Config, + messageBus: MessageBus, + isMainRegistry: boolean = false, + ) { this.config = config; this.messageBus = messageBus; + this.isMainRegistry = isMainRegistry; } getMessageBus(): MessageBus { @@ -599,6 +605,10 @@ export class ToolRegistry { const declarations: FunctionDeclaration[] = []; const seenNames = new Set(); + const mainAgentTools = this.isMainRegistry + ? this.config.getMainAgentTools() + : undefined; + this.getActiveTools().forEach((tool) => { const toolName = tool instanceof DiscoveredMCPTool @@ -608,6 +618,16 @@ export class ToolRegistry { if (seenNames.has(toolName)) { return; } + + if ( + mainAgentTools && + !mainAgentTools.includes(toolName) && + !mainAgentTools.includes(tool.constructor.name) && + !mainAgentTools.some((t) => t.startsWith(`${tool.constructor.name}(`)) + ) { + return; + } + seenNames.add(toolName); let schema = tool.getSchema(modelId);