Skip to content

Commit 5716e0a

Browse files
authored
fix: sse new ctx (#323)
<!-- Thank you for your pull request. Please review below requirements. Bug fixes and new features should include tests and possibly benchmarks. Contributors guide: https://github.com/eggjs/egg/blob/master/CONTRIBUTING.md 感谢您贡献代码。请确认下列 checklist 的完成情况。 Bug 修复和新功能必须包含测试,必要时请附上性能测试。 Contributors guide: https://github.com/eggjs/egg/blob/master/CONTRIBUTING.md --> ##### Checklist <!-- Remove items that do not apply. For completed items, change [ ] to [x]. --> - [ ] `npm test` passes - [ ] tests and/or benchmarks are included - [ ] documentation is changed or added - [ ] commit message follows commit guidelines ##### Affected core subsystem(s) <!-- Provide affected core subsystem(s). --> ##### Description of change <!-- Provide a description of the change below this comment. --> <!-- - any feature? - close https://github.com/eggjs/egg/ISSUE_URL --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a plugin system for controller hooks, allowing user information to be injected into request contexts based on authorization headers. - Added a new tool, "echoUser", which greets the authenticated user. - **Bug Fixes** - Improved request header handling for Server-Sent Events to ensure all original headers are included. - **Tests** - Enhanced tests to cover authentication scenarios and verify the new "echoUser" tool functionality. - **Chores** - Added configuration and setup for the new hook plugin. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent ab43456 commit 5716e0a

File tree

9 files changed

+302
-11
lines changed

9 files changed

+302
-11
lines changed

plugin/controller/lib/impl/mcp/MCPControllerRegister.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,11 +397,17 @@ export class MCPControllerRegister implements ControllerRegister {
397397
req.method = 'POST';
398398
req.url = self.mcpConfig.sseInitPath;
399399
req.headers = {
400+
...ctx.req.headers,
400401
accept: 'application/json, text/event-stream',
401402
'content-type': 'application/json',
402403
};
403404
const newCtx = self.app.createContext(req, res) as unknown as Context;
404405
await ctx.app.ctxStorage.run(newCtx, async () => {
406+
if (MCPControllerRegister.hooks.length > 0) {
407+
for (const hook of MCPControllerRegister.hooks) {
408+
await hook.preHandle?.(newCtx);
409+
}
410+
}
405411
await mw(newCtx, async () => {
406412
messageFunc!(...args);
407413
if (isJSONRPCRequest(args[0])) {

plugin/controller/test/fixtures/apps/mcp-app/app/controller/McpController.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ export class McpController {
4646
@Inject()
4747
commonService: CommonService;
4848

49+
@Inject()
50+
user: any;
51+
4952
@MCPPrompt()
5053
async foo(@PromptArgsSchema(PromptType) args: PromptArgs<typeof PromptType>): Promise<MCPPromptResponse> {
5154
this.logger.info('hello world');
@@ -75,6 +78,18 @@ export class McpController {
7578
};
7679
}
7780

81+
@MCPTool()
82+
async echoUser(): Promise<MCPToolResponse> {
83+
return {
84+
content: [
85+
{
86+
type: 'text',
87+
text: `hello ${this.user}`,
88+
},
89+
],
90+
};
91+
}
92+
7893

7994
@MCPResource({
8095
template: [

plugin/controller/test/fixtures/apps/mcp-app/config/plugin.js

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
'use strict';
2+
// eslint-disable-next-line @typescript-eslint/no-var-requires
3+
const path = require('path');
24

35
exports.tracer = {
46
package: 'egg-tracer',
@@ -24,3 +26,9 @@ exports.mcpProxy = {
2426
package: '@eggjs/mcp-proxy',
2527
enable: true,
2628
};
29+
30+
31+
exports.hookPlugin = {
32+
path: path.join(__dirname, '../hook-plugin'),
33+
enable: true,
34+
};
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import type { Application } from 'egg';
2+
import { MCPControllerRegister } from '@eggjs/tegg-controller-plugin/lib/impl/mcp/MCPControllerRegister';
3+
import { GetAlipayTeggHook } from './lib/MCPControllerHook';
4+
5+
export default class ControllerAppBootHook {
6+
#app: Application;
7+
8+
constructor(app: Application) {
9+
this.#app = app;
10+
}
11+
12+
configWillLoad() {
13+
MCPControllerRegister.addHook(GetAlipayTeggHook(this.#app));
14+
}
15+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import type { Application, Context } from 'egg';
2+
import { MCPControllerHook } from '@eggjs/tegg-controller-plugin/lib/impl/mcp/MCPControllerRegister';
3+
4+
export const GetAlipayTeggHook = (app: Application) => {
5+
const setUser = (ctx: Context) => {
6+
ctx.set({
7+
'content-type': 'text/event-stream',
8+
'cache-control': 'no-cache',
9+
'transfer-encoding': 'chunked',
10+
});
11+
try {
12+
const auth = ctx.get('authorization');
13+
const atitString = Buffer.from(
14+
auth.substring('Bearer '.length),
15+
'base64',
16+
).toString('utf8');
17+
ctx.user = atitString;
18+
} catch (e) {
19+
app.logger.warn('get user failed: ', e);
20+
}
21+
};
22+
const AlipayTeggControllerHook: MCPControllerHook = {
23+
async preHandle(ctx) {
24+
setUser(ctx);
25+
},
26+
async preHandleInitHandle(ctx) {
27+
setUser(ctx);
28+
},
29+
async preSSEInitHandle(ctx) {
30+
setUser(ctx);
31+
},
32+
async preProxy(ctx) {
33+
setUser(ctx);
34+
},
35+
};
36+
37+
return AlipayTeggControllerHook;
38+
};
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"name": "@eggjs/hook-plugin",
3+
"eggPlugin": {
4+
"name": "hookPlugin"
5+
}
6+
}

plugin/controller/test/mcp/mcp.test.ts

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,29 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
9090
});
9191
const baseUrl = await app.httpRequest()
9292
.get('/mcp/init').url;
93-
const sseTransport = new SSEClientTransport(new URL(baseUrl));
93+
const sseTransport = new SSEClientTransport(
94+
new URL(baseUrl),
95+
{
96+
authProvider: {
97+
get redirectUrl() { return 'http://localhost/callback'; },
98+
get clientMetadata() { return { redirect_uris: [ 'http://localhost/callback' ] }; },
99+
clientInformation: () => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' }),
100+
tokens: () => {
101+
return {
102+
access_token: Buffer.from('akita').toString('base64'),
103+
token_type: 'Bearer',
104+
};
105+
},
106+
// eslint-disable-next-line @typescript-eslint/no-empty-function
107+
saveTokens: () => {},
108+
// eslint-disable-next-line @typescript-eslint/no-empty-function
109+
redirectToAuthorization: () => {},
110+
// eslint-disable-next-line @typescript-eslint/no-empty-function
111+
saveCodeVerifier: () => {},
112+
codeVerifier: () => '',
113+
},
114+
},
115+
);
94116
const sseNotifications: { level: string, data: string }[] = [];
95117
sseClient.setNotificationHandler(LoggingMessageNotificationSchema, notification => {
96118
sseNotifications.push({ level: notification.params.level, data: notification.params.data as string });
@@ -107,6 +129,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
107129
description: undefined,
108130
name: 'bar',
109131
},
132+
{
133+
description: undefined,
134+
name: 'echoUser',
135+
},
110136
]);
111137

112138
const toolRes = await sseClient.callTool({
@@ -118,6 +144,14 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
118144
assert.deepEqual(toolRes, {
119145
content: [{ type: 'text', text: 'npm package: aaa not found' }],
120146
});
147+
148+
const userRes = await sseClient.callTool({
149+
name: 'echoUser',
150+
arguments: {},
151+
});
152+
assert.deepEqual(userRes, {
153+
content: [{ type: 'text', text: 'hello akita' }],
154+
});
121155
// notification
122156
const notificationResp = await startNotificationTool(sseClient);
123157
await new Promise(resolve => setTimeout(resolve, 5000));
@@ -181,7 +215,30 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
181215
});
182216
const baseUrl = await app.httpRequest()
183217
.post('/mcp/stream').url;
184-
const streamableTransport = new StreamableHTTPClientTransport(new URL(baseUrl), { requestInit: { headers: { 'custom-session-id': 'custom-session-id' } } });
218+
const streamableTransport = new StreamableHTTPClientTransport(
219+
new URL(baseUrl),
220+
{
221+
authProvider: {
222+
get redirectUrl() { return 'http://localhost/callback'; },
223+
get clientMetadata() { return { redirect_uris: [ 'http://localhost/callback' ] }; },
224+
clientInformation: () => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' }),
225+
tokens: () => {
226+
return {
227+
access_token: Buffer.from('akita').toString('base64'),
228+
token_type: 'Bearer',
229+
};
230+
},
231+
// eslint-disable-next-line @typescript-eslint/no-empty-function
232+
saveTokens: () => {},
233+
// eslint-disable-next-line @typescript-eslint/no-empty-function
234+
redirectToAuthorization: () => {},
235+
// eslint-disable-next-line @typescript-eslint/no-empty-function
236+
saveCodeVerifier: () => {},
237+
codeVerifier: () => '',
238+
},
239+
requestInit: { headers: { 'custom-session-id': 'custom-session-id' } },
240+
},
241+
);
185242
const streamableNotifications: { level: string, data: string }[] = [];
186243
streamableClient.setNotificationHandler(LoggingMessageNotificationSchema, notification => {
187244
streamableNotifications.push({ level: notification.params.level, data: notification.params.data as string });
@@ -199,6 +256,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
199256
description: undefined,
200257
name: 'bar',
201258
},
259+
{
260+
description: undefined,
261+
name: 'echoUser',
262+
},
202263
]);
203264

204265
const toolRes = await streamableClient.callTool({
@@ -210,6 +271,14 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
210271
assert.deepEqual(toolRes, {
211272
content: [{ type: 'text', text: 'npm package: aaa not found' }],
212273
});
274+
275+
const userRes = await streamableClient.callTool({
276+
name: 'echoUser',
277+
arguments: {},
278+
});
279+
assert.deepEqual(userRes, {
280+
content: [{ type: 'text', text: 'hello akita' }],
281+
});
213282
// notification
214283
const notificationResp = await startNotificationTool(streamableClient);
215284
await new Promise(resolve => setTimeout(resolve, 5000));
@@ -279,7 +348,29 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
279348
});
280349
const baseUrl = await app.httpRequest()
281350
.post('/mcp/stateless/stream').url;
282-
const streamableTransport = new StreamableHTTPClientTransport(new URL(baseUrl));
351+
const streamableTransport = new StreamableHTTPClientTransport(
352+
new URL(baseUrl),
353+
{
354+
authProvider: {
355+
get redirectUrl() { return 'http://localhost/callback'; },
356+
get clientMetadata() { return { redirect_uris: [ 'http://localhost/callback' ] }; },
357+
clientInformation: () => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' }),
358+
tokens: () => {
359+
return {
360+
access_token: Buffer.from('akita').toString('base64'),
361+
token_type: 'Bearer',
362+
};
363+
},
364+
// eslint-disable-next-line @typescript-eslint/no-empty-function
365+
saveTokens: () => {},
366+
// eslint-disable-next-line @typescript-eslint/no-empty-function
367+
redirectToAuthorization: () => {},
368+
// eslint-disable-next-line @typescript-eslint/no-empty-function
369+
saveCodeVerifier: () => {},
370+
codeVerifier: () => '',
371+
},
372+
},
373+
);
283374
const streamableNotifications: { level: string, data: string }[] = [];
284375
streamableClient.setNotificationHandler(LoggingMessageNotificationSchema, notification => {
285376
streamableNotifications.push({ level: notification.params.level, data: notification.params.data as string });
@@ -296,6 +387,10 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
296387
description: undefined,
297388
name: 'bar',
298389
},
390+
{
391+
description: undefined,
392+
name: 'echoUser',
393+
},
299394
]);
300395

301396
const toolRes = await streamableClient.callTool({
@@ -307,6 +402,14 @@ describe('plugin/controller/test/mcp/mcp.test.ts', () => {
307402
assert.deepEqual(toolRes, {
308403
content: [{ type: 'text', text: 'npm package: aaa not found' }],
309404
});
405+
406+
const userRes = await streamableClient.callTool({
407+
name: 'echoUser',
408+
arguments: {},
409+
});
410+
assert.deepEqual(userRes, {
411+
content: [{ type: 'text', text: 'hello akita' }],
412+
});
310413
// notification
311414
const notificationResp = await startNotificationTool(streamableClient);
312415
await new Promise(resolve => setTimeout(resolve, 5000));

0 commit comments

Comments
 (0)