forked from mark3labs/mcp-go
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstdio.go
More file actions
293 lines (251 loc) · 8.25 KB
/
stdio.go
File metadata and controls
293 lines (251 loc) · 8.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
package server
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"os"
"os/signal"
"sync/atomic"
"syscall"
"github.com/mark3labs/mcp-go/mcp"
)
// StdioContextFunc is a function that takes an existing context and returns
// a potentially modified context.
// This can be used to inject context values from environment variables,
// for example.
type StdioContextFunc func(ctx context.Context) context.Context
// StdioServer wraps a MCPServer and handles stdio communication.
// It provides a simple way to create command-line MCP servers that
// communicate via standard input/output streams using JSON-RPC messages.
type StdioServer struct {
server *MCPServer
errLogger *log.Logger
contextFunc StdioContextFunc
}
// StdioOption defines a function type for configuring StdioServer
type StdioOption func(*StdioServer)
// WithErrorLogger sets the error logger for the server
func WithErrorLogger(logger *log.Logger) StdioOption {
return func(s *StdioServer) {
s.errLogger = logger
}
}
// WithContextFunc sets a function that will be called to customise the context
// to the server. Note that the stdio server uses the same context for all requests,
// so this function will only be called once per server instance.
func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
return func(s *StdioServer) {
s.contextFunc = fn
}
}
// stdioSession is a static client session, since stdio has only one client.
type stdioSession struct {
notifications chan mcp.JSONRPCNotification
initialized atomic.Bool
}
func (s *stdioSession) SessionID() string {
return "stdio"
}
func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
return s.notifications
}
func (s *stdioSession) Initialize() {
s.initialized.Store(true)
}
func (s *stdioSession) Initialized() bool {
return s.initialized.Load()
}
var _ ClientSession = (*stdioSession)(nil)
var stdioSessionInstance = stdioSession{
notifications: make(chan mcp.JSONRPCNotification, 100),
}
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
// It initializes the server with a default error logger that discards all output.
func NewStdioServer(server *MCPServer) *StdioServer {
return &StdioServer{
server: server,
errLogger: log.New(
os.Stderr,
"",
log.LstdFlags,
), // Default to discarding logs
}
}
// SetErrorLogger configures where error messages from the StdioServer are logged.
// The provided logger will receive all error messages generated during server operation.
func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
s.errLogger = logger
}
// SetContextFunc sets a function that will be called to customise the context
// to the server. Note that the stdio server uses the same context for all requests,
// so this function will only be called once per server instance.
func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
s.contextFunc = fn
}
// handleNotifications continuously processes notifications from the session's notification channel
// and writes them to the provided output. It runs until the context is cancelled.
// Any errors encountered while writing notifications are logged but do not stop the handler.
func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) {
for {
select {
case notification := <-stdioSessionInstance.notifications:
if err := s.writeResponse(notification, stdout); err != nil {
s.errLogger.Printf("Error writing notification: %v", err)
}
case <-ctx.Done():
return
}
}
}
// processInputStream continuously reads and processes messages from the input stream.
// It handles EOF gracefully as a normal termination condition.
// The function returns when either:
// - The context is cancelled (returns context.Err())
// - EOF is encountered (returns nil)
// - An error occurs while reading or processing messages (returns the error)
func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error {
for {
if err := ctx.Err(); err != nil {
return err
}
line, err := s.readNextLine(ctx, reader)
if err != nil {
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error reading input: %v", err)
return err
}
if err := s.processMessage(ctx, line, stdout); err != nil {
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error handling message: %v", err)
return err
}
}
}
// readNextLine reads a single line from the input reader in a context-aware manner.
// It uses channels to make the read operation cancellable via context.
// Returns the read line and any error encountered. If the context is cancelled,
// returns an empty string and the context's error. EOF is returned when the input
// stream is closed.
func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) {
readChan := make(chan string, 1)
errChan := make(chan error, 1)
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
return
default:
line, err := reader.ReadString('\n')
if err != nil {
select {
case errChan <- err:
case <-done:
}
return
}
select {
case readChan <- line:
case <-done:
}
}
}()
select {
case <-ctx.Done():
return "", ctx.Err()
case err := <-errChan:
return "", err
case line := <-readChan:
return line, nil
}
}
// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
// It runs until the context is cancelled or an error occurs.
// Returns an error if there are issues with reading input or writing output.
func (s *StdioServer) Listen(
ctx context.Context,
stdin io.Reader,
stdout io.Writer,
) error {
// Set a static client context since stdio only has one client
if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
return fmt.Errorf("register session: %w", err)
}
defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
ctx = s.server.WithContext(ctx, &stdioSessionInstance)
// Add in any custom context.
if s.contextFunc != nil {
ctx = s.contextFunc(ctx)
}
reader := bufio.NewReader(stdin)
// Start notification handler
go s.handleNotifications(ctx, stdout)
return s.processInputStream(ctx, reader, stdout)
}
// processMessage handles a single JSON-RPC message and writes the response.
// It parses the message, processes it through the wrapped MCPServer, and writes any response.
// Returns an error if there are issues with message processing or response writing.
func (s *StdioServer) processMessage(
ctx context.Context,
line string,
writer io.Writer,
) error {
// Parse the message as raw JSON
var rawMessage json.RawMessage
if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
return s.writeResponse(response, writer)
}
// Handle the message using the wrapped server
response := s.server.HandleMessage(ctx, rawMessage)
// Only write response if there is one (not for notifications)
if response != nil {
if err := s.writeResponse(response, writer); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
}
return nil
}
// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
// Returns an error if marshaling or writing fails.
func (s *StdioServer) writeResponse(
response mcp.JSONRPCMessage,
writer io.Writer,
) error {
responseBytes, err := json.Marshal(response)
if err != nil {
return err
}
// Write response followed by newline
if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
return err
}
return nil
}
// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
// Returns an error if the server encounters any issues during operation.
func ServeStdio(server *MCPServer, opts ...StdioOption) error {
s := NewStdioServer(server)
s.SetErrorLogger(log.New(os.Stderr, "", log.LstdFlags))
for _, opt := range opts {
opt(s)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
return s.Listen(ctx, os.Stdin, os.Stdout)
}