|
1 | 1 | import { createConnection } from "net" |
| 2 | +import { createServer } from "http" |
2 | 3 | import { Log } from "../util/log" |
3 | 4 | import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider" |
4 | 5 |
|
@@ -52,89 +53,91 @@ interface PendingAuth { |
52 | 53 | } |
53 | 54 |
|
54 | 55 | export namespace McpOAuthCallback { |
55 | | - let server: ReturnType<typeof Bun.serve> | undefined |
| 56 | + let server: ReturnType<typeof createServer> | undefined |
56 | 57 | const pendingAuths = new Map<string, PendingAuth>() |
57 | 58 |
|
58 | 59 | const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes |
59 | 60 |
|
60 | | - export async function ensureRunning(): Promise<void> { |
61 | | - if (server) return |
| 61 | + function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) { |
| 62 | + const url = new URL(req.url || "/", `http://localhost:${OAUTH_CALLBACK_PORT}`) |
62 | 63 |
|
63 | | - const running = await isPortInUse() |
64 | | - if (running) { |
65 | | - log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT }) |
| 64 | + if (url.pathname !== OAUTH_CALLBACK_PATH) { |
| 65 | + res.writeHead(404) |
| 66 | + res.end("Not found") |
66 | 67 | return |
67 | 68 | } |
68 | 69 |
|
69 | | - server = Bun.serve({ |
70 | | - port: OAUTH_CALLBACK_PORT, |
71 | | - fetch(req) { |
72 | | - const url = new URL(req.url) |
| 70 | + const code = url.searchParams.get("code") |
| 71 | + const state = url.searchParams.get("state") |
| 72 | + const error = url.searchParams.get("error") |
| 73 | + const errorDescription = url.searchParams.get("error_description") |
73 | 74 |
|
74 | | - if (url.pathname !== OAUTH_CALLBACK_PATH) { |
75 | | - return new Response("Not found", { status: 404 }) |
76 | | - } |
| 75 | + log.info("received oauth callback", { hasCode: !!code, state, error }) |
77 | 76 |
|
78 | | - const code = url.searchParams.get("code") |
79 | | - const state = url.searchParams.get("state") |
80 | | - const error = url.searchParams.get("error") |
81 | | - const errorDescription = url.searchParams.get("error_description") |
82 | | - |
83 | | - log.info("received oauth callback", { hasCode: !!code, state, error }) |
84 | | - |
85 | | - // Enforce state parameter presence |
86 | | - if (!state) { |
87 | | - const errorMsg = "Missing required state parameter - potential CSRF attack" |
88 | | - log.error("oauth callback missing state parameter", { url: url.toString() }) |
89 | | - return new Response(HTML_ERROR(errorMsg), { |
90 | | - status: 400, |
91 | | - headers: { "Content-Type": "text/html" }, |
92 | | - }) |
93 | | - } |
| 77 | + // Enforce state parameter presence |
| 78 | + if (!state) { |
| 79 | + const errorMsg = "Missing required state parameter - potential CSRF attack" |
| 80 | + log.error("oauth callback missing state parameter", { url: url.toString() }) |
| 81 | + res.writeHead(400, { "Content-Type": "text/html" }) |
| 82 | + res.end(HTML_ERROR(errorMsg)) |
| 83 | + return |
| 84 | + } |
94 | 85 |
|
95 | | - if (error) { |
96 | | - const errorMsg = errorDescription || error |
97 | | - if (pendingAuths.has(state)) { |
98 | | - const pending = pendingAuths.get(state)! |
99 | | - clearTimeout(pending.timeout) |
100 | | - pendingAuths.delete(state) |
101 | | - pending.reject(new Error(errorMsg)) |
102 | | - } |
103 | | - return new Response(HTML_ERROR(errorMsg), { |
104 | | - headers: { "Content-Type": "text/html" }, |
105 | | - }) |
106 | | - } |
| 86 | + if (error) { |
| 87 | + const errorMsg = errorDescription || error |
| 88 | + if (pendingAuths.has(state)) { |
| 89 | + const pending = pendingAuths.get(state)! |
| 90 | + clearTimeout(pending.timeout) |
| 91 | + pendingAuths.delete(state) |
| 92 | + pending.reject(new Error(errorMsg)) |
| 93 | + } |
| 94 | + res.writeHead(200, { "Content-Type": "text/html" }) |
| 95 | + res.end(HTML_ERROR(errorMsg)) |
| 96 | + return |
| 97 | + } |
107 | 98 |
|
108 | | - if (!code) { |
109 | | - return new Response(HTML_ERROR("No authorization code provided"), { |
110 | | - status: 400, |
111 | | - headers: { "Content-Type": "text/html" }, |
112 | | - }) |
113 | | - } |
| 99 | + if (!code) { |
| 100 | + res.writeHead(400, { "Content-Type": "text/html" }) |
| 101 | + res.end(HTML_ERROR("No authorization code provided")) |
| 102 | + return |
| 103 | + } |
114 | 104 |
|
115 | | - // Validate state parameter |
116 | | - if (!pendingAuths.has(state)) { |
117 | | - const errorMsg = "Invalid or expired state parameter - potential CSRF attack" |
118 | | - log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) }) |
119 | | - return new Response(HTML_ERROR(errorMsg), { |
120 | | - status: 400, |
121 | | - headers: { "Content-Type": "text/html" }, |
122 | | - }) |
123 | | - } |
| 105 | + // Validate state parameter |
| 106 | + if (!pendingAuths.has(state)) { |
| 107 | + const errorMsg = "Invalid or expired state parameter - potential CSRF attack" |
| 108 | + log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) }) |
| 109 | + res.writeHead(400, { "Content-Type": "text/html" }) |
| 110 | + res.end(HTML_ERROR(errorMsg)) |
| 111 | + return |
| 112 | + } |
124 | 113 |
|
125 | | - const pending = pendingAuths.get(state)! |
| 114 | + const pending = pendingAuths.get(state)! |
126 | 115 |
|
127 | | - clearTimeout(pending.timeout) |
128 | | - pendingAuths.delete(state) |
129 | | - pending.resolve(code) |
| 116 | + clearTimeout(pending.timeout) |
| 117 | + pendingAuths.delete(state) |
| 118 | + pending.resolve(code) |
130 | 119 |
|
131 | | - return new Response(HTML_SUCCESS, { |
132 | | - headers: { "Content-Type": "text/html" }, |
133 | | - }) |
134 | | - }, |
135 | | - }) |
| 120 | + res.writeHead(200, { "Content-Type": "text/html" }) |
| 121 | + res.end(HTML_SUCCESS) |
| 122 | + } |
| 123 | + |
| 124 | + export async function ensureRunning(): Promise<void> { |
| 125 | + if (server) return |
136 | 126 |
|
137 | | - log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT }) |
| 127 | + const running = await isPortInUse() |
| 128 | + if (running) { |
| 129 | + log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT }) |
| 130 | + return |
| 131 | + } |
| 132 | + |
| 133 | + server = createServer(handleRequest) |
| 134 | + await new Promise<void>((resolve, reject) => { |
| 135 | + server!.listen(OAUTH_CALLBACK_PORT, () => { |
| 136 | + log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT }) |
| 137 | + resolve() |
| 138 | + }) |
| 139 | + server!.on("error", reject) |
| 140 | + }) |
138 | 141 | } |
139 | 142 |
|
140 | 143 | export function waitForCallback(oauthState: string): Promise<string> { |
@@ -174,7 +177,7 @@ export namespace McpOAuthCallback { |
174 | 177 |
|
175 | 178 | export async function stop(): Promise<void> { |
176 | 179 | if (server) { |
177 | | - server.stop() |
| 180 | + await new Promise<void>((resolve) => server!.close(() => resolve())) |
178 | 181 | server = undefined |
179 | 182 | log.info("oauth callback server stopped") |
180 | 183 | } |
|
0 commit comments