Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions plugin/controller/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ export default class ControllerAppBootHook {
// The HTTPControllerRegister will collect all the methods
// and register methods after collect is done.
HTTPControllerRegister.instance?.doRegister(this.app.rootProtoManager);

this.app.config.mcp.hooks = MCPControllerRegister.hooks;
}

async willReady() {
Expand Down
19 changes: 18 additions & 1 deletion plugin/controller/app/middleware/mcp_body_middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,24 @@ export default () => {
});
if (res) {
ctx.disableBodyParser = true;
try {
for (const hook of ctx.app.config.mcp.hooks) {
await hook.middlewareStart?.(ctx);
}
await next();
if (!ctx.mcpArg) {
ctx.mcpArg = JSON.parse(ctx.response.header['mcp-proxy-arg'] as string ?? '{}');
}
for (const hook of ctx.app.config.mcp.hooks) {
await hook.middlewareEnd?.(ctx);
}
} catch (e) {
for (const hook of ctx.app.config.mcp.hooks) {
await hook.middlewareError?.(ctx, e);
}
}
} else {
await next();
}
await next();
};
};
49 changes: 41 additions & 8 deletions plugin/controller/lib/impl/mcp/MCPControllerRegister.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { ControllerRegister } from '../../ControllerRegister';
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import { isInitializeRequest, isJSONRPCRequest, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import { isInitializeRequest, isJSONRPCRequest, JSONRPCMessage, MessageExtraInfo } from '@modelcontextprotocol/sdk/types.js';
import { MCPProtocols } from '@eggjs/mcp-proxy/types';
import awaitEvent from 'await-event';
import compose from 'koa-compose';
Expand All @@ -43,6 +43,11 @@ export interface MCPControllerHook {
preProxy?: (ctx: Context, proxyReq: http.IncomingMessage, proxyResp: http.ServerResponse) => Promise<void>
schemaLoader?: (controllerMeta: MCPControllerMeta, meta: MCPPromptMeta | MCPToolMeta) => Promise<Parameters<McpServer['tool']>['2'] | Parameters<McpServer['prompt']>['2']>
checkAndRunProxy?: (ctx: Context, type: MCPProtocols, sessionId: string) => Promise<boolean>;

// middleware
middlewareStart?: (ctx: Context) => Promise<void>
middlewareEnd?: (ctx: Context) => Promise<void>
middlewareError?: (ctx: Context, e: Error) => Promise<void>
}

class InnerSSEServerTransport extends SSEServerTransport {
Expand Down Expand Up @@ -183,6 +188,14 @@ export class MCPControllerRegister implements ControllerRegister {
if (self.globalMiddlewares) {
mw = compose([ mw, self.globalMiddlewares ]);
}
const onmessage = transport.onmessage;

transport.onmessage = async (message: JSONRPCMessage, extra?: MessageExtraInfo) => {
if (self.app.currentContext) {
self.app.currentContext.mcpArg = message;
}
onmessage && await onmessage(message, extra);
};
const initHandler = async (ctx: Context) => {
if (MCPControllerRegister.hooks.length > 0) {
for (const hook of MCPControllerRegister.hooks) {
Expand All @@ -194,7 +207,6 @@ export class MCPControllerRegister implements ControllerRegister {
'content-type': 'text/event-stream',
'transfer-encoding': 'chunked',
});

await ctx.app.ctxStorage.run(ctx, async () => {
await mw(ctx, async () => {
await transport.handleRequest(ctx.req, ctx.res);
Expand Down Expand Up @@ -253,7 +265,7 @@ export class MCPControllerRegister implements ControllerRegister {
}
const sessionId = ctx.req.headers['mcp-session-id'] as string | undefined;
if (!sessionId) {
const ct = contentType.parse(ctx.req.headers['content-type'] ?? '');
const ct = contentType.parse(ctx.req.headers['content-type'] || 'application/json');

let body;

Expand Down Expand Up @@ -307,6 +319,15 @@ export class MCPControllerRegister implements ControllerRegister {
transport,
);

const onmessage = transport.onmessage;

transport.onmessage = async (message: JSONRPCMessage, extra?: MessageExtraInfo) => {
if (self.app.currentContext) {
self.app.currentContext.mcpArg = message;
}
onmessage && await onmessage(message, extra);
};

await ctx.app.ctxStorage.run(ctx, async () => {
await mw(ctx, async () => {
await transport.handleRequest(ctx.req, ctx.res, body);
Expand Down Expand Up @@ -338,6 +359,7 @@ export class MCPControllerRegister implements ControllerRegister {
'content-type': 'text/event-stream',
'transfer-encoding': 'chunked',
});

await ctx.app.ctxStorage.run(ctx, async () => {
await mw(ctx, async () => {
await transport.handleRequest(ctx.req, ctx.res);
Expand Down Expand Up @@ -442,7 +464,8 @@ export class MCPControllerRegister implements ControllerRegister {
};
const messageFunc = transport.onmessage;
self.sseTransportsRequestMap.set(transport, {});
transport.onmessage = async (...args: [JSONRPCMessage]) => {
transport.onmessage = async (message: JSONRPCMessage, extra?: MessageExtraInfo) => {
const args = [ message, extra ];
// 这里需要 new 一个新的 ctx,否则 ContextProto 会未被初始化
const socket = new Socket();
const req = new IncomingMessage(socket);
Expand All @@ -451,6 +474,7 @@ export class MCPControllerRegister implements ControllerRegister {
req.url = self.mcpConfig.getSseInitPath(name);
req.headers = {
...ctx.req.headers,
...extra?.requestInfo?.headers,
accept: 'application/json, text/event-stream',
'content-type': 'application/json',
};
Expand All @@ -462,12 +486,12 @@ export class MCPControllerRegister implements ControllerRegister {
}
}
await mw(newCtx, async () => {
messageFunc!(...args);
messageFunc!(message, extra);
if (isJSONRPCRequest(args[0])) {
const map = self.sseTransportsRequestMap.get(transport)!;
const wait = new Promise<null>((resolve, reject) => {
if ('id' in args[0]) {
map[args[0].id] = { resolve, reject };
if (extra && 'id' in extra) {
map[extra.id as string] = { resolve, reject };
}
});
await wait;
Expand Down Expand Up @@ -500,7 +524,16 @@ export class MCPControllerRegister implements ControllerRegister {
}
self.app.logger.info('message coming', sessionId);
try {
await self.transports[sessionId].handlePostMessage(ctx.req, ctx.res);
const ct = contentType.parse(ctx.req.headers['content-type'] ?? '');

const rawBody = await getRawBody(ctx.req, {
limit: '4mb',
encoding: ct.parameters.charset ?? 'utf-8',
});

const body = JSON.parse(rawBody);
ctx.mcpArg = body;
await self.transports[sessionId].handlePostMessage(ctx.req, ctx.res, body);
} catch (error) {
self.app.logger.error('Error handling MCP message', error);
if (!ctx.res.headersSent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Extra,
ToolExtra,
} from '@eggjs/tegg';
import z from 'zod';

Check warning on line 19 in plugin/controller/test/fixtures/apps/mcp-app/app/controller/McpController.ts

View workflow job for this annotation

GitHub Actions / Runner-macos (16)

Using exported name 'z' as identifier for default import

Check warning on line 19 in plugin/controller/test/fixtures/apps/mcp-app/app/controller/McpController.ts

View workflow job for this annotation

GitHub Actions / Runner-macos (18)

Using exported name 'z' as identifier for default import

Check warning on line 19 in plugin/controller/test/fixtures/apps/mcp-app/app/controller/McpController.ts

View workflow job for this annotation

GitHub Actions / Runner-ubuntu (16)

Using exported name 'z' as identifier for default import

Check warning on line 19 in plugin/controller/test/fixtures/apps/mcp-app/app/controller/McpController.ts

View workflow job for this annotation

GitHub Actions / Runner-macos (20)

Using exported name 'z' as identifier for default import

Check warning on line 19 in plugin/controller/test/fixtures/apps/mcp-app/app/controller/McpController.ts

View workflow job for this annotation

GitHub Actions / Runner-ubuntu (20)

Using exported name 'z' as identifier for default import

Check warning on line 19 in plugin/controller/test/fixtures/apps/mcp-app/app/controller/McpController.ts

View workflow job for this annotation

GitHub Actions / Runner-ubuntu (18)

Using exported name 'z' as identifier for default import

export const PromptType = {
name: z.string(),
Expand Down Expand Up @@ -80,6 +80,12 @@
};
}

@MCPTool()
async mockError(): Promise<MCPToolResponse> {
throw new Error('mock error');
}


@MCPTool()
async echoUser(): Promise<MCPToolResponse> {
return {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
'use strict';
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { randomUUID } = require('node:crypto');
// eslint-disable-next-line @typescript-eslint/no-var-requires
const path = require('node:path');

module.exports = function() {
module.exports = function(appInfo) {
const config = {
keys: 'test key',
mcp: {
Expand All @@ -20,6 +22,17 @@ module.exports = function() {
},
},
},
customLogger: {
mcpMiddewareStartLogger: {
file: path.join(appInfo.root, 'logs', 'tracelog', 'mcpMiddlewareStart.log'),
},
mcpMiddewareEndLogger: {
file: path.join(appInfo.root, 'logs', 'tracelog', 'mcpMiddlewareEnd.log'),
},
mcpMiddewareErrorLogger: {
file: path.join(appInfo.root, 'logs', 'tracelog', 'mcpMiddlewareError.log'),
},
},
};
return config;
};
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ export const GetAlipayTeggHook = (app: Application) => {
async preProxy(ctx) {
setUser(ctx);
},
async middlewareStart(ctx) {
ctx.mcpStartTime = Date.now();
ctx.getLogger('mcpMiddewareStartLogger').info('mcp middleware start');
},
async middlewareEnd(ctx) {
ctx.getLogger('mcpMiddewareEndLogger').info('mcp middleware end, arg: ', JSON.stringify(ctx.mcpArg), `, time: ${Date.now() - ctx.mcpStartTime}`);
},
async middlewareError(ctx, e) {
ctx.getLogger('mcpMiddewareErrorLogger').info('mcp middleware error: ', e);
},
};

return AlipayTeggControllerHook;
Expand Down
53 changes: 53 additions & 0 deletions plugin/controller/test/mcp/mcp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
description: undefined,
name: 'bar',
},
{
description: undefined,
name: 'mockError',
},
{
description: undefined,
name: 'echoUser',
Expand Down Expand Up @@ -219,6 +223,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
],
});
await sseTransport.close();


const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');

assert.ok(middlewareStartTracelog.includes('mcp middleware start'));
assert.ok(middlewareEndTracelog.includes('mcp middleware end, arg: {'));
});

it('streamable should work', async () => {
Expand Down Expand Up @@ -269,6 +280,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
description: undefined,
name: 'bar',
},
{
description: undefined,
name: 'mockError',
},
{
description: undefined,
name: 'echoUser',
Expand Down Expand Up @@ -364,6 +379,12 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
const logContent = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/mcp-app/mcp-app-web.log'));

assert.ok(logContent.includes('startNotificationStream finish'));

const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');

assert.ok(middlewareStartTracelog.includes(' POST /mcp/stream] mcp middleware start'));
assert.ok(middlewareEndTracelog.includes(' POST /mcp/stream] mcp middleware end, arg: '));
});

it('stateless streamable should work', async () => {
Expand Down Expand Up @@ -412,6 +433,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
description: undefined,
name: 'bar',
},
{
description: undefined,
name: 'mockError',
},
{
description: undefined,
name: 'echoUser',
Expand Down Expand Up @@ -503,6 +528,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {

await streamableTransport.terminateSession();
await streamableClient.close();


const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');

assert.ok(middlewareStartTracelog.includes(' POST /mcp/stateless/stream] mcp middleware start'));
assert.ok(middlewareEndTracelog.includes(' POST /mcp/stateless/stream] mcp middleware end, arg: {'));
});

it('multiple sse should work', async () => {
Expand Down Expand Up @@ -640,6 +672,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
],
});
await sseTransport.close();


const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');

assert.ok(middlewareStartTracelog.includes('mcp middleware start'));
assert.ok(middlewareEndTracelog.includes('mcp middleware end'));
});

it('multiple streamable should work', async () => {
Expand Down Expand Up @@ -785,6 +824,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
const logContent = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/mcp-app/mcp-app-web.log'));

assert.ok(logContent.includes('startNotificationStream finish'));


const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');

assert.ok(middlewareStartTracelog.includes(' POST /mcp/test/stream] mcp middleware start'));
assert.ok(middlewareEndTracelog.includes(' POST /mcp/test/stream] mcp middleware end'));
});

it('multiple stateless streamable should work', async () => {
Expand Down Expand Up @@ -924,6 +970,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {

await streamableTransport.terminateSession();
await streamableClient.close();


const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');

assert.ok(middlewareStartTracelog.includes(' /mcp/test/stateless/stream] mcp middleware start'));
assert.ok(middlewareEndTracelog.includes(' /mcp/test/stateless/stream] mcp middleware end'));
});
}
});
Loading
Loading