@@ -1314,3 +1314,170 @@ func TestInsecureStatefulSessionIdManager(t *testing.T) {
13141314 }
13151315 })
13161316}
1317+
1318+ func TestStreamableHTTP_SendNotificationToSpecificClient (t * testing.T ) {
1319+ t .Run ("POST session registration enables SendNotificationToSpecificClient" , func (t * testing.T ) {
1320+ hooks := & Hooks {}
1321+ var registeredSessionID string
1322+ var mu sync.Mutex
1323+ var sessionRegistered sync.WaitGroup
1324+ sessionRegistered .Add (1 )
1325+
1326+ hooks .AddOnRegisterSession (func (ctx context.Context , session ClientSession ) {
1327+ mu .Lock ()
1328+ registeredSessionID = session .SessionID ()
1329+ mu .Unlock ()
1330+ sessionRegistered .Done ()
1331+ })
1332+
1333+ mcpServer := NewMCPServer ("test" , "1.0.0" , WithHooks (hooks ))
1334+ testServer := NewTestStreamableHTTPServer (mcpServer )
1335+ defer testServer .Close ()
1336+
1337+ // Send initialize request to register session
1338+ resp , err := postJSON (testServer .URL , initRequest )
1339+ if err != nil {
1340+ t .Fatalf ("Failed to send initialize request: %v" , err )
1341+ }
1342+ defer resp .Body .Close ()
1343+
1344+ if resp .StatusCode != http .StatusOK {
1345+ t .Fatalf ("Expected status 200, got %d" , resp .StatusCode )
1346+ }
1347+
1348+ // Get session ID from response header
1349+ sessionID := resp .Header .Get (HeaderKeySessionID )
1350+ if sessionID == "" {
1351+ t .Fatal ("Expected session ID in response header" )
1352+ }
1353+
1354+ // Wait for session registration
1355+ done := make (chan struct {})
1356+ go func () {
1357+ sessionRegistered .Wait ()
1358+ close (done )
1359+ }()
1360+
1361+ select {
1362+ case <- done :
1363+ // Session registered successfully
1364+ case <- time .After (2 * time .Second ):
1365+ t .Fatal ("Timeout waiting for session registration" )
1366+ }
1367+
1368+ mu .Lock ()
1369+ if registeredSessionID != sessionID {
1370+ t .Errorf ("Expected registered session ID %s, got %s" , sessionID , registeredSessionID )
1371+ }
1372+ mu .Unlock ()
1373+
1374+ // Now test SendNotificationToSpecificClient
1375+ err = mcpServer .SendNotificationToSpecificClient (sessionID , "test/notification" , map [string ]any {
1376+ "message" : "test notification" ,
1377+ })
1378+ if err != nil {
1379+ t .Errorf ("SendNotificationToSpecificClient failed: %v" , err )
1380+ }
1381+ })
1382+
1383+ t .Run ("Session reuse for non-initialize requests" , func (t * testing.T ) {
1384+ mcpServer := NewMCPServer ("test" , "1.0.0" )
1385+
1386+ // Add a tool that sends a notification
1387+ mcpServer .AddTool (mcp .NewTool ("notify_tool" ), func (ctx context.Context , req mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1388+ session := ClientSessionFromContext (ctx )
1389+ if session == nil {
1390+ return mcp .NewToolResultError ("no session in context" ), nil
1391+ }
1392+
1393+ // Try to send notification to specific client
1394+ server := ServerFromContext (ctx )
1395+ err := server .SendNotificationToSpecificClient (session .SessionID (), "tool/notification" , map [string ]any {
1396+ "from" : "tool" ,
1397+ })
1398+ if err != nil {
1399+ return mcp .NewToolResultError (fmt .Sprintf ("notification failed: %v" , err )), nil
1400+ }
1401+
1402+ return mcp .NewToolResultText ("notification sent" ), nil
1403+ })
1404+
1405+ testServer := NewTestStreamableHTTPServer (mcpServer )
1406+ defer testServer .Close ()
1407+
1408+ // Initialize session
1409+ resp , err := postJSON (testServer .URL , initRequest )
1410+ if err != nil {
1411+ t .Fatalf ("Failed to send initialize request: %v" , err )
1412+ }
1413+ sessionID := resp .Header .Get (HeaderKeySessionID )
1414+ resp .Body .Close ()
1415+
1416+ if sessionID == "" {
1417+ t .Fatal ("Expected session ID in response header" )
1418+ }
1419+
1420+ // Give time for registration to complete
1421+ time .Sleep (100 * time .Millisecond )
1422+
1423+ // Call tool with the session ID
1424+ toolCallRequest := map [string ]any {
1425+ "jsonrpc" : "2.0" ,
1426+ "id" : 2 ,
1427+ "method" : "tools/call" ,
1428+ "params" : map [string ]any {
1429+ "name" : "notify_tool" ,
1430+ },
1431+ }
1432+
1433+ jsonBody , _ := json .Marshal (toolCallRequest )
1434+ req , _ := http .NewRequest (http .MethodPost , testServer .URL , bytes .NewBuffer (jsonBody ))
1435+ req .Header .Set ("Content-Type" , "application/json" )
1436+ req .Header .Set (HeaderKeySessionID , sessionID )
1437+
1438+ resp , err = http .DefaultClient .Do (req )
1439+ if err != nil {
1440+ t .Fatalf ("Failed to call tool: %v" , err )
1441+ }
1442+ defer resp .Body .Close ()
1443+
1444+ bodyBytes , _ := io .ReadAll (resp .Body )
1445+ bodyStr := string (bodyBytes )
1446+
1447+ // Response might be SSE format if notification was sent
1448+ var toolResponse jsonRPCResponse
1449+ if strings .HasPrefix (bodyStr , "event: message" ) {
1450+ // Parse SSE format
1451+ lines := strings .Split (bodyStr , "\n " )
1452+ for _ , line := range lines {
1453+ if strings .HasPrefix (line , "data: " ) {
1454+ jsonData := strings .TrimPrefix (line , "data: " )
1455+ if err := json .Unmarshal ([]byte (jsonData ), & toolResponse ); err == nil {
1456+ break
1457+ }
1458+ }
1459+ }
1460+ } else {
1461+ if err := json .Unmarshal (bodyBytes , & toolResponse ); err != nil {
1462+ t .Fatalf ("Failed to unmarshal response: %v. Body: %s" , err , bodyStr )
1463+ }
1464+ }
1465+
1466+ if toolResponse .Error != nil {
1467+ t .Errorf ("Tool call failed: %v" , toolResponse .Error )
1468+ }
1469+
1470+ // Verify the tool result indicates success
1471+ if result , ok := toolResponse .Result ["content" ].([]any ); ok {
1472+ if len (result ) > 0 {
1473+ if content , ok := result [0 ].(map [string ]any ); ok {
1474+ if text , ok := content ["text" ].(string ); ok {
1475+ if text != "notification sent" {
1476+ t .Errorf ("Expected 'notification sent', got %s" , text )
1477+ }
1478+ }
1479+ }
1480+ }
1481+ }
1482+ })
1483+ }
0 commit comments