Skip to content

Commit 7215645

Browse files
authored
feat: add mcp middleware hook (#344)
<!-- Thank you for your pull request. Please review below requirements. Bug fixes and new features should include tests and possibly benchmarks. Contributors guide: https://github.com/eggjs/egg/blob/master/CONTRIBUTING.md 感谢您贡献代码。请确认下列 checklist 的完成情况。 Bug 修复和新功能必须包含测试,必要时请附上性能测试。 Contributors guide: https://github.com/eggjs/egg/blob/master/CONTRIBUTING.md --> ##### Checklist <!-- Remove items that do not apply. For completed items, change [ ] to [x]. --> - [ ] `npm test` passes - [ ] tests and/or benchmarks are included - [ ] documentation is changed or added - [ ] commit message follows commit guidelines ##### Affected core subsystem(s) <!-- Provide affected core subsystem(s). --> ##### Description of change <!-- Provide a description of the change below this comment. --> <!-- - any feature? - close https://github.com/eggjs/egg/ISSUE_URL -->
1 parent aa5daf7 commit 7215645

File tree

9 files changed

+183
-12
lines changed

9 files changed

+183
-12
lines changed

plugin/controller/app.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ export default class ControllerAppBootHook {
132132
// The HTTPControllerRegister will collect all the methods
133133
// and register methods after collect is done.
134134
HTTPControllerRegister.instance?.doRegister(this.app.rootProtoManager);
135+
136+
this.app.config.mcp.hooks = MCPControllerRegister.hooks;
135137
}
136138

137139
async willReady() {

plugin/controller/app/middleware/mcp_body_middleware.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,24 @@ export default () => {
1818
});
1919
if (res) {
2020
ctx.disableBodyParser = true;
21+
try {
22+
for (const hook of ctx.app.config.mcp.hooks) {
23+
await hook.middlewareStart?.(ctx);
24+
}
25+
await next();
26+
if (!ctx.mcpArg) {
27+
ctx.mcpArg = JSON.parse(ctx.response.header['mcp-proxy-arg'] as string ?? '{}');
28+
}
29+
for (const hook of ctx.app.config.mcp.hooks) {
30+
await hook.middlewareEnd?.(ctx);
31+
}
32+
} catch (e) {
33+
for (const hook of ctx.app.config.mcp.hooks) {
34+
await hook.middlewareError?.(ctx, e);
35+
}
36+
}
37+
} else {
38+
await next();
2139
}
22-
await next();
2340
};
2441
};

plugin/controller/lib/impl/mcp/MCPControllerRegister.ts

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import { ControllerRegister } from '../../ControllerRegister';
1919
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
2020
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
2121
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
22-
import { isInitializeRequest, isJSONRPCRequest, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
22+
import { isInitializeRequest, isJSONRPCRequest, JSONRPCMessage, MessageExtraInfo } from '@modelcontextprotocol/sdk/types.js';
2323
import { MCPProtocols } from '@eggjs/mcp-proxy/types';
2424
import awaitEvent from 'await-event';
2525
import compose from 'koa-compose';
@@ -43,6 +43,11 @@ export interface MCPControllerHook {
4343
preProxy?: (ctx: Context, proxyReq: http.IncomingMessage, proxyResp: http.ServerResponse) => Promise<void>
4444
schemaLoader?: (controllerMeta: MCPControllerMeta, meta: MCPPromptMeta | MCPToolMeta) => Promise<Parameters<McpServer['tool']>['2'] | Parameters<McpServer['prompt']>['2']>
4545
checkAndRunProxy?: (ctx: Context, type: MCPProtocols, sessionId: string) => Promise<boolean>;
46+
47+
// middleware
48+
middlewareStart?: (ctx: Context) => Promise<void>
49+
middlewareEnd?: (ctx: Context) => Promise<void>
50+
middlewareError?: (ctx: Context, e: Error) => Promise<void>
4651
}
4752

4853
class InnerSSEServerTransport extends SSEServerTransport {
@@ -183,6 +188,14 @@ export class MCPControllerRegister implements ControllerRegister {
183188
if (self.globalMiddlewares) {
184189
mw = compose([ mw, self.globalMiddlewares ]);
185190
}
191+
const onmessage = transport.onmessage;
192+
193+
transport.onmessage = async (message: JSONRPCMessage, extra?: MessageExtraInfo) => {
194+
if (self.app.currentContext) {
195+
self.app.currentContext.mcpArg = message;
196+
}
197+
onmessage && await onmessage(message, extra);
198+
};
186199
const initHandler = async (ctx: Context) => {
187200
if (MCPControllerRegister.hooks.length > 0) {
188201
for (const hook of MCPControllerRegister.hooks) {
@@ -194,7 +207,6 @@ export class MCPControllerRegister implements ControllerRegister {
194207
'content-type': 'text/event-stream',
195208
'transfer-encoding': 'chunked',
196209
});
197-
198210
await ctx.app.ctxStorage.run(ctx, async () => {
199211
await mw(ctx, async () => {
200212
await transport.handleRequest(ctx.req, ctx.res);
@@ -253,7 +265,7 @@ export class MCPControllerRegister implements ControllerRegister {
253265
}
254266
const sessionId = ctx.req.headers['mcp-session-id'] as string | undefined;
255267
if (!sessionId) {
256-
const ct = contentType.parse(ctx.req.headers['content-type'] ?? '');
268+
const ct = contentType.parse(ctx.req.headers['content-type'] || 'application/json');
257269

258270
let body;
259271

@@ -307,6 +319,15 @@ export class MCPControllerRegister implements ControllerRegister {
307319
transport,
308320
);
309321

322+
const onmessage = transport.onmessage;
323+
324+
transport.onmessage = async (message: JSONRPCMessage, extra?: MessageExtraInfo) => {
325+
if (self.app.currentContext) {
326+
self.app.currentContext.mcpArg = message;
327+
}
328+
onmessage && await onmessage(message, extra);
329+
};
330+
310331
await ctx.app.ctxStorage.run(ctx, async () => {
311332
await mw(ctx, async () => {
312333
await transport.handleRequest(ctx.req, ctx.res, body);
@@ -338,6 +359,7 @@ export class MCPControllerRegister implements ControllerRegister {
338359
'content-type': 'text/event-stream',
339360
'transfer-encoding': 'chunked',
340361
});
362+
341363
await ctx.app.ctxStorage.run(ctx, async () => {
342364
await mw(ctx, async () => {
343365
await transport.handleRequest(ctx.req, ctx.res);
@@ -442,7 +464,8 @@ export class MCPControllerRegister implements ControllerRegister {
442464
};
443465
const messageFunc = transport.onmessage;
444466
self.sseTransportsRequestMap.set(transport, {});
445-
transport.onmessage = async (...args: [JSONRPCMessage]) => {
467+
transport.onmessage = async (message: JSONRPCMessage, extra?: MessageExtraInfo) => {
468+
const args = [ message, extra ];
446469
// 这里需要 new 一个新的 ctx,否则 ContextProto 会未被初始化
447470
const socket = new Socket();
448471
const req = new IncomingMessage(socket);
@@ -451,6 +474,7 @@ export class MCPControllerRegister implements ControllerRegister {
451474
req.url = self.mcpConfig.getSseInitPath(name);
452475
req.headers = {
453476
...ctx.req.headers,
477+
...extra?.requestInfo?.headers,
454478
accept: 'application/json, text/event-stream',
455479
'content-type': 'application/json',
456480
};
@@ -462,12 +486,12 @@ export class MCPControllerRegister implements ControllerRegister {
462486
}
463487
}
464488
await mw(newCtx, async () => {
465-
messageFunc!(...args);
489+
messageFunc!(message, extra);
466490
if (isJSONRPCRequest(args[0])) {
467491
const map = self.sseTransportsRequestMap.get(transport)!;
468492
const wait = new Promise<null>((resolve, reject) => {
469-
if ('id' in args[0]) {
470-
map[args[0].id] = { resolve, reject };
493+
if (extra && 'id' in extra) {
494+
map[extra.id as string] = { resolve, reject };
471495
}
472496
});
473497
await wait;
@@ -500,7 +524,16 @@ export class MCPControllerRegister implements ControllerRegister {
500524
}
501525
self.app.logger.info('message coming', sessionId);
502526
try {
503-
await self.transports[sessionId].handlePostMessage(ctx.req, ctx.res);
527+
const ct = contentType.parse(ctx.req.headers['content-type'] ?? '');
528+
529+
const rawBody = await getRawBody(ctx.req, {
530+
limit: '4mb',
531+
encoding: ct.parameters.charset ?? 'utf-8',
532+
});
533+
534+
const body = JSON.parse(rawBody);
535+
ctx.mcpArg = body;
536+
await self.transports[sessionId].handlePostMessage(ctx.req, ctx.res, body);
504537
} catch (error) {
505538
self.app.logger.error('Error handling MCP message', error);
506539
if (!ctx.res.headersSent) {

plugin/controller/test/fixtures/apps/mcp-app/app/controller/McpController.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ export class McpController {
8080
};
8181
}
8282

83+
@MCPTool()
84+
async mockError(): Promise<MCPToolResponse> {
85+
throw new Error('mock error');
86+
}
87+
88+
8389
@MCPTool()
8490
async echoUser(): Promise<MCPToolResponse> {
8591
return {

plugin/controller/test/fixtures/apps/mcp-app/config/config.default.js

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
'use strict';
22
// eslint-disable-next-line @typescript-eslint/no-var-requires
33
const { randomUUID } = require('node:crypto');
4+
// eslint-disable-next-line @typescript-eslint/no-var-requires
5+
const path = require('node:path');
46

5-
module.exports = function() {
7+
module.exports = function(appInfo) {
68
const config = {
79
keys: 'test key',
810
mcp: {
@@ -20,6 +22,17 @@ module.exports = function() {
2022
},
2123
},
2224
},
25+
customLogger: {
26+
mcpMiddewareStartLogger: {
27+
file: path.join(appInfo.root, 'logs', 'tracelog', 'mcpMiddlewareStart.log'),
28+
},
29+
mcpMiddewareEndLogger: {
30+
file: path.join(appInfo.root, 'logs', 'tracelog', 'mcpMiddlewareEnd.log'),
31+
},
32+
mcpMiddewareErrorLogger: {
33+
file: path.join(appInfo.root, 'logs', 'tracelog', 'mcpMiddlewareError.log'),
34+
},
35+
},
2336
};
2437
return config;
2538
};

plugin/controller/test/fixtures/apps/mcp-app/hook-plugin/lib/MCPControllerHook.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ export const GetAlipayTeggHook = (app: Application) => {
3232
async preProxy(ctx) {
3333
setUser(ctx);
3434
},
35+
async middlewareStart(ctx) {
36+
ctx.mcpStartTime = Date.now();
37+
ctx.getLogger('mcpMiddewareStartLogger').info('mcp middleware start');
38+
},
39+
async middlewareEnd(ctx) {
40+
ctx.getLogger('mcpMiddewareEndLogger').info('mcp middleware end, arg: ', JSON.stringify(ctx.mcpArg), `, time: ${Date.now() - ctx.mcpStartTime}`);
41+
},
42+
async middlewareError(ctx, e) {
43+
ctx.getLogger('mcpMiddewareErrorLogger').info('mcp middleware error: ', e);
44+
},
3545
};
3646

3747
return AlipayTeggControllerHook;

plugin/controller/test/mcp/mcp.test.ts

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
130130
description: undefined,
131131
name: 'bar',
132132
},
133+
{
134+
description: undefined,
135+
name: 'mockError',
136+
},
133137
{
134138
description: undefined,
135139
name: 'echoUser',
@@ -219,6 +223,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
219223
],
220224
});
221225
await sseTransport.close();
226+
227+
228+
const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
229+
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');
230+
231+
assert.ok(middlewareStartTracelog.includes('mcp middleware start'));
232+
assert.ok(middlewareEndTracelog.includes('mcp middleware end, arg: {'));
222233
});
223234

224235
it('streamable should work', async () => {
@@ -269,6 +280,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
269280
description: undefined,
270281
name: 'bar',
271282
},
283+
{
284+
description: undefined,
285+
name: 'mockError',
286+
},
272287
{
273288
description: undefined,
274289
name: 'echoUser',
@@ -364,6 +379,12 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
364379
const logContent = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/mcp-app/mcp-app-web.log'));
365380

366381
assert.ok(logContent.includes('startNotificationStream finish'));
382+
383+
const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
384+
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');
385+
386+
assert.ok(middlewareStartTracelog.includes(' POST /mcp/stream] mcp middleware start'));
387+
assert.ok(middlewareEndTracelog.includes(' POST /mcp/stream] mcp middleware end, arg: '));
367388
});
368389

369390
it('stateless streamable should work', async () => {
@@ -412,6 +433,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
412433
description: undefined,
413434
name: 'bar',
414435
},
436+
{
437+
description: undefined,
438+
name: 'mockError',
439+
},
415440
{
416441
description: undefined,
417442
name: 'echoUser',
@@ -503,6 +528,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
503528

504529
await streamableTransport.terminateSession();
505530
await streamableClient.close();
531+
532+
533+
const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
534+
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');
535+
536+
assert.ok(middlewareStartTracelog.includes(' POST /mcp/stateless/stream] mcp middleware start'));
537+
assert.ok(middlewareEndTracelog.includes(' POST /mcp/stateless/stream] mcp middleware end, arg: {'));
506538
});
507539

508540
it('multiple sse should work', async () => {
@@ -640,6 +672,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
640672
],
641673
});
642674
await sseTransport.close();
675+
676+
677+
const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
678+
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');
679+
680+
assert.ok(middlewareStartTracelog.includes('mcp middleware start'));
681+
assert.ok(middlewareEndTracelog.includes('mcp middleware end'));
643682
});
644683

645684
it('multiple streamable should work', async () => {
@@ -785,6 +824,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
785824
const logContent = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/mcp-app/mcp-app-web.log'));
786825

787826
assert.ok(logContent.includes('startNotificationStream finish'));
827+
828+
829+
const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
830+
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');
831+
832+
assert.ok(middlewareStartTracelog.includes(' POST /mcp/test/stream] mcp middleware start'));
833+
assert.ok(middlewareEndTracelog.includes(' POST /mcp/test/stream] mcp middleware end'));
788834
});
789835

790836
it('multiple stateless streamable should work', async () => {
@@ -924,6 +970,13 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
924970

925971
await streamableTransport.terminateSession();
926972
await streamableClient.close();
973+
974+
975+
const middlewareStartTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareStart.log'), 'utf-8');
976+
const middlewareEndTracelog = await fs.readFile(path.join(__dirname, '../fixtures/apps/mcp-app/logs/tracelog/mcpMiddlewareEnd.log'), 'utf-8');
977+
978+
assert.ok(middlewareStartTracelog.includes(' /mcp/test/stateless/stream] mcp middleware start'));
979+
assert.ok(middlewareEndTracelog.includes(' /mcp/test/stateless/stream] mcp middleware end'));
927980
});
928981
}
929982
});

0 commit comments

Comments
 (0)