Skip to content

Commit 4313aa1

Browse files
committed
fix early cancel
1 parent b6ca548 commit 4313aa1

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

client/transport/streamable_http.go

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ func (c *StreamableHTTP) SendRequest(
220220
return nil, fmt.Errorf("failed to marshal request: %w", err)
221221
}
222222

223+
ctx, cancel := c.contextAwareOfClientClose(ctx)
224+
defer cancel()
225+
223226
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
224227
if err != nil {
225228
if errors.Is(err, errSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
@@ -295,18 +298,6 @@ func (c *StreamableHTTP) sendHTTP(
295298
body io.Reader,
296299
acceptType string,
297300
) (resp *http.Response, err error) {
298-
// Create a combined context that could be canceled when the client is closed
299-
newCtx, cancel := context.WithCancel(ctx)
300-
defer cancel()
301-
go func() {
302-
select {
303-
case <-c.closed:
304-
cancel()
305-
case <-newCtx.Done():
306-
// The original context was canceled, no need to do anything
307-
}
308-
}()
309-
ctx = newCtx
310301

311302
// Create HTTP request
312303
req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
@@ -478,6 +469,9 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
478469
}
479470

480471
// Create HTTP request
472+
ctx, cancel := c.contextAwareOfClientClose(ctx)
473+
defer cancel()
474+
481475
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
482476
if err != nil {
483477
return fmt.Errorf("failed to send request: %w", err)
@@ -555,7 +549,9 @@ var (
555549

556550
func (c *StreamableHTTP) createGETConnectionToServer() error {
557551

558-
ctx := context.Background() // the sendHTTP will be automatically canceled when the client is closed
552+
ctx, cancel := c.contextAwareOfClientClose(context.Background())
553+
defer cancel()
554+
559555
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
560556
if err != nil {
561557
return fmt.Errorf("failed to send request: %w", err)
@@ -585,3 +581,17 @@ func (c *StreamableHTTP) createGETConnectionToServer() error {
585581

586582
return nil
587583
}
584+
585+
func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
586+
newCtx, cancel := context.WithCancel(ctx)
587+
go func() {
588+
select {
589+
case <-c.closed:
590+
cancel()
591+
case <-newCtx.Done():
592+
// The original context was canceled
593+
cancel()
594+
}
595+
}()
596+
return newCtx, cancel
597+
}

0 commit comments

Comments
 (0)