From ce1000b1e8a36b4636f83582d264590a14080147 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Thu, 21 Aug 2025 18:50:18 +0000 Subject: [PATCH] feat: add OAuth authentication support for streamable-http MCP servers - Implement OAuthHandler to manage OAuth flows and token storage - Create OAuthStreamableHTTPClientTransport wrapper for OAuth-enabled connections - Add OAuth configuration support to MCP server schema - Store OAuth tokens securely in VS Code global storage - Support token refresh and re-authentication on 401 responses - Add comprehensive tests for OAuth functionality Fixes #7296 --- src/services/mcp/McpHub.ts | 91 +++- src/services/mcp/OAuthHandler.ts | 488 ++++++++++++++++++ .../mcp/OAuthStreamableHTTPClientTransport.ts | 314 +++++++++++ .../mcp/__tests__/OAuthHandler.spec.ts | 312 +++++++++++ 4 files changed, 1186 insertions(+), 19 deletions(-) create mode 100644 src/services/mcp/OAuthHandler.ts create mode 100644 src/services/mcp/OAuthStreamableHTTPClientTransport.ts create mode 100644 src/services/mcp/__tests__/OAuthHandler.spec.ts diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 271c6e1fb3..0c2618b13b 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -2,6 +2,8 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js" import { StdioClientTransport, getDefaultEnvironment } from "@modelcontextprotocol/sdk/client/stdio.js" import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js" import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js" +import { OAuthStreamableHTTPClientTransport } from "./OAuthStreamableHTTPClientTransport" +import { OAuthConfig } from "./OAuthHandler" import ReconnectingEventSource from "reconnecting-eventsource" import { CallToolResultSchema, @@ -117,6 +119,18 @@ const createServerTypeSchema = () => { type: z.enum(["streamable-http"]).optional(), url: z.string().url("URL must be a valid URL format"), headers: z.record(z.string()).optional(), + // OAuth configuration (optional) + oauth: z + .object({ + clientId: z.string(), + clientSecret: z.string().optional(), + authorizationUrl: z.string().url(), + tokenUrl: z.string().url(), + redirectUri: z.string().optional(), + scopes: z.array(z.string()).optional(), + additionalParams: z.record(z.string()).optional(), + }) + .optional(), // Ensure no stdio fields are present command: z.undefined().optional(), args: z.undefined().optional(), @@ -736,30 +750,69 @@ export class McpHub { console.error(`No stderr stream for ${name}`) } } else if (configInjected.type === "streamable-http") { - // Streamable HTTP connection - transport = new StreamableHTTPClientTransport(new URL(configInjected.url), { - requestInit: { + // Check if OAuth is configured + if (configInjected.oauth) { + // Use OAuth-enabled transport + const provider = this.providerRef.deref() + if (!provider) { + throw new Error("Provider not available for OAuth initialization") + } + + const oauthTransport = new OAuthStreamableHTTPClientTransport({ + url: new URL(configInjected.url), headers: configInjected.headers, - }, - }) + oauth: configInjected.oauth as OAuthConfig, + serverName: name, + context: provider.context, + }) - // Set up Streamable HTTP specific error handling - transport.onerror = async (error) => { - console.error(`Transport error for "${name}" (streamable-http):`, error) - const connection = this.findConnection(name, source) - if (connection) { - connection.server.status = "disconnected" - this.appendErrorMessage(connection, error instanceof Error ? error.message : `${error}`) + // Get the underlying transport + transport = oauthTransport.getTransport() + + // Set up error handling + oauthTransport.onerror = async (error) => { + console.error(`Transport error for "${name}" (streamable-http with OAuth):`, error) + const connection = this.findConnection(name, source) + if (connection) { + connection.server.status = "disconnected" + this.appendErrorMessage(connection, error instanceof Error ? error.message : `${error}`) + } + await this.notifyWebviewOfServerChanges() } - await this.notifyWebviewOfServerChanges() - } - transport.onclose = async () => { - const connection = this.findConnection(name, source) - if (connection) { - connection.server.status = "disconnected" + oauthTransport.onclose = async () => { + const connection = this.findConnection(name, source) + if (connection) { + connection.server.status = "disconnected" + } + await this.notifyWebviewOfServerChanges() + } + } else { + // Standard Streamable HTTP connection without OAuth + transport = new StreamableHTTPClientTransport(new URL(configInjected.url), { + requestInit: { + headers: configInjected.headers, + }, + }) + + // Set up Streamable HTTP specific error handling + transport.onerror = async (error) => { + console.error(`Transport error for "${name}" (streamable-http):`, error) + const connection = this.findConnection(name, source) + if (connection) { + connection.server.status = "disconnected" + this.appendErrorMessage(connection, error instanceof Error ? error.message : `${error}`) + } + await this.notifyWebviewOfServerChanges() + } + + transport.onclose = async () => { + const connection = this.findConnection(name, source) + if (connection) { + connection.server.status = "disconnected" + } + await this.notifyWebviewOfServerChanges() } - await this.notifyWebviewOfServerChanges() } } else if (configInjected.type === "sse") { // SSE connection diff --git a/src/services/mcp/OAuthHandler.ts b/src/services/mcp/OAuthHandler.ts new file mode 100644 index 0000000000..b6daee0dee --- /dev/null +++ b/src/services/mcp/OAuthHandler.ts @@ -0,0 +1,488 @@ +import * as vscode from "vscode" +import * as http from "http" +import * as crypto from "crypto" +import * as fs from "fs/promises" +import * as path from "path" +import { URL } from "url" +import pkceChallenge from "pkce-challenge" + +export interface OAuthConfig { + clientId: string + clientSecret?: string + authorizationUrl: string + tokenUrl: string + redirectUri?: string + scopes?: string[] + additionalParams?: Record +} + +export interface OAuthTokens { + accessToken: string + refreshToken?: string + expiresAt?: number + tokenType?: string + scope?: string +} + +interface StoredOAuthData { + tokens: OAuthTokens + serverName: string + timestamp: number +} + +export class OAuthHandler { + private static instance: OAuthHandler | null = null + private server: http.Server | null = null + private pendingAuthorizations: Map void> = new Map() + private tokenStorage: Map = new Map() + private storageFilePath: string + + private constructor(private context: vscode.ExtensionContext) { + // Initialize storage file path + this.storageFilePath = path.join(context.globalStorageUri.fsPath, "mcp-oauth-tokens.json") + this.loadStoredTokens() + } + + public static getInstance(context: vscode.ExtensionContext): OAuthHandler { + if (!OAuthHandler.instance) { + OAuthHandler.instance = new OAuthHandler(context) + } + return OAuthHandler.instance + } + + /** + * Load stored OAuth tokens from disk + */ + private async loadStoredTokens(): Promise { + try { + const data = await fs.readFile(this.storageFilePath, "utf-8") + const parsed = JSON.parse(data) as Record + this.tokenStorage = new Map(Object.entries(parsed)) + } catch (error) { + // File doesn't exist or is invalid, start with empty storage + this.tokenStorage = new Map() + } + } + + /** + * Save OAuth tokens to disk + */ + private async saveStoredTokens(): Promise { + try { + const data = Object.fromEntries(this.tokenStorage) + await fs.mkdir(path.dirname(this.storageFilePath), { recursive: true }) + await fs.writeFile(this.storageFilePath, JSON.stringify(data, null, 2)) + } catch (error) { + console.error("Failed to save OAuth tokens:", error) + } + } + + /** + * Get stored tokens for a server + */ + public async getStoredTokens(serverName: string): Promise { + const stored = this.tokenStorage.get(serverName) + if (!stored) { + return null + } + + // Check if token is expired + if (stored.tokens.expiresAt && stored.tokens.expiresAt < Date.now()) { + // Token is expired, remove it + this.tokenStorage.delete(serverName) + await this.saveStoredTokens() + return null + } + + return stored.tokens + } + + /** + * Store tokens for a server + */ + private async storeTokens(serverName: string, tokens: OAuthTokens): Promise { + this.tokenStorage.set(serverName, { + tokens, + serverName, + timestamp: Date.now(), + }) + await this.saveStoredTokens() + } + + /** + * Clear stored tokens for a server + */ + public async clearTokens(serverName: string): Promise { + this.tokenStorage.delete(serverName) + await this.saveStoredTokens() + } + + /** + * Start the OAuth flow for a server + */ + public async authenticate(serverName: string, config: OAuthConfig): Promise { + // Check if we have valid stored tokens + const storedTokens = await this.getStoredTokens(serverName) + if (storedTokens) { + return storedTokens + } + + // Start OAuth flow + return new Promise((resolve) => { + this.startOAuthFlow(serverName, config, resolve) + }) + } + + private async startOAuthFlow( + serverName: string, + config: OAuthConfig, + resolve: (tokens: OAuthTokens | null) => void, + ): Promise { + try { + // Generate PKCE challenge + const pkce = await pkceChallenge() + const state = crypto.randomBytes(16).toString("hex") + + // Start local server if not already running + if (!this.server) { + await this.startCallbackServer() + } + + // Store the pending authorization + const authKey = `${serverName}-${state}` + this.pendingAuthorizations.set(authKey, async (tokens) => { + if (tokens) { + await this.storeTokens(serverName, tokens) + } + resolve(tokens) + }) + + // Build authorization URL + const authUrl = new URL(config.authorizationUrl) + authUrl.searchParams.set("client_id", config.clientId) + authUrl.searchParams.set("response_type", "code") + authUrl.searchParams.set("redirect_uri", config.redirectUri || "http://localhost:3000/callback") + authUrl.searchParams.set("state", state) + authUrl.searchParams.set("code_challenge", pkce.code_challenge) + authUrl.searchParams.set("code_challenge_method", "S256") + + if (config.scopes && config.scopes.length > 0) { + authUrl.searchParams.set("scope", config.scopes.join(" ")) + } + + // Add any additional parameters + if (config.additionalParams) { + for (const [key, value] of Object.entries(config.additionalParams)) { + authUrl.searchParams.set(key, value) + } + } + + // Store config for token exchange + this.storePendingConfig(authKey, config, pkce.code_verifier) + + // Open the authorization URL in the browser + const opened = await vscode.env.openExternal(vscode.Uri.parse(authUrl.toString())) + if (!opened) { + throw new Error("Failed to open authorization URL in browser") + } + + // Show information message + vscode.window.showInformationMessage( + `Opening browser for OAuth authentication for ${serverName}. Please complete the authorization flow.`, + ) + + // Set a timeout for the authorization + setTimeout( + () => { + if (this.pendingAuthorizations.has(authKey)) { + this.pendingAuthorizations.delete(authKey) + this.clearPendingConfig(authKey) + resolve(null) + vscode.window.showErrorMessage(`OAuth authentication timeout for ${serverName}`) + } + }, + 5 * 60 * 1000, + ) // 5 minutes timeout + } catch (error) { + console.error("OAuth authentication error:", error) + vscode.window.showErrorMessage(`OAuth authentication failed: ${error}`) + resolve(null) + } + } + + /** + * Start the local callback server + */ + private async startCallbackServer(): Promise { + return new Promise((resolve, reject) => { + this.server = http.createServer(async (req, res) => { + const url = new URL(req.url || "", `http://${req.headers.host}`) + + if (url.pathname === "/callback") { + await this.handleCallback(url, res) + } else { + res.writeHead(404) + res.end("Not found") + } + }) + + this.server.listen(3000, "localhost", () => { + console.log("OAuth callback server listening on http://localhost:3000") + resolve() + }) + + this.server.on("error", (error) => { + console.error("OAuth callback server error:", error) + reject(error) + }) + }) + } + + /** + * Handle OAuth callback + */ + private async handleCallback(url: URL, res: http.ServerResponse): Promise { + const code = url.searchParams.get("code") + const state = url.searchParams.get("state") + const error = url.searchParams.get("error") + + if (error) { + res.writeHead(200, { "Content-Type": "text/html" }) + res.end(` + + +

Authentication Failed

+

Error: ${error}

+

You can close this window.

+ + + `) + return + } + + if (!code || !state) { + res.writeHead(400, { "Content-Type": "text/html" }) + res.end(` + + +

Invalid Request

+

Missing authorization code or state.

+ + + `) + return + } + + // Find the pending authorization + let authKey: string | null = null + for (const key of this.pendingAuthorizations.keys()) { + if (key.endsWith(`-${state}`)) { + authKey = key + break + } + } + + if (!authKey) { + res.writeHead(400, { "Content-Type": "text/html" }) + res.end(` + + +

Invalid State

+

The authorization state is invalid or expired.

+ + + `) + return + } + + const callback = this.pendingAuthorizations.get(authKey) + const config = this.getPendingConfig(authKey) + + if (!callback || !config) { + res.writeHead(400, { "Content-Type": "text/html" }) + res.end(` + + +

Configuration Error

+

Missing configuration for this authorization.

+ + + `) + return + } + + try { + // Exchange code for tokens + const tokens = await this.exchangeCodeForTokens(code, config) + + // Send success response + res.writeHead(200, { "Content-Type": "text/html" }) + res.end(` + + +

Authentication Successful!

+

You can close this window and return to VS Code.

+ + + + `) + + // Clean up and call callback + this.pendingAuthorizations.delete(authKey) + this.clearPendingConfig(authKey) + callback(tokens) + } catch (error) { + console.error("Token exchange error:", error) + res.writeHead(500, { "Content-Type": "text/html" }) + res.end(` + + +

Token Exchange Failed

+

Error: ${error}

+ + + `) + + // Clean up and call callback with null + this.pendingAuthorizations.delete(authKey) + this.clearPendingConfig(authKey) + callback(null) + } + } + + /** + * Exchange authorization code for tokens + */ + private async exchangeCodeForTokens( + code: string, + config: { oauth: OAuthConfig; codeVerifier: string }, + ): Promise { + const params = new URLSearchParams({ + grant_type: "authorization_code", + code, + redirect_uri: config.oauth.redirectUri || "http://localhost:3000/callback", + client_id: config.oauth.clientId, + code_verifier: config.codeVerifier, + }) + + if (config.oauth.clientSecret) { + params.set("client_secret", config.oauth.clientSecret) + } + + const response = await fetch(config.oauth.tokenUrl, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: params.toString(), + }) + + if (!response.ok) { + const error = await response.text() + throw new Error(`Token exchange failed: ${response.status} - ${error}`) + } + + const data = await response.json() + + // Calculate expiration time if expires_in is provided + let expiresAt: number | undefined + if (data.expires_in) { + expiresAt = Date.now() + data.expires_in * 1000 + } + + return { + accessToken: data.access_token, + refreshToken: data.refresh_token, + expiresAt, + tokenType: data.token_type, + scope: data.scope, + } + } + + /** + * Refresh an access token + */ + public async refreshToken( + serverName: string, + config: OAuthConfig, + refreshToken: string, + ): Promise { + try { + const params = new URLSearchParams({ + grant_type: "refresh_token", + refresh_token: refreshToken, + client_id: config.clientId, + }) + + if (config.clientSecret) { + params.set("client_secret", config.clientSecret) + } + + const response = await fetch(config.tokenUrl, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: params.toString(), + }) + + if (!response.ok) { + // Refresh failed, clear stored tokens + await this.clearTokens(serverName) + return null + } + + const data = await response.json() + + // Calculate expiration time if expires_in is provided + let expiresAt: number | undefined + if (data.expires_in) { + expiresAt = Date.now() + data.expires_in * 1000 + } + + const tokens: OAuthTokens = { + accessToken: data.access_token, + refreshToken: data.refresh_token || refreshToken, // Use new refresh token if provided, otherwise keep the old one + expiresAt, + tokenType: data.token_type, + scope: data.scope, + } + + // Store the new tokens + await this.storeTokens(serverName, tokens) + + return tokens + } catch (error) { + console.error("Token refresh error:", error) + await this.clearTokens(serverName) + return null + } + } + + // Temporary storage for pending OAuth configs + private pendingConfigs: Map = new Map() + + private storePendingConfig(authKey: string, config: OAuthConfig, codeVerifier: string): void { + this.pendingConfigs.set(authKey, { oauth: config, codeVerifier }) + } + + private getPendingConfig(authKey: string): { oauth: OAuthConfig; codeVerifier: string } | undefined { + return this.pendingConfigs.get(authKey) + } + + private clearPendingConfig(authKey: string): void { + this.pendingConfigs.delete(authKey) + } + + /** + * Dispose of the OAuth handler + */ + public dispose(): void { + if (this.server) { + this.server.close() + this.server = null + } + this.pendingAuthorizations.clear() + this.pendingConfigs.clear() + } +} diff --git a/src/services/mcp/OAuthStreamableHTTPClientTransport.ts b/src/services/mcp/OAuthStreamableHTTPClientTransport.ts new file mode 100644 index 0000000000..907e4ef121 --- /dev/null +++ b/src/services/mcp/OAuthStreamableHTTPClientTransport.ts @@ -0,0 +1,314 @@ +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js" +import { OAuthHandler, OAuthConfig, OAuthTokens } from "./OAuthHandler" +import * as vscode from "vscode" + +export interface OAuthStreamableHTTPConfig { + url: URL + headers?: Record + oauth?: OAuthConfig + serverName: string + context: vscode.ExtensionContext +} + +/** + * Wrapper for StreamableHTTPClientTransport that adds OAuth authentication support + */ +export class OAuthStreamableHTTPClientTransport { + private transport: StreamableHTTPClientTransport + private oauthHandler: OAuthHandler | null = null + private oauthConfig: OAuthConfig | null = null + private serverName: string + private tokens: OAuthTokens | null = null + private isAuthenticating: boolean = false + private authenticationPromise: Promise | null = null + private originalHeaders: Record + + constructor(config: OAuthStreamableHTTPConfig) { + this.serverName = config.serverName + this.originalHeaders = config.headers || {} + + // If OAuth config is provided, set up OAuth handler + if (config.oauth) { + this.oauthConfig = config.oauth + this.oauthHandler = OAuthHandler.getInstance(config.context) + } + + // Create the base transport with initial headers + this.transport = new StreamableHTTPClientTransport(config.url, { + requestInit: { + headers: this.originalHeaders, + }, + }) + + // Intercept the transport to add OAuth headers + if (this.oauthConfig) { + this.setupOAuthInterception() + } + } + + /** + * Set up OAuth interception for the transport + */ + private setupOAuthInterception(): void { + // Store the original fetch function + const originalFetch = globalThis.fetch + + // Create a custom fetch that adds OAuth headers + const oauthFetch = async (input: RequestInfo | URL, init?: RequestInit): Promise => { + // Only intercept requests to our MCP server URL + const url = typeof input === "string" ? input : input instanceof URL ? input.href : input.url + + // Check if this is a request to our MCP server + if (url && this.shouldInterceptRequest(url)) { + // Ensure we have valid tokens + await this.ensureAuthenticated() + + // Add OAuth token to headers + if (this.tokens) { + const headers = new Headers(init?.headers || {}) + headers.set("Authorization", `${this.tokens.tokenType || "Bearer"} ${this.tokens.accessToken}`) + + init = { + ...init, + headers, + } + } + + // Make the request + const response = await originalFetch(input, init) + + // Check if we got a 401 Unauthorized response + if (response.status === 401) { + // Token might be expired, try to refresh or re-authenticate + await this.handleUnauthorized() + + // Retry the request with new token + if (this.tokens) { + const headers = new Headers(init?.headers || {}) + headers.set("Authorization", `${this.tokens.tokenType || "Bearer"} ${this.tokens.accessToken}`) + + const retryInit = { + ...init, + headers, + } + + return originalFetch(input, retryInit) + } + } + + return response + } + + // Not our request, pass through + return originalFetch(input, init) + } + + // Replace global fetch temporarily when our transport is active + // This is a workaround since StreamableHTTPClientTransport doesn't expose its fetch method + const originalStart = this.transport.start.bind(this.transport) + const originalClose = this.transport.close.bind(this.transport) + + this.transport.start = async () => { + // Replace fetch + globalThis.fetch = oauthFetch as typeof fetch + + // Ensure we're authenticated before starting + if (this.oauthConfig && this.oauthHandler) { + try { + await this.ensureAuthenticated() + } catch (error) { + console.error(`OAuth authentication failed for ${this.serverName}:`, error) + // Continue anyway - the server might not require auth for all endpoints + } + } + + return originalStart() + } + + this.transport.close = async () => { + // Restore original fetch + globalThis.fetch = originalFetch + return originalClose() + } + } + + /** + * Check if we should intercept this request + */ + private shouldInterceptRequest(url: string): boolean { + // This is a simplified check - you might want to make this more sophisticated + // For now, we'll intercept all requests while this transport is active + return true + } + + /** + * Ensure we have valid OAuth tokens + */ + private async ensureAuthenticated(): Promise { + // If already authenticating, wait for it to complete + if (this.isAuthenticating && this.authenticationPromise) { + await this.authenticationPromise + return + } + + // If we already have tokens, check if they're still valid + if (this.tokens) { + if (!this.tokens.expiresAt || this.tokens.expiresAt > Date.now()) { + // Tokens are still valid + return + } + + // Try to refresh the token + if (this.tokens.refreshToken && this.oauthHandler && this.oauthConfig) { + const refreshedTokens = await this.oauthHandler.refreshToken( + this.serverName, + this.oauthConfig, + this.tokens.refreshToken, + ) + + if (refreshedTokens) { + this.tokens = refreshedTokens + return + } + } + } + + // Check for stored tokens + if (this.oauthHandler) { + const storedTokens = await this.oauthHandler.getStoredTokens(this.serverName) + if (storedTokens) { + this.tokens = storedTokens + return + } + } + + // Need to authenticate + await this.authenticate() + } + + /** + * Handle 401 Unauthorized response + */ + private async handleUnauthorized(): Promise { + // If we have a refresh token, try to refresh + if (this.tokens?.refreshToken && this.oauthHandler && this.oauthConfig) { + const refreshedTokens = await this.oauthHandler.refreshToken( + this.serverName, + this.oauthConfig, + this.tokens.refreshToken, + ) + + if (refreshedTokens) { + this.tokens = refreshedTokens + return + } + } + + // Clear stored tokens and re-authenticate + if (this.oauthHandler) { + await this.oauthHandler.clearTokens(this.serverName) + } + + // Clear current tokens + this.tokens = null + + await this.authenticate() + } + + /** + * Perform OAuth authentication + */ + private async authenticate(): Promise { + if (!this.oauthHandler || !this.oauthConfig) { + throw new Error("OAuth not configured for this transport") + } + + // Prevent multiple simultaneous authentication attempts + if (this.isAuthenticating) { + if (this.authenticationPromise) { + await this.authenticationPromise + } + return + } + + this.isAuthenticating = true + this.authenticationPromise = (async () => { + try { + const tokens = await this.oauthHandler!.authenticate(this.serverName, this.oauthConfig!) + + if (!tokens) { + throw new Error("OAuth authentication failed or was cancelled") + } + + this.tokens = tokens + } finally { + this.isAuthenticating = false + this.authenticationPromise = null + } + })() + + await this.authenticationPromise + } + + /** + * Get the underlying transport + */ + public getTransport(): StreamableHTTPClientTransport { + return this.transport + } + + /** + * Start the transport + */ + public async start(): Promise { + await this.transport.start() + } + + /** + * Close the transport + */ + public async close(): Promise { + await this.transport.close() + } + + /** + * Check if OAuth is configured for this transport + */ + public hasOAuth(): boolean { + return this.oauthConfig !== null + } + + /** + * Get the current OAuth tokens (if any) + */ + public getTokens(): OAuthTokens | null { + return this.tokens + } + + /** + * Clear OAuth tokens and force re-authentication on next request + */ + public async clearTokens(): Promise { + this.tokens = null + if (this.oauthHandler) { + await this.oauthHandler.clearTokens(this.serverName) + } + } + + // Proxy all other properties and methods to the underlying transport + get onerror() { + return this.transport.onerror + } + + set onerror(handler: ((error: Error) => void) | undefined) { + this.transport.onerror = handler + } + + get onclose() { + return this.transport.onclose + } + + set onclose(handler: (() => void) | undefined) { + this.transport.onclose = handler + } +} diff --git a/src/services/mcp/__tests__/OAuthHandler.spec.ts b/src/services/mcp/__tests__/OAuthHandler.spec.ts new file mode 100644 index 0000000000..54994c7e7e --- /dev/null +++ b/src/services/mcp/__tests__/OAuthHandler.spec.ts @@ -0,0 +1,312 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" +import * as vscode from "vscode" +import { OAuthHandler, OAuthConfig } from "../OAuthHandler" +import * as fs from "fs/promises" +import * as http from "http" + +// Mock vscode +vi.mock("vscode", () => ({ + window: { + showInformationMessage: vi.fn(), + showErrorMessage: vi.fn(), + }, + env: { + openExternal: vi.fn().mockResolvedValue(true), + }, + Uri: { + parse: vi.fn((str: string) => ({ toString: () => str })), + }, + ExtensionContext: vi.fn(), +})) + +// Mock fs +vi.mock("fs/promises", () => ({ + readFile: vi.fn(), + writeFile: vi.fn(), + mkdir: vi.fn(), + access: vi.fn(), +})) + +// Mock http +vi.mock("http", () => { + const mockServer = { + listen: vi.fn((port: number, host: string, callback: () => void) => { + callback() + }), + close: vi.fn(), + on: vi.fn(), + } + return { + createServer: vi.fn(() => mockServer), + Server: vi.fn(), + } +}) + +// Mock pkce-challenge +vi.mock("pkce-challenge", () => ({ + default: vi.fn().mockResolvedValue({ + code_challenge: "test_challenge", + code_verifier: "test_verifier", + }), +})) + +describe("OAuthHandler", () => { + let mockContext: vscode.ExtensionContext + let oauthHandler: OAuthHandler + + beforeEach(() => { + // Reset all mocks + vi.clearAllMocks() + + // Create mock context + mockContext = { + globalStorageUri: { + fsPath: "/test/storage", + }, + } as any + + // Reset singleton + ;(OAuthHandler as any).instance = null + }) + + afterEach(() => { + // Clean up + if (oauthHandler) { + oauthHandler.dispose() + } + }) + + describe("getInstance", () => { + it("should create a singleton instance", () => { + const instance1 = OAuthHandler.getInstance(mockContext) + const instance2 = OAuthHandler.getInstance(mockContext) + + expect(instance1).toBe(instance2) + }) + }) + + describe("getStoredTokens", () => { + it("should return null when no tokens are stored", async () => { + vi.mocked(fs.readFile).mockRejectedValue(new Error("File not found")) + + oauthHandler = OAuthHandler.getInstance(mockContext) + const tokens = await oauthHandler.getStoredTokens("test-server") + + expect(tokens).toBeNull() + }) + + it("should return stored tokens when they exist and are valid", async () => { + const storedData = { + "test-server": { + tokens: { + accessToken: "test_token", + refreshToken: "refresh_token", + expiresAt: Date.now() + 3600000, // 1 hour from now + }, + serverName: "test-server", + timestamp: Date.now(), + }, + } + + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(storedData)) + + oauthHandler = OAuthHandler.getInstance(mockContext) + const tokens = await oauthHandler.getStoredTokens("test-server") + + expect(tokens).toEqual({ + accessToken: "test_token", + refreshToken: "refresh_token", + expiresAt: expect.any(Number), + }) + }) + + it("should return null when tokens are expired", async () => { + const storedData = { + "test-server": { + tokens: { + accessToken: "test_token", + refreshToken: "refresh_token", + expiresAt: Date.now() - 3600000, // 1 hour ago + }, + serverName: "test-server", + timestamp: Date.now(), + }, + } + + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(storedData)) + + oauthHandler = OAuthHandler.getInstance(mockContext) + const tokens = await oauthHandler.getStoredTokens("test-server") + + expect(tokens).toBeNull() + }) + }) + + describe("clearTokens", () => { + it("should remove tokens for a server", async () => { + const storedData = { + "test-server": { + tokens: { + accessToken: "test_token", + }, + serverName: "test-server", + timestamp: Date.now(), + }, + } + + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(storedData)) + vi.mocked(fs.writeFile).mockResolvedValue() + + oauthHandler = OAuthHandler.getInstance(mockContext) + await oauthHandler.clearTokens("test-server") + + // Check that writeFile was called with empty object for that server + expect(fs.writeFile).toHaveBeenCalledWith( + expect.stringContaining("mcp-oauth-tokens.json"), + expect.stringContaining("{}"), + ) + }) + }) + + describe("authenticate", () => { + it("should return stored tokens if they exist", async () => { + const storedData = { + "test-server": { + tokens: { + accessToken: "stored_token", + refreshToken: "stored_refresh", + expiresAt: Date.now() + 3600000, + }, + serverName: "test-server", + timestamp: Date.now(), + }, + } + + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(storedData)) + + oauthHandler = OAuthHandler.getInstance(mockContext) + + const config: OAuthConfig = { + clientId: "test_client", + authorizationUrl: "https://auth.example.com/authorize", + tokenUrl: "https://auth.example.com/token", + } + + const tokens = await oauthHandler.authenticate("test-server", config) + + expect(tokens).toEqual({ + accessToken: "stored_token", + refreshToken: "stored_refresh", + expiresAt: expect.any(Number), + }) + + // Should not open browser if tokens exist + expect(vscode.env.openExternal).not.toHaveBeenCalled() + }) + + it("should start OAuth flow when no tokens exist", async () => { + vi.mocked(fs.readFile).mockRejectedValue(new Error("File not found")) + + oauthHandler = OAuthHandler.getInstance(mockContext) + + const config: OAuthConfig = { + clientId: "test_client", + authorizationUrl: "https://auth.example.com/authorize", + tokenUrl: "https://auth.example.com/token", + scopes: ["read", "write"], + } + + // Start authentication (won't complete without callback) + const authPromise = oauthHandler.authenticate("test-server", config) + + // Should open browser + expect(vscode.env.openExternal).toHaveBeenCalledWith( + expect.objectContaining({ + toString: expect.any(Function), + }), + ) + + // Check the URL that was opened + const urlArg = vi.mocked(vscode.env.openExternal).mock.calls[0][0] + const url = urlArg.toString() + expect(url).toContain("https://auth.example.com/authorize") + expect(url).toContain("client_id=test_client") + expect(url).toContain("response_type=code") + expect(url).toContain("scope=read%20write") + expect(url).toContain("code_challenge=test_challenge") + }) + }) + + describe("refreshToken", () => { + it("should refresh tokens successfully", async () => { + // Mock successful token refresh response + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: "new_access_token", + refresh_token: "new_refresh_token", + expires_in: 3600, + token_type: "Bearer", + }), + }) + + vi.mocked(fs.readFile).mockRejectedValue(new Error("File not found")) + vi.mocked(fs.writeFile).mockResolvedValue() + vi.mocked(fs.mkdir).mockResolvedValue(undefined as any) + + oauthHandler = OAuthHandler.getInstance(mockContext) + + const config: OAuthConfig = { + clientId: "test_client", + clientSecret: "test_secret", + authorizationUrl: "https://auth.example.com/authorize", + tokenUrl: "https://auth.example.com/token", + } + + const tokens = await oauthHandler.refreshToken("test-server", config, "old_refresh_token") + + expect(tokens).toEqual({ + accessToken: "new_access_token", + refreshToken: "new_refresh_token", + expiresAt: expect.any(Number), + tokenType: "Bearer", + scope: undefined, + }) + + // Check that fetch was called with correct parameters + expect(global.fetch).toHaveBeenCalledWith( + "https://auth.example.com/token", + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: expect.stringContaining("grant_type=refresh_token"), + }), + ) + }) + + it("should return null when refresh fails", async () => { + // Mock failed token refresh response + global.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 401, + text: async () => "Invalid refresh token", + }) + + vi.mocked(fs.readFile).mockRejectedValue(new Error("File not found")) + + oauthHandler = OAuthHandler.getInstance(mockContext) + + const config: OAuthConfig = { + clientId: "test_client", + authorizationUrl: "https://auth.example.com/authorize", + tokenUrl: "https://auth.example.com/token", + } + + const tokens = await oauthHandler.refreshToken("test-server", config, "invalid_refresh_token") + + expect(tokens).toBeNull() + }) + }) +})