Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 72 additions & 19 deletions src/services/mcp/McpHub.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"
import { StdioClientTransport, getDefaultEnvironment } from "@modelcontextprotocol/sdk/client/stdio.js"
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"
import { OAuthStreamableHTTPClientTransport } from "./OAuthStreamableHTTPClientTransport"
import { OAuthConfig } from "./OAuthHandler"
import ReconnectingEventSource from "reconnecting-eventsource"
import {
CallToolResultSchema,
Expand Down Expand Up @@ -117,6 +119,18 @@ const createServerTypeSchema = () => {
type: z.enum(["streamable-http"]).optional(),
url: z.string().url("URL must be a valid URL format"),
headers: z.record(z.string()).optional(),
// OAuth configuration (optional)
oauth: z
.object({
clientId: z.string(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OAuth configuration is added to the schema but there's no validation of the OAuth URLs or required fields before use. Could we add validation to ensure the URLs are valid and all required OAuth fields are present when oauth is configured?

clientSecret: z.string().optional(),
authorizationUrl: z.string().url(),
tokenUrl: z.string().url(),
redirectUri: z.string().optional(),
scopes: z.array(z.string()).optional(),
additionalParams: z.record(z.string()).optional(),
})
.optional(),
// Ensure no stdio fields are present
command: z.undefined().optional(),
args: z.undefined().optional(),
Expand Down Expand Up @@ -736,30 +750,69 @@ export class McpHub {
console.error(`No stderr stream for ${name}`)
}
} else if (configInjected.type === "streamable-http") {
// Streamable HTTP connection
transport = new StreamableHTTPClientTransport(new URL(configInjected.url), {
requestInit: {
// Check if OAuth is configured
if (configInjected.oauth) {
// Use OAuth-enabled transport
const provider = this.providerRef.deref()
if (!provider) {
throw new Error("Provider not available for OAuth initialization")
}

const oauthTransport = new OAuthStreamableHTTPClientTransport({
url: new URL(configInjected.url),
headers: configInjected.headers,
},
})
oauth: configInjected.oauth as OAuthConfig,
serverName: name,
context: provider.context,
})

// Set up Streamable HTTP specific error handling
transport.onerror = async (error) => {
console.error(`Transport error for "${name}" (streamable-http):`, error)
const connection = this.findConnection(name, source)
if (connection) {
connection.server.status = "disconnected"
this.appendErrorMessage(connection, error instanceof Error ? error.message : `${error}`)
// Get the underlying transport
transport = oauthTransport.getTransport()

// Set up error handling
oauthTransport.onerror = async (error) => {
console.error(`Transport error for "${name}" (streamable-http with OAuth):`, error)
const connection = this.findConnection(name, source)
if (connection) {
connection.server.status = "disconnected"
this.appendErrorMessage(connection, error instanceof Error ? error.message : `${error}`)
}
await this.notifyWebviewOfServerChanges()
}
await this.notifyWebviewOfServerChanges()
}

transport.onclose = async () => {
const connection = this.findConnection(name, source)
if (connection) {
connection.server.status = "disconnected"
oauthTransport.onclose = async () => {
const connection = this.findConnection(name, source)
if (connection) {
connection.server.status = "disconnected"
}
await this.notifyWebviewOfServerChanges()
}
} else {
// Standard Streamable HTTP connection without OAuth
transport = new StreamableHTTPClientTransport(new URL(configInjected.url), {
requestInit: {
headers: configInjected.headers,
},
})

// Set up Streamable HTTP specific error handling
transport.onerror = async (error) => {
console.error(`Transport error for "${name}" (streamable-http):`, error)
const connection = this.findConnection(name, source)
if (connection) {
connection.server.status = "disconnected"
this.appendErrorMessage(connection, error instanceof Error ? error.message : `${error}`)
}
await this.notifyWebviewOfServerChanges()
}

transport.onclose = async () => {
const connection = this.findConnection(name, source)
if (connection) {
connection.server.status = "disconnected"
}
await this.notifyWebviewOfServerChanges()
}
await this.notifyWebviewOfServerChanges()
}
} else if (configInjected.type === "sse") {
// SSE connection
Expand Down
Loading
Loading