@@ -220,6 +220,9 @@ func (c *StreamableHTTP) SendRequest(
220220 return nil , fmt .Errorf ("failed to marshal request: %w" , err )
221221 }
222222
223+ ctx , cancel := c .contextAwareOfClientClose (ctx )
224+ defer cancel ()
225+
223226 resp , err := c .sendHTTP (ctx , http .MethodPost , bytes .NewReader (requestBody ), "application/json, text/event-stream" )
224227 if err != nil {
225228 if errors .Is (err , errSessionTerminated ) && request .Method == string (mcp .MethodInitialize ) {
@@ -295,18 +298,6 @@ func (c *StreamableHTTP) sendHTTP(
295298 body io.Reader ,
296299 acceptType string ,
297300) (resp * http.Response , err error ) {
298- // Create a combined context that could be canceled when the client is closed
299- newCtx , cancel := context .WithCancel (ctx )
300- defer cancel ()
301- go func () {
302- select {
303- case <- c .closed :
304- cancel ()
305- case <- newCtx .Done ():
306- // The original context was canceled, no need to do anything
307- }
308- }()
309- ctx = newCtx
310301
311302 // Create HTTP request
312303 req , err := http .NewRequestWithContext (ctx , method , c .serverURL .String (), body )
@@ -478,6 +469,9 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
478469 }
479470
480471 // Create HTTP request
472+ ctx , cancel := c .contextAwareOfClientClose (ctx )
473+ defer cancel ()
474+
481475 resp , err := c .sendHTTP (ctx , http .MethodPost , bytes .NewReader (requestBody ), "application/json, text/event-stream" )
482476 if err != nil {
483477 return fmt .Errorf ("failed to send request: %w" , err )
@@ -555,7 +549,9 @@ var (
555549
556550func (c * StreamableHTTP ) createGETConnectionToServer () error {
557551
558- ctx := context .Background () // the sendHTTP will be automatically canceled when the client is closed
552+ ctx , cancel := c .contextAwareOfClientClose (context .Background ())
553+ defer cancel ()
554+
559555 resp , err := c .sendHTTP (ctx , http .MethodGet , nil , "text/event-stream" )
560556 if err != nil {
561557 return fmt .Errorf ("failed to send request: %w" , err )
@@ -585,3 +581,17 @@ func (c *StreamableHTTP) createGETConnectionToServer() error {
585581
586582 return nil
587583}
584+
585+ func (c * StreamableHTTP ) contextAwareOfClientClose (ctx context.Context ) (context.Context , context.CancelFunc ) {
586+ newCtx , cancel := context .WithCancel (ctx )
587+ go func () {
588+ select {
589+ case <- c .closed :
590+ cancel ()
591+ case <- newCtx .Done ():
592+ // The original context was canceled
593+ cancel ()
594+ }
595+ }()
596+ return newCtx , cancel
597+ }
0 commit comments