Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/a2a-server/src/config/config.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const mockConfig = {
...params,
initialize: vi.fn(),
waitForMcpInit: vi.fn(),
refreshAuth: vi.fn(),
getExperiments: vi.fn().mockReturnValue({
flags: {
Expand Down Expand Up @@ -94,6 +95,7 @@ describe('loadConfig', () => {
const mockConfig = {
...(params as object),
initialize: vi.fn(),
waitForMcpInit: vi.fn(),
refreshAuth: vi.fn(),
getExperiments: vi.fn().mockReturnValue({
flags: {
Expand Down
2 changes: 2 additions & 0 deletions packages/a2a-server/src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ export async function loadConfig(

// Needed to initialize ToolRegistry, and git checkpointing if enabled
await config.initialize();

await config.waitForMcpInit();
startupProfiler.flush(config);

await refreshAuthentication(config, adcFilePath, 'Config');
Expand Down
24 changes: 24 additions & 0 deletions packages/cli/src/zed-integration/zedIntegration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ describe('GeminiAgent', () => {
mockConfig = {
refreshAuth: vi.fn(),
initialize: vi.fn(),
waitForMcpInit: vi.fn(),
getFileSystemService: vi.fn(),
setFileSystemService: vi.fn(),
getContentGeneratorConfig: vi.fn(),
Expand Down Expand Up @@ -486,6 +487,7 @@ describe('Session', () => {
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
setApprovalMode: vi.fn(),
isPlanEnabled: vi.fn().mockReturnValue(false),
waitForMcpInit: vi.fn(),
} as unknown as Mocked<Config>;
mockConnection = {
sessionUpdate: vi.fn(),
Expand All @@ -500,6 +502,28 @@ describe('Session', () => {
vi.clearAllMocks();
});

it('should await MCP initialization before processing a prompt', async () => {
const stream = createMockStream([
{
type: StreamEventType.CHUNK,
value: { candidates: [{ content: { parts: [{ text: 'Hi' }] } }] },
},
]);
mockChat.sendMessageStream.mockResolvedValue(stream);

await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: 'test' }],
});

expect(mockConfig.waitForMcpInit).toHaveBeenCalledOnce();
const waitOrder = (mockConfig.waitForMcpInit as Mock).mock
.invocationCallOrder[0];
const sendOrder = (mockChat.sendMessageStream as Mock).mock
.invocationCallOrder[0];
expect(waitOrder).toBeLessThan(sendOrder);
});

it('should handle prompt with text response', async () => {
const stream = createMockStream([
{
Expand Down
2 changes: 2 additions & 0 deletions packages/cli/src/zed-integration/zedIntegration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ export class Session {
const pendingSend = new AbortController();
this.pendingPrompt = pendingSend;

await this.config.waitForMcpInit();

const promptId = Math.random().toString(16).slice(2);
const chat = this.chat;

Expand Down
11 changes: 9 additions & 2 deletions packages/core/src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ export class Config implements McpContext {
private compressionTruncationCounter = 0;
private initialized = false;
private initPromise: Promise<void> | undefined;
private mcpInitializationPromise: Promise<void> | null = null;
readonly storage: Storage;
private readonly fileExclusions: FileExclusions;
private readonly eventEmitter?: EventEmitter;
Expand Down Expand Up @@ -1124,7 +1125,7 @@ export class Config implements McpContext {
);
// We do not await this promise so that the CLI can start up even if
// MCP servers are slow to connect.
const mcpInitialization = Promise.allSettled([
this.mcpInitializationPromise = Promise.allSettled([
this.mcpClientManager.startConfiguredMcpServers(),
this.getExtensionLoader().start(this),
]).then((results) => {
Expand All @@ -1136,7 +1137,7 @@ export class Config implements McpContext {
});

if (!this.interactive || this.experimentalZedIntegration) {
await mcpInitialization;
await this.mcpInitializationPromise;
}

if (this.skillsSupport) {
Expand Down Expand Up @@ -2234,6 +2235,12 @@ export class Config implements McpContext {
return this.experimentalZedIntegration;
}

async waitForMcpInit(): Promise<void> {
if (this.mcpInitializationPromise) {
await this.mcpInitializationPromise;
}
}

getListExtensions(): boolean {
return this.listExtensions;
}
Expand Down
Loading