Skip to content

Commit 7afc2e5

Browse files
feat(server): Add hooks.AddOnUnregisterSession functionality (#175)
Add OnUnregisterSession hook functionality to complement the existing OnRegisterSession hooks, allowing code to run when a client session is being removed from the server. In some cases, the server may want to do additional work when a session has been closed. For example, in the SSE server case where you may end up managing various logs for the duration of the session -- you would want to indicate that the session was finished.
1 parent ee6757f commit 7afc2e5

File tree

6 files changed

+114
-8
lines changed

6 files changed

+114
-8
lines changed

server/hooks.go

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/internal/gen/hooks.go.tmpl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered.
1515
type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession)
1616

17+
// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered.
18+
type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession)
1719

1820
// BeforeAnyHookFunc is a function that is called after the request is
1921
// parsed but before the method is called.
@@ -63,7 +65,8 @@ type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.
6365
{{end}}
6466

6567
type Hooks struct {
66-
OnRegisterSession []OnRegisterSessionHookFunc
68+
OnRegisterSession []OnRegisterSessionHookFunc
69+
OnUnregisterSession []OnUnregisterSessionHookFunc
6770
OnBeforeAny []BeforeAnyHookFunc
6871
OnSuccess []OnSuccessHookFunc
6972
OnError []OnErrorHookFunc
@@ -183,6 +186,19 @@ func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) {
183186
}
184187
}
185188

189+
func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) {
190+
c.OnUnregisterSession = append(c.OnUnregisterSession, hook)
191+
}
192+
193+
func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) {
194+
if c == nil {
195+
return
196+
}
197+
for _, hook := range c.OnUnregisterSession {
198+
hook(ctx, session)
199+
}
200+
}
201+
186202
{{- range .}}
187203
func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) {
188204
c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook)

server/server.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,11 @@ func (s *MCPServer) RegisterSession(
206206

207207
// UnregisterSession removes from storage session that is shut down.
208208
func (s *MCPServer) UnregisterSession(
209+
ctx context.Context,
209210
sessionID string,
210211
) {
211-
s.sessions.Delete(sessionID)
212+
session, _ := s.sessions.LoadAndDelete(sessionID)
213+
s.hooks.UnregisterSession(ctx, session.(ClientSession))
212214
}
213215

214216
// SendNotificationToAllClients sends a notification to all the currently active clients.

server/server_test.go

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,6 @@ func TestMCPServer_Tools(t *testing.T) {
340340
}`))
341341
tt.validate(t, notifications, toolsList.(mcp.JSONRPCMessage))
342342
})
343-
344343
}
345344
}
346345

@@ -794,8 +793,8 @@ func TestMCPServer_HandleInvalidMessages(t *testing.T) {
794793
message: `{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`,
795794
expectedErr: mcp.INVALID_REQUEST,
796795
validateErr: func(t *testing.T, err error) {
797-
var unparseableErr = &UnparseableMessageError{}
798-
var ok = errors.As(err, &unparseableErr)
796+
unparseableErr := &UnparseableMessageError{}
797+
ok := errors.As(err, &unparseableErr)
799798
assert.True(t, ok, "Error should be UnparseableMessageError")
800799
assert.Equal(t, mcp.MethodInitialize, unparseableErr.GetMethod())
801800
assert.Equal(t, json.RawMessage(`{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`), unparseableErr.GetMessage())
@@ -1194,7 +1193,6 @@ func TestMCPServer_ResourceTemplates(t *testing.T) {
11941193
assert.Equal(t, "test://something/test-resource/a/b/c", resultContent.URI)
11951194
assert.Equal(t, "text/plain", resultContent.MIMEType)
11961195
assert.Equal(t, "test content: something", resultContent.Text)
1197-
11981196
})
11991197
}
12001198

@@ -1422,6 +1420,76 @@ func TestMCPServer_WithHooks(t *testing.T) {
14221420
assert.IsType(t, afterPingData[0].res, onSuccessData[0].res, "OnSuccess result should be same type as AfterPing result")
14231421
}
14241422

1423+
func TestMCPServer_SessionHooks(t *testing.T) {
1424+
var (
1425+
registerCalled bool
1426+
unregisterCalled bool
1427+
1428+
registeredContext context.Context
1429+
unregisteredContext context.Context
1430+
1431+
registeredSession ClientSession
1432+
unregisteredSession ClientSession
1433+
)
1434+
1435+
hooks := &Hooks{}
1436+
hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) {
1437+
registerCalled = true
1438+
registeredContext = ctx
1439+
registeredSession = session
1440+
})
1441+
hooks.AddOnUnregisterSession(func(ctx context.Context, session ClientSession) {
1442+
unregisterCalled = true
1443+
unregisteredContext = ctx
1444+
unregisteredSession = session
1445+
})
1446+
1447+
server := NewMCPServer(
1448+
"test-server",
1449+
"1.0.0",
1450+
WithHooks(hooks),
1451+
)
1452+
1453+
testSession := &fakeSession{
1454+
sessionID: "test-session-id",
1455+
notificationChannel: make(chan mcp.JSONRPCNotification, 5),
1456+
initialized: false,
1457+
}
1458+
1459+
ctx := context.WithoutCancel(context.Background())
1460+
err := server.RegisterSession(ctx, testSession)
1461+
require.NoError(t, err)
1462+
1463+
assert.True(t, registerCalled, "Register session hook was not called")
1464+
assert.Equal(t, testSession.SessionID(), registeredSession.SessionID(),
1465+
"Register hook received wrong session")
1466+
1467+
server.UnregisterSession(ctx, testSession.SessionID())
1468+
1469+
assert.True(t, unregisterCalled, "Unregister session hook was not called")
1470+
assert.Equal(t, testSession.SessionID(), unregisteredSession.SessionID(),
1471+
"Unregister hook received wrong session")
1472+
1473+
assert.Equal(t, ctx, unregisteredContext, "Unregister hook received wrong context")
1474+
assert.Equal(t, ctx, registeredContext, "Register hook received wrong context")
1475+
}
1476+
1477+
func TestMCPServer_SessionHooks_NilHooks(t *testing.T) {
1478+
server := NewMCPServer("test-server", "1.0.0")
1479+
1480+
testSession := &fakeSession{
1481+
sessionID: "test-session-id",
1482+
notificationChannel: make(chan mcp.JSONRPCNotification, 5),
1483+
initialized: false,
1484+
}
1485+
1486+
ctx := context.WithoutCancel(context.Background())
1487+
err := server.RegisterSession(ctx, testSession)
1488+
require.NoError(t, err)
1489+
1490+
server.UnregisterSession(ctx, testSession.SessionID())
1491+
}
1492+
14251493
func TestMCPServer_WithRecover(t *testing.T) {
14261494
panicToolHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
14271495
panic("test panic")

server/sse.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
259259
http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError)
260260
return
261261
}
262-
defer s.server.UnregisterSession(sessionID)
262+
defer s.server.UnregisterSession(r.Context(), sessionID)
263263

264264
// Start notification handler for this session
265265
go func() {
@@ -438,6 +438,7 @@ func (s *SSEServer) SendEventToSession(
438438
return fmt.Errorf("event queue full")
439439
}
440440
}
441+
441442
func (s *SSEServer) GetUrlPath(input string) (string, error) {
442443
parse, err := url.Parse(input)
443444
if err != nil {
@@ -449,6 +450,7 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) {
449450
func (s *SSEServer) CompleteSseEndpoint() string {
450451
return s.baseURL + s.basePath + s.sseEndpoint
451452
}
453+
452454
func (s *SSEServer) CompleteSsePath() string {
453455
path, err := s.GetUrlPath(s.CompleteSseEndpoint())
454456
if err != nil {
@@ -460,6 +462,7 @@ func (s *SSEServer) CompleteSsePath() string {
460462
func (s *SSEServer) CompleteMessageEndpoint() string {
461463
return s.baseURL + s.basePath + s.messageEndpoint
462464
}
465+
463466
func (s *SSEServer) CompleteMessagePath() string {
464467
path, err := s.GetUrlPath(s.CompleteMessageEndpoint())
465468
if err != nil {

server/stdio.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ func (s *StdioServer) Listen(
204204
if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
205205
return fmt.Errorf("register session: %w", err)
206206
}
207-
defer s.server.UnregisterSession(stdioSessionInstance.SessionID())
207+
defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
208208
ctx = s.server.WithContext(ctx, &stdioSessionInstance)
209209

210210
// Add in any custom context.

0 commit comments

Comments
 (0)