Skip to content

Commit 176bf20

Browse files
committed
feat: add mcp global middleware
1 parent a6d2b0c commit 176bf20

File tree

8 files changed

+144
-7
lines changed

8 files changed

+144
-7
lines changed

plugin/controller/lib/impl/mcp/MCPControllerRegister.ts

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
MCPPromptMeta,
1212
MCPToolMeta,
1313
ControllerType,
14+
EggContext,
1415
} from '@eggjs/tegg';
1516
import { EggPrototype } from '@eggjs/tegg-metadata';
1617
import { EggContainerFactory } from '@eggjs/tegg-runtime';
@@ -21,6 +22,7 @@ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
2122
import { isInitializeRequest, isJSONRPCRequest, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
2223
import { MCPProtocols } from '@eggjs/mcp-proxy/types';
2324
import awaitEvent from 'await-event';
25+
import compose from 'koa-compose';
2426

2527
import getRawBody from 'raw-body';
2628
import contentType from 'content-type';
@@ -90,6 +92,7 @@ export class MCPControllerRegister implements ControllerRegister {
9092
}
9193
>>();
9294
static hooks: MCPControllerHook[] = [];
95+
globalMiddlewares: compose.ComposedMiddleware<EggContext>;
9396

9497
static create(proto: EggPrototype, controllerMeta: ControllerMetadata, app: Application) {
9598
assert(controllerMeta.type === ControllerType.MCP, 'controller meta type is not MCP');
@@ -160,7 +163,10 @@ export class MCPControllerRegister implements ControllerRegister {
160163
sessionIdGenerator: undefined,
161164
});
162165
self.statelessTransport = transport;
163-
const mw = self.app.middleware.teggCtxLifecycleMiddleware();
166+
let mw = self.app.middleware.teggCtxLifecycleMiddleware();
167+
if (self.globalMiddlewares) {
168+
mw = compose([ mw, self.globalMiddlewares ]);
169+
}
164170
const initHandler = async (ctx: Context) => {
165171
if (MCPControllerRegister.hooks.length > 0) {
166172
for (const hook of MCPControllerRegister.hooks) {
@@ -218,7 +224,10 @@ export class MCPControllerRegister implements ControllerRegister {
218224
mcpStreamServerInit() {
219225
const allRouterFunc = this.router.all;
220226
const self = this;
221-
const mw = self.app.middleware.teggCtxLifecycleMiddleware();
227+
let mw = self.app.middleware.teggCtxLifecycleMiddleware();
228+
if (self.globalMiddlewares) {
229+
mw = compose([ mw, self.globalMiddlewares ]);
230+
}
222231
const initHandler = async (ctx: Context) => {
223232
ctx.respond = false;
224233
if (MCPControllerRegister.hooks.length > 0) {
@@ -377,7 +386,10 @@ export class MCPControllerRegister implements ControllerRegister {
377386

378387
sseCtxStorageRun(ctx: Context, transport: SSEServerTransport) {
379388
const self = this;
380-
const mw = this.app.middleware.teggCtxLifecycleMiddleware();
389+
let mw = this.app.middleware.teggCtxLifecycleMiddleware();
390+
if (self.globalMiddlewares) {
391+
mw = compose([ mw, self.globalMiddlewares ]);
392+
}
381393
const closeFunc = transport.onclose;
382394
transport.onclose = (...args) => {
383395
closeFunc?.(...args);
@@ -436,6 +448,11 @@ export class MCPControllerRegister implements ControllerRegister {
436448
// if (aclMiddleware) {
437449
// methodMiddlewares.push(aclMiddleware);
438450
// }
451+
452+
let mw = self.app.middleware.teggCtxLifecycleMiddleware();
453+
if (self.globalMiddlewares) {
454+
mw = compose([ mw, self.globalMiddlewares ]);
455+
}
439456
const messageHander = async (ctx: Context) => {
440457
const sessionId = ctx.query.sessionId;
441458

@@ -476,13 +493,30 @@ export class MCPControllerRegister implements ControllerRegister {
476493
Reflect.apply(routerFunc, this.router, [
477494
'chairMcpMessage',
478495
self.mcpConfig.sseMessagePath,
479-
...[],
496+
...[ mw ],
480497
messageHander,
481498
]);
482499
}
483500

501+
getGlobalMiddleware() {
502+
const middlewareNames = this.app.config.mcp.middleware || [];
503+
const middlewares: compose.Middleware<EggContext>[] = [];
504+
for (const name of middlewareNames) {
505+
const middlewareFactory = (this.app as unknown as any).middlewares[name];
506+
if (!middlewareFactory) {
507+
throw new TypeError(`Middleware ${name} not found`);
508+
}
509+
const options = (this.app.config as any)[name] || {};
510+
const mw = middlewareFactory(options, this.app);
511+
(mw as any)._name = name;
512+
middlewares.push(mw);
513+
}
514+
this.globalMiddlewares = compose(middlewares);
515+
}
516+
484517
async register() {
485518
if (!this.mcpServerHelper) {
519+
this.getGlobalMiddleware();
486520
this.mcpServerHelper = new MCPServerHelper({
487521
name: this.controllerMeta.name ?? `chair-mcp-${this.app.name}-server`,
488522
version: this.controllerMeta.version ?? '1.0.0',

plugin/controller/lib/impl/mcp/MCPServerHelper.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ export class MCPServerHelper {
5050
) as ReturnType<ReadResourceCallback>;
5151
};
5252
const name = resourceMeta.mcpName ?? resourceMeta.name;
53-
if (resourceMeta.uri || resourceMeta.template) {
54-
this.server.registerResource(name, resourceMeta.uri ?? resourceMeta.template!, resourceMeta.metadata ?? {}, handler);
53+
if (resourceMeta.uri) {
54+
this.server.registerResource(name, resourceMeta.uri, resourceMeta.metadata ?? {}, handler);
55+
} else if (resourceMeta.template) {
56+
this.server.registerResource(name, resourceMeta.template, resourceMeta.metadata ?? {}, handler);
5557
} else {
5658
throw new Error(`MCPResource ${name} must have uri or template`);
5759
}

plugin/controller/test/fixtures/apps/mcp-app/app/controller/McpController.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import {
1313
Inject,
1414
ContextProto,
1515
ToolArgsSchema,
16+
Extra,
17+
ToolExtra,
1618
} from '@eggjs/tegg';
1719
import z from 'zod';
1820

@@ -90,6 +92,18 @@ export class McpController {
9092
};
9193
}
9294

95+
@MCPTool()
96+
async traceTest(@Extra() extra: ToolExtra): Promise<MCPToolResponse> {
97+
return {
98+
content: [
99+
{
100+
type: 'text',
101+
text: `hello ${extra.requestInfo?.headers.trace}`,
102+
},
103+
],
104+
};
105+
}
106+
93107

94108
@MCPResource({
95109
template: [
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
'use strict';
2+
3+
module.exports = () => {
4+
return async function tracelog(ctx, next) {
5+
ctx.req.headers.trace = 'middleware';
6+
await next();
7+
};
8+
};

plugin/controller/test/fixtures/apps/mcp-app/config/config.default.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ module.exports = function() {
1414
sessionIdGenerator: ctx => {
1515
return ctx.request.headers['custom-session-id'] || randomUUID();
1616
},
17+
middleware: [
18+
'tracelog',
19+
],
1720
},
1821
};
1922
return config;

plugin/controller/test/mcp/mcp.test.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
133133
description: undefined,
134134
name: 'echoUser',
135135
},
136+
{
137+
description: undefined,
138+
name: 'traceTest',
139+
},
136140
]);
137141

138142
const toolRes = await sseClient.callTool({
@@ -152,6 +156,14 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
152156
assert.deepEqual(userRes, {
153157
content: [{ type: 'text', text: 'hello akita' }],
154158
});
159+
160+
const traceRes = await sseClient.callTool({
161+
name: 'traceTest',
162+
arguments: {},
163+
});
164+
assert.deepEqual(traceRes, {
165+
content: [{ type: 'text', text: 'hello middleware' }],
166+
});
155167
// notification
156168
const notificationResp = await startNotificationTool(sseClient);
157169
await new Promise(resolve => setTimeout(resolve, 5000));
@@ -260,6 +272,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
260272
description: undefined,
261273
name: 'echoUser',
262274
},
275+
{
276+
description: undefined,
277+
name: 'traceTest',
278+
},
263279
]);
264280

265281
const toolRes = await streamableClient.callTool({
@@ -279,6 +295,14 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
279295
assert.deepEqual(userRes, {
280296
content: [{ type: 'text', text: 'hello akita' }],
281297
});
298+
299+
const traceRes = await streamableClient.callTool({
300+
name: 'traceTest',
301+
arguments: {},
302+
});
303+
assert.deepEqual(traceRes, {
304+
content: [{ type: 'text', text: 'hello middleware' }],
305+
});
282306
// notification
283307
const notificationResp = await startNotificationTool(streamableClient);
284308
await new Promise(resolve => setTimeout(resolve, 5000));
@@ -391,6 +415,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
391415
description: undefined,
392416
name: 'echoUser',
393417
},
418+
{
419+
description: undefined,
420+
name: 'traceTest',
421+
},
394422
]);
395423

396424
const toolRes = await streamableClient.callTool({
@@ -410,6 +438,14 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
410438
assert.deepEqual(userRes, {
411439
content: [{ type: 'text', text: 'hello akita' }],
412440
});
441+
442+
const traceRes = await streamableClient.callTool({
443+
name: 'traceTest',
444+
arguments: {},
445+
});
446+
assert.deepEqual(traceRes, {
447+
content: [{ type: 'text', text: 'hello middleware' }],
448+
});
413449
// notification
414450
const notificationResp = await startNotificationTool(streamableClient);
415451
await new Promise(resolve => setTimeout(resolve, 5000));

plugin/controller/test/mcp/mcpCluster.test.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ describe('plugin/controller/test/mcp/mcpCluster.test.ts', () => {
143143
description: undefined,
144144
name: 'echoUser',
145145
},
146+
{
147+
description: undefined,
148+
name: 'traceTest',
149+
},
146150
]);
147151

148152
const toolRes = await sseClient.callTool({
@@ -161,6 +165,14 @@ describe('plugin/controller/test/mcp/mcpCluster.test.ts', () => {
161165
assert.deepEqual(userRes, {
162166
content: [{ type: 'text', text: 'hello akita' }],
163167
});
168+
169+
const traceRes = await sseClient.callTool({
170+
name: 'traceTest',
171+
arguments: {},
172+
});
173+
assert.deepEqual(traceRes, {
174+
content: [{ type: 'text', text: 'hello middleware' }],
175+
});
164176
// notification
165177
const notificationResp = await startNotificationTool(sseClient);
166178
await new Promise(resolve => setTimeout(resolve, 5000));
@@ -269,6 +281,10 @@ describe('plugin/controller/test/mcp/mcpCluster.test.ts', () => {
269281
description: undefined,
270282
name: 'echoUser',
271283
},
284+
{
285+
description: undefined,
286+
name: 'traceTest',
287+
},
272288
]);
273289

274290
const toolRes = await streamableClient.callTool({
@@ -287,6 +303,14 @@ describe('plugin/controller/test/mcp/mcpCluster.test.ts', () => {
287303
assert.deepEqual(userRes, {
288304
content: [{ type: 'text', text: 'hello akita' }],
289305
});
306+
307+
const traceRes = await streamableClient.callTool({
308+
name: 'traceTest',
309+
arguments: {},
310+
});
311+
assert.deepEqual(traceRes, {
312+
content: [{ type: 'text', text: 'hello middleware' }],
313+
});
290314
// notification
291315
const notificationResp = await startNotificationTool(streamableClient);
292316
await new Promise(resolve => setTimeout(resolve, 5000));
@@ -399,6 +423,10 @@ describe('plugin/controller/test/mcp/mcpCluster.test.ts', () => {
399423
description: undefined,
400424
name: 'echoUser',
401425
},
426+
{
427+
description: undefined,
428+
name: 'traceTest',
429+
},
402430
]);
403431

404432
const toolRes = await streamableClient.callTool({
@@ -417,6 +445,14 @@ describe('plugin/controller/test/mcp/mcpCluster.test.ts', () => {
417445
assert.deepEqual(userRes, {
418446
content: [{ type: 'text', text: 'hello akita' }],
419447
});
448+
449+
const traceRes = await streamableClient.callTool({
450+
name: 'traceTest',
451+
arguments: {},
452+
});
453+
assert.deepEqual(traceRes, {
454+
content: [{ type: 'text', text: 'hello middleware' }],
455+
});
420456
// notification
421457
const notificationResp = await startNotificationTool(streamableClient);
422458
await new Promise(resolve => setTimeout(resolve, 5000));

plugin/mcp-proxy/index.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import cluster from 'node:cluster';
1414
import { MCPControllerRegister, MCPControllerHook } from '@eggjs/tegg-controller-plugin/lib/impl/mcp/MCPControllerRegister';
1515
import querystring from 'node:querystring';
1616
import url from 'node:url';
17+
import compose from 'koa-compose';
1718
import { MCPProtocols } from './types';
1819

1920
const MAXIMUM_MESSAGE_SIZE = '4mb';
@@ -118,7 +119,10 @@ export const MCPProxyHook: MCPControllerHook = {
118119
const sessionId = transport.sessionId!;
119120
self.streamTransports[sessionId] = transport;
120121
self.app.mcpProxy.setProxyHandler(MCPProtocols.STREAM, async (req: http.IncomingMessage, res: http.ServerResponse) => {
121-
const mw = self.app.middleware.teggCtxLifecycleMiddleware();
122+
let mw = self.app.middleware.teggCtxLifecycleMiddleware();
123+
if (self.globalMiddlewares) {
124+
mw = compose([ mw, self.globalMiddlewares ]);
125+
}
122126
const ctx = self.app.createContext(req, res) as unknown as Context;
123127
if (MCPControllerRegister.hooks.length > 0) {
124128
for (const hook of MCPControllerRegister.hooks) {

0 commit comments

Comments
 (0)