Skip to content

Commit c7e9b30

Browse files
committed
feat: add custom handler support for all MCP server methods
- Add custom handler fields to MCPServer struct for all basic MCP methods - Implement custom handler logic in all handle* methods with proper error handling - Support custom handlers for: Initialize, Ping, SetLevel, ListResources, ListResourceTemplates, ReadResource, ListPrompts, GetPrompt, ListTools, CallTool, and Notification methods - Maintain backward compatibility by falling back to default behavior when custom handlers are not set - Enable more flexible server customization and middleware integration
1 parent 47e9419 commit c7e9b30

File tree

1 file changed

+138
-4
lines changed

1 file changed

+138
-4
lines changed

server/server.go

Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,19 @@ type MCPServer struct {
171171
paginationLimit *int
172172
sessions sync.Map
173173
hooks *Hooks
174+
175+
// custom handlers for basic methods
176+
InitializeHandler func(ctx context.Context, request mcp.InitializeRequest) (*mcp.InitializeResult, error)
177+
PingHandler func(ctx context.Context, request mcp.PingRequest) (*mcp.EmptyResult, error)
178+
ListResourcesHandler func(ctx context.Context, request mcp.ListResourcesRequest) (*mcp.ListResourcesResult, error)
179+
ListResourceTemplatesHandler func(ctx context.Context, request mcp.ListResourceTemplatesRequest) (*mcp.ListResourceTemplatesResult, error)
180+
ReadResourceHandler func(ctx context.Context, request mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error)
181+
ListPromptsHandler func(ctx context.Context, request mcp.ListPromptsRequest) (*mcp.ListPromptsResult, error)
182+
GetPromptHandler func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error)
183+
ListToolsHandler func(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
184+
CallToolHandler func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
185+
SetLevelHandler func(ctx context.Context, request mcp.SetLevelRequest) (*mcp.EmptyResult, error)
186+
NotificationHandler func(ctx context.Context, notification mcp.JSONRPCNotification)
174187
}
175188

176189
// WithPaginationLimit sets the pagination limit for the server.
@@ -647,9 +660,21 @@ func (s *MCPServer) AddNotificationHandler(
647660

648661
func (s *MCPServer) handleInitialize(
649662
ctx context.Context,
650-
_ any,
663+
id any,
651664
request mcp.InitializeRequest,
652665
) (*mcp.InitializeResult, *requestError) {
666+
if s.InitializeHandler != nil {
667+
result, err := s.InitializeHandler(ctx, request)
668+
if err != nil {
669+
return nil, &requestError{
670+
id: id,
671+
code: mcp.INTERNAL_ERROR,
672+
err: err,
673+
}
674+
}
675+
return result, nil
676+
}
677+
653678
capabilities := mcp.ServerCapabilities{}
654679

655680
// Only add resource capabilities if they're configured
@@ -729,10 +754,21 @@ func (s *MCPServer) protocolVersion(clientVersion string) string {
729754
}
730755

731756
func (s *MCPServer) handlePing(
732-
_ context.Context,
733-
_ any,
734-
_ mcp.PingRequest,
757+
ctx context.Context,
758+
id any,
759+
request mcp.PingRequest,
735760
) (*mcp.EmptyResult, *requestError) {
761+
if s.PingHandler != nil {
762+
result, err := s.PingHandler(ctx, request)
763+
if err != nil {
764+
return nil, &requestError{
765+
id: id,
766+
code: mcp.INTERNAL_ERROR,
767+
err: err,
768+
}
769+
}
770+
return result, nil
771+
}
736772
return &mcp.EmptyResult{}, nil
737773
}
738774

@@ -741,6 +777,18 @@ func (s *MCPServer) handleSetLevel(
741777
id any,
742778
request mcp.SetLevelRequest,
743779
) (*mcp.EmptyResult, *requestError) {
780+
if s.SetLevelHandler != nil {
781+
result, err := s.SetLevelHandler(ctx, request)
782+
if err != nil {
783+
return nil, &requestError{
784+
id: id,
785+
code: mcp.INTERNAL_ERROR,
786+
err: err,
787+
}
788+
}
789+
return result, nil
790+
}
791+
744792
clientSession := ClientSessionFromContext(ctx)
745793
if clientSession == nil || !clientSession.Initialized() {
746794
return nil, &requestError{
@@ -820,6 +868,18 @@ func (s *MCPServer) handleListResources(
820868
id any,
821869
request mcp.ListResourcesRequest,
822870
) (*mcp.ListResourcesResult, *requestError) {
871+
if s.ListResourcesHandler != nil {
872+
result, err := s.ListResourcesHandler(ctx, request)
873+
if err != nil {
874+
return nil, &requestError{
875+
id: id,
876+
code: mcp.INTERNAL_ERROR,
877+
err: err,
878+
}
879+
}
880+
return result, nil
881+
}
882+
823883
s.resourcesMu.RLock()
824884
resources := make([]mcp.Resource, 0, len(s.resources))
825885
for _, entry := range s.resources {
@@ -858,6 +918,18 @@ func (s *MCPServer) handleListResourceTemplates(
858918
id any,
859919
request mcp.ListResourceTemplatesRequest,
860920
) (*mcp.ListResourceTemplatesResult, *requestError) {
921+
if s.ListResourceTemplatesHandler != nil {
922+
result, err := s.ListResourceTemplatesHandler(ctx, request)
923+
if err != nil {
924+
return nil, &requestError{
925+
id: id,
926+
code: mcp.INTERNAL_ERROR,
927+
err: err,
928+
}
929+
}
930+
return result, nil
931+
}
932+
861933
s.resourcesMu.RLock()
862934
templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates))
863935
for _, entry := range s.resourceTemplates {
@@ -894,6 +966,18 @@ func (s *MCPServer) handleReadResource(
894966
id any,
895967
request mcp.ReadResourceRequest,
896968
) (*mcp.ReadResourceResult, *requestError) {
969+
if s.ReadResourceHandler != nil {
970+
result, err := s.ReadResourceHandler(ctx, request)
971+
if err != nil {
972+
return nil, &requestError{
973+
id: id,
974+
code: mcp.INTERNAL_ERROR,
975+
err: err,
976+
}
977+
}
978+
return result, nil
979+
}
980+
897981
s.resourcesMu.RLock()
898982
// First try direct resource handlers
899983
if entry, ok := s.resources[request.Params.URI]; ok {
@@ -972,6 +1056,18 @@ func (s *MCPServer) handleListPrompts(
9721056
id any,
9731057
request mcp.ListPromptsRequest,
9741058
) (*mcp.ListPromptsResult, *requestError) {
1059+
if s.ListPromptsHandler != nil {
1060+
result, err := s.ListPromptsHandler(ctx, request)
1061+
if err != nil {
1062+
return nil, &requestError{
1063+
id: id,
1064+
code: mcp.INTERNAL_ERROR,
1065+
err: err,
1066+
}
1067+
}
1068+
return result, nil
1069+
}
1070+
9751071
s.promptsMu.RLock()
9761072
prompts := make([]mcp.Prompt, 0, len(s.prompts))
9771073
for _, prompt := range s.prompts {
@@ -1010,6 +1106,18 @@ func (s *MCPServer) handleGetPrompt(
10101106
id any,
10111107
request mcp.GetPromptRequest,
10121108
) (*mcp.GetPromptResult, *requestError) {
1109+
if s.GetPromptHandler != nil {
1110+
result, err := s.GetPromptHandler(ctx, request)
1111+
if err != nil {
1112+
return nil, &requestError{
1113+
id: id,
1114+
code: mcp.INTERNAL_ERROR,
1115+
err: err,
1116+
}
1117+
}
1118+
return result, nil
1119+
}
1120+
10131121
s.promptsMu.RLock()
10141122
handler, ok := s.promptHandlers[request.Params.Name]
10151123
s.promptsMu.RUnlock()
@@ -1039,6 +1147,17 @@ func (s *MCPServer) handleListTools(
10391147
id any,
10401148
request mcp.ListToolsRequest,
10411149
) (*mcp.ListToolsResult, *requestError) {
1150+
if s.ListToolsHandler != nil {
1151+
result, err := s.ListToolsHandler(ctx, request)
1152+
if err != nil {
1153+
return nil, &requestError{
1154+
id: id,
1155+
code: mcp.INTERNAL_ERROR,
1156+
err: err,
1157+
}
1158+
}
1159+
return result, nil
1160+
}
10421161
// Get the base tools from the server
10431162
s.toolsMu.RLock()
10441163
tools := make([]mcp.Tool, 0, len(s.tools))
@@ -1129,6 +1248,17 @@ func (s *MCPServer) handleToolCall(
11291248
id any,
11301249
request mcp.CallToolRequest,
11311250
) (*mcp.CallToolResult, *requestError) {
1251+
if s.CallToolHandler != nil {
1252+
result, err := s.CallToolHandler(ctx, request)
1253+
if err != nil {
1254+
return nil, &requestError{
1255+
id: id,
1256+
code: mcp.INTERNAL_ERROR,
1257+
err: err,
1258+
}
1259+
}
1260+
return result, nil
1261+
}
11321262
// First check session-specific tools
11331263
var tool ServerTool
11341264
var ok bool
@@ -1188,6 +1318,10 @@ func (s *MCPServer) handleNotification(
11881318
ctx context.Context,
11891319
notification mcp.JSONRPCNotification,
11901320
) mcp.JSONRPCMessage {
1321+
if s.NotificationHandler != nil {
1322+
s.NotificationHandler(ctx, notification)
1323+
return nil
1324+
}
11911325
s.notificationHandlersMu.RLock()
11921326
handler, ok := s.notificationHandlers[notification.Method]
11931327
s.notificationHandlersMu.RUnlock()

0 commit comments

Comments
 (0)