Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions server/hooks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 17 additions & 1 deletion server/internal/gen/hooks.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered.
type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession)

// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered.
type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession)

// BeforeAnyHookFunc is a function that is called after the request is
// parsed but before the method is called.
Expand Down Expand Up @@ -63,7 +65,8 @@ type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.
{{end}}

type Hooks struct {
OnRegisterSession []OnRegisterSessionHookFunc
OnRegisterSession []OnRegisterSessionHookFunc
OnUnregisterSession []OnUnregisterSessionHookFunc
OnBeforeAny []BeforeAnyHookFunc
OnSuccess []OnSuccessHookFunc
OnError []OnErrorHookFunc
Expand Down Expand Up @@ -183,6 +186,19 @@ func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) {
}
}

func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) {
c.OnUnregisterSession = append(c.OnUnregisterSession, hook)
}

func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) {
if c == nil {
return
}
for _, hook := range c.OnUnregisterSession {
hook(ctx, session)
}
}

{{- range .}}
func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) {
c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook)
Expand Down
4 changes: 3 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,11 @@ func (s *MCPServer) RegisterSession(

// UnregisterSession removes from storage session that is shut down.
func (s *MCPServer) UnregisterSession(
ctx context.Context,
sessionID string,
) {
s.sessions.Delete(sessionID)
session, _ := s.sessions.LoadAndDelete(sessionID)
s.hooks.UnregisterSession(ctx, session.(ClientSession))
}

// sendNotificationToAllClients sends a notification to all the currently active clients.
Expand Down
76 changes: 72 additions & 4 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ func TestMCPServer_Tools(t *testing.T) {
}`))
tt.validate(t, notifications, toolsList.(mcp.JSONRPCMessage))
})

}
}

Expand Down Expand Up @@ -725,8 +724,8 @@ func TestMCPServer_HandleInvalidMessages(t *testing.T) {
message: `{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`,
expectedErr: mcp.INVALID_REQUEST,
validateErr: func(t *testing.T, err error) {
var unparseableErr = &UnparseableMessageError{}
var ok = errors.As(err, &unparseableErr)
unparseableErr := &UnparseableMessageError{}
ok := errors.As(err, &unparseableErr)
assert.True(t, ok, "Error should be UnparseableMessageError")
assert.Equal(t, mcp.MethodInitialize, unparseableErr.GetMethod())
assert.Equal(t, json.RawMessage(`{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`), unparseableErr.GetMessage())
Expand Down Expand Up @@ -1125,7 +1124,6 @@ func TestMCPServer_ResourceTemplates(t *testing.T) {
assert.Equal(t, "test://something/test-resource/a/b/c", resultContent.URI)
assert.Equal(t, "text/plain", resultContent.MIMEType)
assert.Equal(t, "test content: something", resultContent.Text)

})
}

Expand Down Expand Up @@ -1353,6 +1351,76 @@ func TestMCPServer_WithHooks(t *testing.T) {
assert.IsType(t, afterPingData[0].res, onSuccessData[0].res, "OnSuccess result should be same type as AfterPing result")
}

func TestMCPServer_SessionHooks(t *testing.T) {
var (
registerCalled bool
unregisterCalled bool

registeredContext context.Context
unregisteredContext context.Context

registeredSession ClientSession
unregisteredSession ClientSession
)

hooks := &Hooks{}
hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) {
registerCalled = true
registeredContext = ctx
registeredSession = session
})
hooks.AddOnUnregisterSession(func(ctx context.Context, session ClientSession) {
unregisterCalled = true
unregisteredContext = ctx
unregisteredSession = session
})

server := NewMCPServer(
"test-server",
"1.0.0",
WithHooks(hooks),
)

testSession := &fakeSession{
sessionID: "test-session-id",
notificationChannel: make(chan mcp.JSONRPCNotification, 5),
initialized: false,
}

ctx := context.WithoutCancel(context.Background())
err := server.RegisterSession(ctx, testSession)
require.NoError(t, err)

assert.True(t, registerCalled, "Register session hook was not called")
assert.Equal(t, testSession.SessionID(), registeredSession.SessionID(),
"Register hook received wrong session")

server.UnregisterSession(ctx, testSession.SessionID())

assert.True(t, unregisterCalled, "Unregister session hook was not called")
assert.Equal(t, testSession.SessionID(), unregisteredSession.SessionID(),
"Unregister hook received wrong session")

assert.Equal(t, ctx, unregisteredContext, "Unregister hook received wrong context")
assert.Equal(t, ctx, registeredContext, "Register hook received wrong context")
}

func TestMCPServer_SessionHooks_NilHooks(t *testing.T) {
server := NewMCPServer("test-server", "1.0.0")

testSession := &fakeSession{
sessionID: "test-session-id",
notificationChannel: make(chan mcp.JSONRPCNotification, 5),
initialized: false,
}

ctx := context.WithoutCancel(context.Background())
err := server.RegisterSession(ctx, testSession)
require.NoError(t, err)

server.UnregisterSession(ctx, testSession.SessionID())
}

func TestMCPServer_WithRecover(t *testing.T) {
panicToolHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
panic("test panic")
Expand Down
40 changes: 21 additions & 19 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,20 @@ var _ ClientSession = (*sseSession)(nil)
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
server *MCPServer
baseURL string
basePath string
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc
server *MCPServer
baseURL string
basePath string
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc

keepAlive bool
keepAliveInterval time.Duration
mu sync.RWMutex

mu sync.RWMutex
}

// SSEOption defines a function type for configuring SSEServer
Expand Down Expand Up @@ -161,12 +161,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
s := &SSEServer{
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
}

// Apply all options
Expand Down Expand Up @@ -259,7 +259,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError)
return
}
defer s.server.UnregisterSession(sessionID)
defer s.server.UnregisterSession(r.Context(), sessionID)

// Start notification handler for this session
go func() {
Expand Down Expand Up @@ -310,7 +310,6 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}()
}


// Send the initial endpoint event
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID))
flusher.Flush()
Expand Down Expand Up @@ -439,6 +438,7 @@ func (s *SSEServer) SendEventToSession(
return fmt.Errorf("event queue full")
}
}

func (s *SSEServer) GetUrlPath(input string) (string, error) {
parse, err := url.Parse(input)
if err != nil {
Expand All @@ -450,6 +450,7 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) {
func (s *SSEServer) CompleteSseEndpoint() string {
return s.baseURL + s.basePath + s.sseEndpoint
}

func (s *SSEServer) CompleteSsePath() string {
path, err := s.GetUrlPath(s.CompleteSseEndpoint())
if err != nil {
Expand All @@ -461,6 +462,7 @@ func (s *SSEServer) CompleteSsePath() string {
func (s *SSEServer) CompleteMessageEndpoint() string {
return s.baseURL + s.basePath + s.messageEndpoint
}

func (s *SSEServer) CompleteMessagePath() string {
path, err := s.GetUrlPath(s.CompleteMessageEndpoint())
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion server/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (s *StdioServer) Listen(
if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
return fmt.Errorf("register session: %w", err)
}
defer s.server.UnregisterSession(stdioSessionInstance.SessionID())
defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
ctx = s.server.WithContext(ctx, &stdioSessionInstance)

// Add in any custom context.
Expand Down