Skip to content

Commit 4086efe

Browse files
feat: add support for custom HTTP headers in client requests (#546)
* feat: add support for custom HTTP headers in client requests This update introduces the ability to include custom HTTP headers in requests sent from the client. This enhancement facilitates more flexible and secure communication with servers by allowing clients to pass additional information in the header of each request, such as authentication tokens or custom metadata. This feature is crucial for integrating with APIs that require specific headers for access control, content negotiation, or tracking purposes. Signed-off-by: Matthis Holleville <[email protected]> * feat(client/transport): enhance HTTP request flexibility Enhanced the flexibility of HTTP requests in the streamable HTTP client by allowing additional headers to be specified. This change aims to support more diverse server requirements and improve the adaptability of our client transport layer. Signed-off-by: Matthis Holleville <[email protected]> * fix: Improve OAuth error handling and test readability in HTTP transport - Enhanced OAuth error detection by using `errors.Is` for more reliable error handling. - Corrected a typo in a comment and improved code readability in tests by using a variable for headers. Signed-off-by: Matthis Holleville <[email protected]> * fix: improve variable naming for clarity in streamable_http_test Signed-off-by: Matthis Holleville <[email protected]> * feat: Ensure system headers are preserved in streamable HTTP tests To maintain consistency and ensure the integrity of HTTP headers during tests, system headers like Content-Type are now verified to be preserved. This change enhances the reliability of our testing framework by ensuring essential headers are not inadvertently removed or altered during the testing process. Signed-off-by: Matthis Holleville <[email protected]> --------- Signed-off-by: Matthis Holleville <[email protected]>
1 parent f60537b commit 4086efe

File tree

4 files changed

+85
-15
lines changed

4 files changed

+85
-15
lines changed

client/client.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"net/http"
78
"slices"
89
"sync"
910
"sync/atomic"
@@ -140,6 +141,7 @@ func (c *Client) sendRequest(
140141
ctx context.Context,
141142
method string,
142143
params any,
144+
header http.Header,
143145
) (*json.RawMessage, error) {
144146
if !c.initialized && method != "initialize" {
145147
return nil, fmt.Errorf("client not initialized")
@@ -152,6 +154,7 @@ func (c *Client) sendRequest(
152154
ID: mcp.NewRequestId(id),
153155
Method: method,
154156
Params: params,
157+
Header: header,
155158
}
156159

157160
response, err := c.transport.SendRequest(ctx, request)
@@ -193,7 +196,7 @@ func (c *Client) Initialize(
193196
Capabilities: capabilities,
194197
}
195198

196-
response, err := c.sendRequest(ctx, "initialize", params)
199+
response, err := c.sendRequest(ctx, "initialize", params, request.Header)
197200
if err != nil {
198201
return nil, err
199202
}
@@ -238,7 +241,7 @@ func (c *Client) Initialize(
238241
}
239242

240243
func (c *Client) Ping(ctx context.Context) error {
241-
_, err := c.sendRequest(ctx, "ping", nil)
244+
_, err := c.sendRequest(ctx, "ping", nil, nil)
242245
return err
243246
}
244247

@@ -319,7 +322,7 @@ func (c *Client) ReadResource(
319322
ctx context.Context,
320323
request mcp.ReadResourceRequest,
321324
) (*mcp.ReadResourceResult, error) {
322-
response, err := c.sendRequest(ctx, "resources/read", request.Params)
325+
response, err := c.sendRequest(ctx, "resources/read", request.Params, request.Header)
323326
if err != nil {
324327
return nil, err
325328
}
@@ -331,15 +334,15 @@ func (c *Client) Subscribe(
331334
ctx context.Context,
332335
request mcp.SubscribeRequest,
333336
) error {
334-
_, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
337+
_, err := c.sendRequest(ctx, "resources/subscribe", request.Params, request.Header)
335338
return err
336339
}
337340

338341
func (c *Client) Unsubscribe(
339342
ctx context.Context,
340343
request mcp.UnsubscribeRequest,
341344
) error {
342-
_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
345+
_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params, request.Header)
343346
return err
344347
}
345348

@@ -383,7 +386,7 @@ func (c *Client) GetPrompt(
383386
ctx context.Context,
384387
request mcp.GetPromptRequest,
385388
) (*mcp.GetPromptResult, error) {
386-
response, err := c.sendRequest(ctx, "prompts/get", request.Params)
389+
response, err := c.sendRequest(ctx, "prompts/get", request.Params, request.Header)
387390
if err != nil {
388391
return nil, err
389392
}
@@ -431,7 +434,7 @@ func (c *Client) CallTool(
431434
ctx context.Context,
432435
request mcp.CallToolRequest,
433436
) (*mcp.CallToolResult, error) {
434-
response, err := c.sendRequest(ctx, "tools/call", request.Params)
437+
response, err := c.sendRequest(ctx, "tools/call", request.Params, request.Header)
435438
if err != nil {
436439
return nil, err
437440
}
@@ -443,15 +446,15 @@ func (c *Client) SetLevel(
443446
ctx context.Context,
444447
request mcp.SetLevelRequest,
445448
) error {
446-
_, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
449+
_, err := c.sendRequest(ctx, "logging/setLevel", request.Params, request.Header)
447450
return err
448451
}
449452

450453
func (c *Client) Complete(
451454
ctx context.Context,
452455
request mcp.CompleteRequest,
453456
) (*mcp.CompleteResult, error) {
454-
response, err := c.sendRequest(ctx, "completion/complete", request.Params)
457+
response, err := c.sendRequest(ctx, "completion/complete", request.Params, request.Header)
455458
if err != nil {
456459
return nil, err
457460
}
@@ -591,7 +594,7 @@ func listByPage[T any](
591594
request mcp.PaginatedRequest,
592595
method string,
593596
) (*T, error) {
594-
response, err := client.sendRequest(ctx, method, request.Params)
597+
response, err := client.sendRequest(ctx, method, request.Params, nil)
595598
if err != nil {
596599
return nil, err
597600
}

client/transport/interface.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package transport
33
import (
44
"context"
55
"encoding/json"
6+
"net/http"
67

78
"github.com/mark3labs/mcp-go/mcp"
89
)
@@ -59,6 +60,7 @@ type JSONRPCRequest struct {
5960
ID mcp.RequestId `json:"id"`
6061
Method string `json:"method"`
6162
Params any `json:"params,omitempty"`
63+
Header http.Header `json:"-"`
6264
}
6365

6466
// JSONRPCResponse represents a JSON-RPC 2.0 response message.

client/transport/streamable_http.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ func (c *StreamableHTTP) SendRequest(
265265
ctx, cancel := c.contextAwareOfClientClose(ctx)
266266
defer cancel()
267267

268-
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
268+
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", request.Header)
269269
if err != nil {
270270
if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
271271
// If the request is initialize, should not return a SessionTerminated error
@@ -346,13 +346,19 @@ func (c *StreamableHTTP) sendHTTP(
346346
method string,
347347
body io.Reader,
348348
acceptType string,
349+
header http.Header,
349350
) (resp *http.Response, err error) {
350351
// Create HTTP request
351352
req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
352353
if err != nil {
353354
return nil, fmt.Errorf("failed to create request: %w", err)
354355
}
355356

357+
// request headers
358+
if header != nil {
359+
req.Header = header
360+
}
361+
356362
// Set headers
357363
req.Header.Set("Content-Type", "application/json")
358364
req.Header.Set("Accept", acceptType)
@@ -375,7 +381,7 @@ func (c *StreamableHTTP) sendHTTP(
375381
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
376382
if err != nil {
377383
// If we get an authorization error, return a specific error that can be handled by the client
378-
if err.Error() == "no valid token available, authorization required" {
384+
if errors.Is(err, ErrOAuthAuthorizationRequired) {
379385
return nil, &OAuthAuthorizationRequiredError{
380386
Handler: c.oauthHandler,
381387
}
@@ -546,7 +552,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
546552
ctx, cancel := c.contextAwareOfClientClose(ctx)
547553
defer cancel()
548554

549-
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
555+
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", nil)
550556
if err != nil {
551557
return fmt.Errorf("failed to send request: %w", err)
552558
}
@@ -642,7 +648,7 @@ var (
642648
)
643649

644650
func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
645-
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
651+
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream", nil)
646652
if err != nil {
647653
return fmt.Errorf("failed to send request: %w", err)
648654
}
@@ -757,7 +763,7 @@ func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSO
757763
ctx, cancel := c.contextAwareOfClientClose(ctx)
758764
defer cancel()
759765

760-
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json, text/event-stream")
766+
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json, text/event-stream", nil)
761767
if err != nil {
762768
c.logger.Errorf("failed to send response to server: %v", err)
763769
return

client/transport/streamable_http_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ func startMockStreamableHTTPServer() (string, func()) {
7070
"jsonrpc": "2.0",
7171
"id": request["id"],
7272
"result": request,
73+
"headers": r.Header,
7374
}); err != nil {
7475
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
7576
return
@@ -122,6 +123,24 @@ func startMockStreamableHTTPServer() (string, func()) {
122123
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
123124
return
124125
}
126+
case "debug/echo_header":
127+
// Check session ID
128+
if r.Header.Get("Mcp-Session-Id") != sessionID {
129+
http.Error(w, "Invalid session ID", http.StatusNotFound)
130+
return
131+
}
132+
133+
// Echo back the request headers as the response result
134+
w.Header().Set("Content-Type", "application/json")
135+
w.WriteHeader(http.StatusOK)
136+
if err := json.NewEncoder(w).Encode(map[string]any{
137+
"jsonrpc": "2.0",
138+
"id": request["id"],
139+
"result": r.Header,
140+
}); err != nil {
141+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
142+
return
143+
}
125144
}
126145
})
127146

@@ -215,6 +234,46 @@ func TestStreamableHTTP(t *testing.T) {
215234
}
216235
})
217236

237+
t.Run("SendRequestWithHeader", func(t *testing.T) {
238+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
239+
defer cancel()
240+
241+
params := map[string]any{
242+
"string": "hello world",
243+
"array": []any{1, 2, 3},
244+
}
245+
246+
hdr := http.Header{"X-Test-Header": {"test-header-value"}}
247+
request := JSONRPCRequest{
248+
JSONRPC: "2.0",
249+
ID: mcp.NewRequestId(int64(1)),
250+
Method: "debug/echo_header",
251+
Params: params,
252+
Header: hdr,
253+
}
254+
255+
// Send the request
256+
response, err := trans.SendRequest(ctx, request)
257+
if err != nil {
258+
t.Fatalf("SendRequest failed: %v", err)
259+
}
260+
261+
// Parse the result to verify echo
262+
var result map[string]any
263+
if err := json.Unmarshal(response.Result, &result); err != nil {
264+
t.Fatalf("Failed to unmarshal result: %v", err)
265+
}
266+
267+
if headerValues, ok := result["X-Test-Header"].([]any); !ok || len(headerValues) == 0 || headerValues[0] != "test-header-value" {
268+
t.Errorf("Expected X-Test-Header to be ['test-header-value'], got %v", result["X-Test-Header"])
269+
}
270+
271+
// Verify system headers are still present
272+
if contentType, ok := result["Content-Type"].([]any); !ok || len(contentType) == 0 {
273+
t.Errorf("Expected Content-Type header to be preserved")
274+
}
275+
})
276+
218277
t.Run("SendRequestWithTimeout", func(t *testing.T) {
219278
// Create a context that's already canceled
220279
ctx, cancel := context.WithCancel(context.Background())

0 commit comments

Comments
 (0)