diff --git a/README.md b/README.md index 049efed..1b7f8f7 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,7 @@ mcp tools npx -y @modelcontextprotocol/server-filesystem ~ Uses HTTP and Server-Sent Events (SSE) to communicate with an MCP server via JSON-RPC 2.0. This is useful for connecting to remote servers that implement the MCP protocol. ```bash -mcp tools http://127.0.0.1:3001 +mcp tools http://localhost:3001/sse # Example: Use the everything sample server # docker run -p 3001:3001 --rm -it tzolov/mcp-everything-server:v1 @@ -340,7 +340,7 @@ mcp web https://ne.tools The web interface includes: - A sidebar listing all available tools, resources, and prompts -- Form-based and JSON-based parameter editing +- Form-based and JSON-based parameter editing - Formatted and raw JSON response views - Interactive parameter forms automatically generated from tool schemas - Support for complex parameter types (arrays, objects, nested structures) @@ -626,7 +626,7 @@ mcp proxy tool count_lines "Counts lines in a file" "file:string" -e "wc -l < \" The guard mode allows you to restrict access to specific tools, prompts, and resources based on pattern matching. This is useful for security purposes when: - Restricting potentially dangerous operations (file writes, deletions, etc.) -- Limiting the capabilities of AI assistants or applications +- Limiting the capabilities of AI assistants or applications - Providing read-only access to sensitive systems - Creating sandboxed environments for testing or demonstrations diff --git a/cmd/mcptools/commands/call.go b/cmd/mcptools/commands/call.go index ef548e3..4039400 100644 --- a/cmd/mcptools/commands/call.go +++ b/cmd/mcptools/commands/call.go @@ -1,11 +1,13 @@ package commands import ( + "context" "encoding/json" "fmt" "os" "strings" + "github.com/mark3labs/mcp-go/mcp" "github.com/spf13/cobra" ) @@ -104,11 +106,36 @@ func CallCmd() *cobra.Command { switch entityType { case EntityTypeTool: - resp, execErr = mcpClient.CallTool(entityName, params) + var toolResponse *mcp.CallToolResult + request := mcp.CallToolRequest{} + request.Params.Name = entityName + request.Params.Arguments = params + toolResponse, execErr = mcpClient.CallTool(context.Background(), request) + if execErr == nil && toolResponse != nil { + resp = ConvertJSONToMap(toolResponse) + } else { + resp = map[string]any{} + } case EntityTypeRes: - resp, execErr = mcpClient.ReadResource(entityName) + var resourceResponse *mcp.ReadResourceResult + request := mcp.ReadResourceRequest{} + request.Params.URI = entityName + resourceResponse, execErr = mcpClient.ReadResource(context.Background(), request) + if execErr == nil && resourceResponse != nil { + resp = ConvertJSONToMap(resourceResponse) + } else { + resp = map[string]any{} + } case EntityTypePrompt: - resp, execErr = mcpClient.GetPrompt(entityName) + var promptResponse *mcp.GetPromptResult + request := mcp.GetPromptRequest{} + request.Params.Name = entityName + promptResponse, execErr = mcpClient.GetPrompt(context.Background(), request) + if execErr == nil && promptResponse != nil { + resp = ConvertJSONToMap(promptResponse) + } else { + resp = map[string]any{} + } default: fmt.Fprintf(os.Stderr, "Error: unsupported entity type: %s\n", entityType) os.Exit(1) diff --git a/cmd/mcptools/commands/get_prompt.go b/cmd/mcptools/commands/get_prompt.go index 16d9c50..30ff06a 100644 --- a/cmd/mcptools/commands/get_prompt.go +++ b/cmd/mcptools/commands/get_prompt.go @@ -1,10 +1,12 @@ package commands import ( + "context" "encoding/json" "fmt" "os" + "github.com/mark3labs/mcp-go/mcp" "github.com/spf13/cobra" ) @@ -81,8 +83,18 @@ func GetPromptCmd() *cobra.Command { os.Exit(1) } - resp, execErr := mcpClient.GetPrompt(promptName) - if formatErr := FormatAndPrintResponse(thisCmd, resp, execErr); formatErr != nil { + request := mcp.GetPromptRequest{} + request.Params.Name = promptName + resp, execErr := mcpClient.GetPrompt(context.Background(), request) + + var responseMap map[string]any + if execErr == nil && resp != nil { + responseMap = ConvertJSONToMap(resp) + } else { + responseMap = map[string]any{} + } + + if formatErr := FormatAndPrintResponse(thisCmd, responseMap, execErr); formatErr != nil { fmt.Fprintf(os.Stderr, "%v\n", formatErr) os.Exit(1) } diff --git a/cmd/mcptools/commands/guard.go b/cmd/mcptools/commands/guard.go index 5fb763b..b182686 100644 --- a/cmd/mcptools/commands/guard.go +++ b/cmd/mcptools/commands/guard.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/f/mcptools/pkg/alias" - "github.com/f/mcptools/pkg/client" "github.com/f/mcptools/pkg/guard" "github.com/spf13/cobra" ) @@ -36,10 +35,10 @@ Examples: mcp guard --allow tools:read_* --deny edit_*,write_*,create_* npx run @modelcontextprotocol/server-filesystem ~ mcp guard --allow prompts:system_* --deny tools:execute_* npx run @modelcontextprotocol/server-filesystem ~ mcp guard --allow tools:read_* fs # Using an alias - + Patterns can include wildcards: * matches any sequence of characters - + Entity types: tools: filter available tools prompts: filter available prompts @@ -76,7 +75,7 @@ Entity types: if found { fmt.Fprintf(os.Stderr, "Expanding alias '%s' to '%s'\n", aliasName, serverCmd) // Replace the alias with the actual command - parsedArgs = client.ParseCommandString(serverCmd) + parsedArgs = ParseCommandString(serverCmd) } } diff --git a/cmd/mcptools/commands/prompts.go b/cmd/mcptools/commands/prompts.go index c15b79a..a618ea1 100644 --- a/cmd/mcptools/commands/prompts.go +++ b/cmd/mcptools/commands/prompts.go @@ -1,9 +1,11 @@ package commands import ( + "context" "fmt" "os" + "github.com/mark3labs/mcp-go/mcp" "github.com/spf13/cobra" ) @@ -29,8 +31,15 @@ func PromptsCmd() *cobra.Command { os.Exit(1) } - resp, listErr := mcpClient.ListPrompts() - if formatErr := FormatAndPrintResponse(thisCmd, resp, listErr); formatErr != nil { + resp, listErr := mcpClient.ListPrompts(context.Background(), mcp.ListPromptsRequest{}) + + var prompts []any + if listErr == nil && resp != nil { + prompts = ConvertJSONToSlice(resp.Prompts) + } + + promptsMap := map[string]any{"prompts": prompts} + if formatErr := FormatAndPrintResponse(thisCmd, promptsMap, listErr); formatErr != nil { fmt.Fprintf(os.Stderr, "%v\n", formatErr) os.Exit(1) } diff --git a/cmd/mcptools/commands/read_resource.go b/cmd/mcptools/commands/read_resource.go index 5229b97..b9c3af5 100644 --- a/cmd/mcptools/commands/read_resource.go +++ b/cmd/mcptools/commands/read_resource.go @@ -1,9 +1,11 @@ package commands import ( + "context" "fmt" "os" + "github.com/mark3labs/mcp-go/mcp" "github.com/spf13/cobra" ) @@ -69,8 +71,18 @@ func ReadResourceCmd() *cobra.Command { os.Exit(1) } - resp, execErr := mcpClient.ReadResource(resourceName) - if formatErr := FormatAndPrintResponse(thisCmd, resp, execErr); formatErr != nil { + request := mcp.ReadResourceRequest{} + request.Params.URI = resourceName + resp, execErr := mcpClient.ReadResource(context.Background(), request) + + var responseMap map[string]any + if execErr == nil && resp != nil { + responseMap = ConvertJSONToMap(resp) + } else { + responseMap = map[string]any{} + } + + if formatErr := FormatAndPrintResponse(thisCmd, responseMap, execErr); formatErr != nil { fmt.Fprintf(os.Stderr, "%v\n", formatErr) os.Exit(1) } diff --git a/cmd/mcptools/commands/read_resource_test.go b/cmd/mcptools/commands/read_resource_test.go index 4b5ac5e..4583981 100644 --- a/cmd/mcptools/commands/read_resource_test.go +++ b/cmd/mcptools/commands/read_resource_test.go @@ -37,13 +37,11 @@ func TestReadResourceCmdRun_Success(t *testing.T) { // Given: a mock client that returns a successful read resource response mockResponse := map[string]any{ - "result": map[string]any{ - "contents": []any{ - map[string]any{ - "uri": "test://foo", - "mimeType": "text/plain", - "text": "bar", - }, + "contents": []any{ + map[string]any{ + "uri": "test://foo", + "mimeType": "text/plain", + "text": "bar", }, }, } diff --git a/cmd/mcptools/commands/resources.go b/cmd/mcptools/commands/resources.go index 9a79d1c..3bb814c 100644 --- a/cmd/mcptools/commands/resources.go +++ b/cmd/mcptools/commands/resources.go @@ -1,9 +1,11 @@ package commands import ( + "context" "fmt" "os" + "github.com/mark3labs/mcp-go/mcp" "github.com/spf13/cobra" ) @@ -29,8 +31,15 @@ func ResourcesCmd() *cobra.Command { os.Exit(1) } - resp, listErr := mcpClient.ListResources() - if formatErr := FormatAndPrintResponse(thisCmd, resp, listErr); formatErr != nil { + resp, listErr := mcpClient.ListResources(context.Background(), mcp.ListResourcesRequest{}) + + var resources []any + if listErr == nil && resp != nil { + resources = ConvertJSONToSlice(resp.Resources) + } + + resourcesMap := map[string]any{"resources": resources} + if formatErr := FormatAndPrintResponse(thisCmd, resourcesMap, listErr); formatErr != nil { fmt.Fprintf(os.Stderr, "%v\n", formatErr) os.Exit(1) } diff --git a/cmd/mcptools/commands/shell.go b/cmd/mcptools/commands/shell.go index 6ec56c1..0279307 100644 --- a/cmd/mcptools/commands/shell.go +++ b/cmd/mcptools/commands/shell.go @@ -1,6 +1,7 @@ package commands import ( + "context" "encoding/json" "errors" "fmt" @@ -8,7 +9,8 @@ import ( "path/filepath" "strings" - "github.com/f/mcptools/pkg/client" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" "github.com/peterh/liner" "github.com/spf13/cobra" ) @@ -48,18 +50,12 @@ func ShellCmd() *cobra.Command { //nolint:gocyclo os.Exit(1) } - mcpClient, clientErr := CreateClientFunc(parsedArgs, client.CloseTransportAfterExecute(false)) + mcpClient, clientErr := CreateClientFunc(parsedArgs) if clientErr != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", clientErr) os.Exit(1) } - _, listErr := mcpClient.ListTools() - if listErr != nil { - fmt.Fprintf(os.Stderr, "Error connecting to MCP server: %v\n", listErr) - os.Exit(1) - } - fmt.Fprintf(thisCmd.OutOrStdout(), "mcp > MCP Tools Shell (%s)\n", Version) fmt.Fprintf(thisCmd.OutOrStdout(), "mcp > Connected to Server: %s\n", strings.Join(parsedArgs, " ")) fmt.Fprintf(thisCmd.OutOrStdout(), "\nmcp > Type '/h' for help or '/q' to quit\n") @@ -111,19 +107,43 @@ func ShellCmd() *cobra.Command { //nolint:gocyclo switch command { case "tools": - resp, listErr = mcpClient.ListTools() + var listToolsResult *mcp.ListToolsResult + listToolsResult, listErr = mcpClient.ListTools(context.Background(), mcp.ListToolsRequest{}) + + var tools []any + if listErr == nil && listToolsResult != nil { + tools = ConvertJSONToSlice(listToolsResult.Tools) + } + + resp = map[string]any{"tools": tools} if formatErr := FormatAndPrintResponse(thisCmd, resp, listErr); formatErr != nil { fmt.Fprintf(os.Stderr, "%v\n", formatErr) continue } case "resources": - resp, listErr = mcpClient.ListResources() + var listResourcesResult *mcp.ListResourcesResult + listResourcesResult, listErr = mcpClient.ListResources(context.Background(), mcp.ListResourcesRequest{}) + + var resources []any + if listErr == nil && listResourcesResult != nil { + resources = ConvertJSONToSlice(listResourcesResult.Resources) + } + + resp = map[string]any{"resources": resources} if formatErr := FormatAndPrintResponse(thisCmd, resp, listErr); formatErr != nil { fmt.Fprintf(os.Stderr, "%v\n", formatErr) continue } case "prompts": - resp, listErr = mcpClient.ListPrompts() + var listPromptsResult *mcp.ListPromptsResult + listPromptsResult, listErr = mcpClient.ListPrompts(context.Background(), mcp.ListPromptsRequest{}) + + var prompts []any + if listErr == nil && listPromptsResult != nil { + prompts = ConvertJSONToSlice(listPromptsResult.Prompts) + } + + resp = map[string]any{"prompts": prompts} if formatErr := FormatAndPrintResponse(thisCmd, resp, listErr); formatErr != nil { fmt.Fprintf(os.Stderr, "%v\n", formatErr) continue @@ -208,11 +228,36 @@ func callCommand(thisCmd *cobra.Command, mcpClient *client.Client, commandArgs [ switch entityType { case EntityTypeTool: - resp, execErr = mcpClient.CallTool(entityName, params) + var toolResponse *mcp.CallToolResult + request := mcp.CallToolRequest{} + request.Params.Name = entityName + request.Params.Arguments = params + toolResponse, execErr = mcpClient.CallTool(context.Background(), request) + if execErr == nil && toolResponse != nil { + resp = ConvertJSONToMap(toolResponse) + } else { + resp = map[string]any{} + } case EntityTypeRes: - resp, execErr = mcpClient.ReadResource(entityName) + var resourceResponse *mcp.ReadResourceResult + request := mcp.ReadResourceRequest{} + request.Params.URI = entityName + resourceResponse, execErr = mcpClient.ReadResource(context.Background(), request) + if execErr == nil && resourceResponse != nil { + resp = ConvertJSONToMap(resourceResponse) + } else { + resp = map[string]any{} + } case EntityTypePrompt: - resp, execErr = mcpClient.GetPrompt(entityName) + var promptResponse *mcp.GetPromptResult + request := mcp.GetPromptRequest{} + request.Params.Name = entityName + promptResponse, execErr = mcpClient.GetPrompt(context.Background(), request) + if execErr == nil && promptResponse != nil { + resp = ConvertJSONToMap(promptResponse) + } else { + resp = map[string]any{} + } default: fmt.Fprintf(os.Stderr, "Error: unsupported entity type: %s\n", entityType) } diff --git a/cmd/mcptools/commands/shell_test.go b/cmd/mcptools/commands/shell_test.go index 9431dcc..82a147f 100644 --- a/cmd/mcptools/commands/shell_test.go +++ b/cmd/mcptools/commands/shell_test.go @@ -187,7 +187,7 @@ func TestShellCallCommand(t *testing.T) { name: "tool_name without params", input: "test-tool\n/q\n", expectedOutputs: []string{"Tool executed successfully"}, - expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{}}, + expectedParams: map[string]any{"name": "test-tool"}, mockResponses: map[string]map[string]any{ "tools/call": { "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, @@ -198,7 +198,7 @@ func TestShellCallCommand(t *testing.T) { name: "call tool without params", input: "call test-tool\n/q\n", expectedOutputs: []string{"Tool executed successfully"}, - expectedParams: map[string]any{"name": "test-tool", "arguments": map[string]any{}}, + expectedParams: map[string]any{"name": "test-tool"}, mockResponses: map[string]map[string]any{ "tools/call": { "content": []any{map[string]any{"type": "text", "text": "Tool executed successfully"}}, @@ -308,8 +308,9 @@ func TestShellCallCommand(t *testing.T) { } t.Errorf("expected method %q, got %q", method, mockResponse) } - if !reflect.DeepEqual(params, tt.expectedParams) { - t.Errorf("expected params %v, got %v", tt.expectedParams, params) + jsonParams := ConvertJSONToMap(params) + if !reflect.DeepEqual(jsonParams, tt.expectedParams) { + t.Errorf("expected params %v, got %v", tt.expectedParams, jsonParams) } return mockResponse, nil }) diff --git a/cmd/mcptools/commands/test_helpers.go b/cmd/mcptools/commands/test_helpers.go index 6cb19da..3115b99 100644 --- a/cmd/mcptools/commands/test_helpers.go +++ b/cmd/mcptools/commands/test_helpers.go @@ -3,9 +3,14 @@ package commands import ( "bytes" + "context" + "encoding/json" + "fmt" "testing" - "github.com/f/mcptools/pkg/client" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" ) // MockTransport implements the transport.Transport interface for testing. @@ -13,9 +18,40 @@ type MockTransport struct { ExecuteFunc func(method string, params any) (map[string]any, error) } -// Execute calls the mock implementation. -func (m *MockTransport) Execute(method string, params any) (map[string]any, error) { - return m.ExecuteFunc(method, params) +// Start is a no-op for the mock transport. +func (m *MockTransport) Start(_ context.Context) error { + return nil +} + +// SendRequest overrides the default implementation of the transport.SendRequest method. +func (m *MockTransport) SendRequest(_ context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if request.Method == "initialize" { + return &transport.JSONRPCResponse{Result: json.RawMessage(`{}`)}, nil + } + response, err := m.ExecuteFunc(request.Method, request.Params) + if err != nil { + return nil, err + } + responseBytes, err := json.Marshal(response) + if err != nil { + return nil, err + } + fmt.Println("Returning response:", string(responseBytes)) + return &transport.JSONRPCResponse{Result: json.RawMessage(responseBytes)}, nil +} + +// SendNotification is a no-op for the mock transport. +func (m *MockTransport) SendNotification(_ context.Context, _ mcp.JSONRPCNotification) error { + return nil +} + +// SetNotificationHandler is a no-op for the mock transport. +func (m *MockTransport) SetNotificationHandler(_ func(notification mcp.JSONRPCNotification)) { +} + +// Close is a no-op for the mock transport. +func (m *MockTransport) Close() error { + return nil } // setupMockClient creates a mock client with the given execute function and returns cleanup function. @@ -27,10 +63,11 @@ func setupMockClient(executeFunc func(method string, _ any) (map[string]any, err ExecuteFunc: executeFunc, } - mockClient := client.NewWithTransport(mockTransport) + mockClient := client.NewClient(mockTransport) + _, _ = mockClient.Initialize(context.Background(), mcp.InitializeRequest{}) // Override the function that creates clients - CreateClientFunc = func(_ []string, _ ...client.Option) (*client.Client, error) { + CreateClientFunc = func(_ []string, _ ...client.ClientOption) (*client.Client, error) { return mockClient, nil } diff --git a/cmd/mcptools/commands/tools.go b/cmd/mcptools/commands/tools.go index 321992e..77f2a35 100644 --- a/cmd/mcptools/commands/tools.go +++ b/cmd/mcptools/commands/tools.go @@ -1,9 +1,11 @@ package commands import ( + "context" "fmt" "os" + "github.com/mark3labs/mcp-go/mcp" "github.com/spf13/cobra" ) @@ -21,7 +23,6 @@ func ToolsCmd() *cobra.Command { } parsedArgs := ProcessFlags(args) - mcpClient, err := CreateClientFunc(parsedArgs) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) @@ -29,8 +30,15 @@ func ToolsCmd() *cobra.Command { os.Exit(1) } - resp, listErr := mcpClient.ListTools() - if formatErr := FormatAndPrintResponse(thisCmd, resp, listErr); formatErr != nil { + resp, listErr := mcpClient.ListTools(context.Background(), mcp.ListToolsRequest{}) + + var tools []any + if listErr == nil && resp != nil { + tools = ConvertJSONToSlice(resp.Tools) + } + + toolsMap := map[string]any{"tools": tools} + if formatErr := FormatAndPrintResponse(thisCmd, toolsMap, listErr); formatErr != nil { fmt.Fprintf(os.Stderr, "%v\n", formatErr) os.Exit(1) } diff --git a/cmd/mcptools/commands/utils.go b/cmd/mcptools/commands/utils.go index 5c5c1c5..f3a9be0 100644 --- a/cmd/mcptools/commands/utils.go +++ b/cmd/mcptools/commands/utils.go @@ -1,12 +1,18 @@ package commands import ( + "bufio" + "context" + "encoding/json" "fmt" "strings" + "time" "github.com/f/mcptools/pkg/alias" - "github.com/f/mcptools/pkg/client" "github.com/f/mcptools/pkg/jsonutils" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/spf13/cobra" ) @@ -17,36 +23,68 @@ var ( // IsHTTP returns true if the string is a valid HTTP URL. func IsHTTP(str string) bool { - return strings.HasPrefix(str, "http://") || strings.HasPrefix(str, "https://") + return strings.HasPrefix(str, "http://") || strings.HasPrefix(str, "https://") || strings.HasPrefix(str, "localhost:") } // CreateClientFunc is the function used to create MCP clients. // This can be replaced in tests to use a mock transport. -var CreateClientFunc = func(args []string, opts ...client.Option) (*client.Client, error) { +var CreateClientFunc = func(args []string, _ ...client.ClientOption) (*client.Client, error) { if len(args) == 0 { return nil, ErrCommandRequired } - opts = append(opts, client.SetShowServerLogs(ShowServerLogs)) + // opts = append(opts, client.SetShowServerLogs(ShowServerLogs)) // Check if the first argument is an alias if len(args) == 1 { server, found := alias.GetServerCommand(args[0]) if found { - if IsHTTP(server) { - return client.NewHTTP(server), nil - } - cmdParts := client.ParseCommandString(server) - c := client.NewStdio(cmdParts, opts...) - return c, nil + args = ParseCommandString(server) } } + var c *client.Client + var err error + if len(args) == 1 && IsHTTP(args[0]) { - return client.NewHTTP(args[0]), nil + c, err = client.NewSSEMCPClient(args[0]) + if err != nil { + return nil, err + } + err = c.Start(context.Background()) + } else { + c, err = client.NewStdioMCPClient(args[0], nil, args[1:]...) + } + + if err != nil { + return nil, err + } + + stdErr, ok := client.GetStderr(c) + if ok && ShowServerLogs { + go func() { + scanner := bufio.NewScanner(stdErr) + for scanner.Scan() { + fmt.Printf("[>] %s\n", scanner.Text()) + } + }() } - c := client.NewStdio(args, opts...) + done := make(chan error, 1) + + go func() { + _, err := c.Initialize(context.Background(), mcp.InitializeRequest{}) + done <- err + }() + + select { + case err := <-done: + if err != nil { + return nil, fmt.Errorf("init error: %w", err) + } + case <-time.After(10 * time.Second): + return nil, fmt.Errorf("initialization timed out") + } return c, nil } @@ -80,7 +118,7 @@ func ProcessFlags(args []string) []string { // FormatAndPrintResponse formats and prints an MCP response in the format specified by // FormatOption. -func FormatAndPrintResponse(cmd *cobra.Command, resp map[string]any, err error) error { +func FormatAndPrintResponse(cmd *cobra.Command, resp any, err error) error { if err != nil { return fmt.Errorf("error: %w", err) } @@ -100,3 +138,36 @@ func IsValidFormat(format string) bool { format == "pretty" || format == "p" || format == "table" || format == "t" } + +// ParseCommandString splits a command string into separate arguments, +// respecting spaces as argument separators. +// Note: This is a simple implementation that doesn't handle quotes or escapes. +func ParseCommandString(cmdStr string) []string { + if cmdStr == "" { + return nil + } + + return strings.Fields(cmdStr) +} + +// ConvertJSONToSlice converts a JSON serialized object to a slice of any type. +func ConvertJSONToSlice(jsonData any) []any { + if jsonData == nil { + return nil + } + var toolsSlice []any + data, _ := json.Marshal(jsonData) + _ = json.Unmarshal(data, &toolsSlice) + return toolsSlice +} + +// ConvertJSONToMap converts a JSON serialized object to a map of strings to any type. +func ConvertJSONToMap(jsonData any) map[string]any { + if jsonData == nil { + return nil + } + var promptMap map[string]any + data, _ := json.Marshal(jsonData) + _ = json.Unmarshal(data, &promptMap) + return promptMap +} diff --git a/cmd/mcptools/commands/web.go b/cmd/mcptools/commands/web.go index 4406700..8e61e77 100644 --- a/cmd/mcptools/commands/web.go +++ b/cmd/mcptools/commands/web.go @@ -1,6 +1,7 @@ package commands import ( + "context" "encoding/json" "fmt" "net/http" @@ -8,7 +9,8 @@ import ( "strings" "sync" - "github.com/f/mcptools/pkg/client" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" "github.com/spf13/cobra" ) @@ -47,18 +49,12 @@ func WebCmd() *cobra.Command { os.Exit(1) } - mcpClient, clientErr := CreateClientFunc(parsedArgs, client.CloseTransportAfterExecute(false)) + mcpClient, clientErr := CreateClientFunc(parsedArgs) if clientErr != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", clientErr) os.Exit(1) } - _, listErr := mcpClient.ListTools() - if listErr != nil { - fmt.Fprintf(os.Stderr, "Error connecting to MCP server: %v\n", listErr) - os.Exit(1) - } - fmt.Fprintf(thisCmd.OutOrStdout(), "mcp > Starting MCP Tools Web Interface (%s)\n", Version) fmt.Fprintf(thisCmd.OutOrStdout(), "mcp > Connected to Server: %s\n", strings.Join(parsedArgs, " ")) fmt.Fprintf(thisCmd.OutOrStdout(), "mcp > Web server running at http://localhost:%s\n", port) @@ -116,7 +112,7 @@ func handleIndex() http.HandlerFunc { .hidden { display: none; } - + /* Only preserve critical styles that can't be easily done with Tailwind */ #raw-output-container { white-space: pre; @@ -126,44 +122,44 @@ func handleIndex() http.HandlerFunc { - +

Select an item from the sidebar

- + - +
Formatted
Raw JSON
- +
@@ -184,13 +180,13 @@ func handleIndex() http.HandlerFunc { toolsList.appendChild(li); }); } - + // Ensure formatted tab is visible by default document.getElementById('formatted-output-container').classList.remove('hidden'); document.getElementById('raw-output-container').classList.add('hidden'); }) .catch(err => console.error('Error fetching tools:', err)); - + // Fetch and display resources fetch('/api/resources') .then(response => response.json()) @@ -207,7 +203,7 @@ func handleIndex() http.HandlerFunc { } }) .catch(err => console.error('Error fetching resources:', err)); - + // Fetch and display prompts fetch('/api/prompts') .then(response => response.json()) @@ -224,73 +220,73 @@ func handleIndex() http.HandlerFunc { } }) .catch(err => console.error('Error fetching prompts:', err)); - + // Tab switching functionality document.getElementById('form-tab').addEventListener('click', () => { // First update the JSON to match any form changes updateJSONFromForm(); - + // Then switch to form view document.getElementById('form-tab').classList.add('active'); document.getElementById('form-tab').classList.remove('bg-gray-50'); document.getElementById('form-tab').classList.add('bg-white', 'text-blue-600'); - + document.getElementById('json-tab').classList.remove('active'); document.getElementById('json-tab').classList.remove('bg-white', 'text-blue-600'); document.getElementById('json-tab').classList.add('bg-gray-50', 'text-gray-500'); - + document.getElementById('form-container').classList.remove('hidden'); document.getElementById('json-editor-container').classList.add('hidden'); }); - + document.getElementById('json-tab').addEventListener('click', () => { // First update the form to match any JSON changes updateFormFromJSON(); - + // Then switch to JSON view document.getElementById('json-tab').classList.add('active'); document.getElementById('json-tab').classList.remove('bg-gray-50'); document.getElementById('json-tab').classList.add('bg-white', 'text-blue-600'); - + document.getElementById('form-tab').classList.remove('active'); document.getElementById('form-tab').classList.remove('bg-white', 'text-blue-600'); document.getElementById('form-tab').classList.add('bg-gray-50', 'text-gray-500'); - + document.getElementById('json-editor-container').classList.remove('hidden'); document.getElementById('form-container').classList.add('hidden'); }); - + document.getElementById('formatted-tab').addEventListener('click', () => { document.getElementById('formatted-tab').classList.add('active'); document.getElementById('formatted-tab').classList.remove('bg-gray-50'); document.getElementById('formatted-tab').classList.add('bg-white', 'text-blue-600'); - + document.getElementById('raw-tab').classList.remove('active'); document.getElementById('raw-tab').classList.remove('bg-white', 'text-blue-600'); document.getElementById('raw-tab').classList.add('bg-gray-50', 'text-gray-500'); - + document.getElementById('formatted-output-container').classList.remove('hidden'); document.getElementById('raw-output-container').classList.add('hidden'); }); - + document.getElementById('raw-tab').addEventListener('click', () => { document.getElementById('raw-tab').classList.add('active'); document.getElementById('raw-tab').classList.remove('bg-gray-50'); document.getElementById('raw-tab').classList.add('bg-white', 'text-blue-600'); - + document.getElementById('formatted-tab').classList.remove('active'); document.getElementById('formatted-tab').classList.remove('bg-white', 'text-blue-600'); document.getElementById('formatted-tab').classList.add('bg-gray-50', 'text-gray-500'); - + document.getElementById('raw-output-container').classList.remove('hidden'); document.getElementById('formatted-output-container').classList.add('hidden'); }); - + // Add live update to JSON editor with debounce let jsonUpdateTimeout = null; document.getElementById('params-area').addEventListener('input', () => { clearTimeout(jsonUpdateTimeout); - + // Use debounce to avoid excessive updates during typing jsonUpdateTimeout = setTimeout(() => { try { @@ -304,15 +300,15 @@ func handleIndex() http.HandlerFunc { } }, 500); // Wait 500ms after typing stops }); - + // Current tool being edited let currentTool = null; - + // Show tool details function showTool(tool) { currentTool = tool; document.getElementById('main-title').textContent = tool.name; - + // Set and show description const descriptionElement = document.getElementById('tool-description'); if (tool.description) { @@ -321,12 +317,12 @@ func handleIndex() http.HandlerFunc { } else { descriptionElement.classList.add('hidden'); } - + document.getElementById('tool-panel').classList.remove('hidden'); - + // Create form based on schema createFormFromSchema(tool); - + // Set default JSON parameters let defaultParams = {}; if (tool.parameters && tool.parameters.properties) { @@ -339,15 +335,15 @@ func handleIndex() http.HandlerFunc { }); } document.getElementById('params-area').value = JSON.stringify(defaultParams, null, 2); - + // Display initial information about the tool displayFormattedOutput({ tool: tool }); document.getElementById('raw-output-container').textContent = JSON.stringify(tool, null, 2); - + // Set up execute button document.getElementById('execute-btn').onclick = () => { let params = {}; - + // Check if we're using the form or JSON editor if (document.getElementById('form-container').classList.contains('hidden')) { // Using JSON editor @@ -362,23 +358,23 @@ func handleIndex() http.HandlerFunc { params = collectFormValues(tool); document.getElementById('params-area').value = JSON.stringify(params, null, 2); } - + callTool(tool.name, params); }; } - + // Update JSON editor with values from form function updateJSONFromForm() { if (!currentTool) return; - + const params = collectFormValues(currentTool); document.getElementById('params-area').value = JSON.stringify(params, null, 2); } - + // Update form with values from JSON editor function updateFormFromJSON() { if (!currentTool) return; - + try { const params = JSON.parse(document.getElementById('params-area').value); populateFormFromJSON(params); @@ -386,11 +382,11 @@ func handleIndex() http.HandlerFunc { alert('Error parsing JSON: ' + e.message); } } - + // Populate form fields from JSON data function populateFormFromJSON(jsonData) { if (!currentTool || !jsonData) return; - + // Get schema let schema = null; if (currentTool.parameters && currentTool.parameters.properties) { @@ -398,31 +394,31 @@ func handleIndex() http.HandlerFunc { } else if (currentTool.inputSchema && currentTool.inputSchema.properties) { schema = currentTool.inputSchema; } - + if (!schema) return; - + const properties = schema.properties; - + for (const propName in properties) { const prop = properties[propName]; const value = jsonData[propName]; - + if (value === undefined) continue; - + // Handle array of objects separately if (prop.type === 'array' && prop.items && prop.items.type === 'object' && prop.items.properties) { const arrayContainer = document.getElementById('array-container-' + propName); if (!arrayContainer) continue; - + // Clear existing items arrayContainer.innerHTML = ''; - + // Add new items based on the JSON data if (Array.isArray(value)) { value.forEach((itemData, index) => { // Add new item to the DOM addArrayItem(propName, prop.items); - + // Set values for each field for (const fieldName in prop.items.properties) { if (itemData[fieldName] !== undefined) { @@ -442,7 +438,7 @@ func handleIndex() http.HandlerFunc { // Handle regular inputs const input = document.getElementById('param-' + propName); if (!input) continue; - + if (prop.type === 'array' && Array.isArray(value)) { // For textarea arrays, join with newlines input.value = value.join('\n'); @@ -459,12 +455,12 @@ func handleIndex() http.HandlerFunc { } } } - + // Create form inputs based on schema function createFormFromSchema(tool) { const formContainer = document.getElementById('form-container'); formContainer.innerHTML = ''; - + // Check for schema in either parameters or inputSchema let schema = null; if (tool.parameters && tool.parameters.properties) { @@ -472,27 +468,27 @@ func handleIndex() http.HandlerFunc { } else if (tool.inputSchema && tool.inputSchema.properties) { schema = tool.inputSchema; } - + if (!schema) { formContainer.innerHTML = '

No parameters required for this tool.

'; return; } - + const properties = schema.properties; const required = schema.required || []; - + for (const propName in properties) { const prop = properties[propName]; const formGroup = document.createElement('div'); formGroup.className = 'mb-6'; formGroup.dataset.propName = propName; - + // Create label const label = document.createElement('label'); label.htmlFor = 'param-' + propName; label.textContent = propName; label.className = 'block text-sm font-medium text-gray-700 mb-1'; - + if (required.includes(propName)) { const requiredSpan = document.createElement('span'); requiredSpan.className = 'text-red-600 font-bold'; @@ -500,7 +496,7 @@ func handleIndex() http.HandlerFunc { label.appendChild(requiredSpan); } formGroup.appendChild(label); - + // Add description if available if (prop.description) { const description = document.createElement('div'); @@ -508,7 +504,7 @@ func handleIndex() http.HandlerFunc { description.textContent = prop.description; formGroup.appendChild(description); } - + // Handle different types of inputs if (prop.type === 'array') { // Create array container @@ -516,14 +512,14 @@ func handleIndex() http.HandlerFunc { arrayContainer.className = 'mb-4'; arrayContainer.id = 'array-container-' + propName; formGroup.appendChild(arrayContainer); - + // Check if this is an array of objects with schema defined const isObjectArray = prop.items && prop.items.type === 'object' && prop.items.properties; - + if (isObjectArray) { // Store the item schema for later use when adding new items arrayContainer.dataset.itemSchema = JSON.stringify(prop.items); - + // Add button for adding new items const addButton = document.createElement('button'); addButton.type = 'button'; @@ -533,15 +529,15 @@ func handleIndex() http.HandlerFunc { addArrayItem(propName, prop.items); updateJSONFromForm(); // Update JSON when adding items }; - + const arrayActions = document.createElement('div'); arrayActions.className = 'mt-2'; arrayActions.appendChild(addButton); formGroup.appendChild(arrayActions); - + // Make sure to append the formGroup to the DOM before adding items formContainer.appendChild(formGroup); - + // Add initial empty item addArrayItem(propName, prop.items); } else { @@ -552,12 +548,12 @@ func handleIndex() http.HandlerFunc { textarea.className = 'w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 font-mono'; textarea.placeholder = 'Enter one item per line'; textarea.rows = 4; - + // Add event listener to update JSON when textarea changes textarea.addEventListener('input', () => updateJSONFromForm()); - + formGroup.appendChild(textarea); - + formContainer.appendChild(formGroup); } } else { @@ -567,19 +563,19 @@ func handleIndex() http.HandlerFunc { case 'boolean': input = document.createElement('select'); input.className = 'w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500'; - + const trueOption = document.createElement('option'); trueOption.value = 'true'; trueOption.textContent = 'true'; - + const falseOption = document.createElement('option'); falseOption.value = 'false'; falseOption.textContent = 'false'; - + input.appendChild(trueOption); input.appendChild(falseOption); break; - + case 'number': case 'integer': input = document.createElement('input'); @@ -588,19 +584,19 @@ func handleIndex() http.HandlerFunc { if (prop.minimum !== undefined) input.min = prop.minimum; if (prop.maximum !== undefined) input.max = prop.maximum; break; - + case 'object': input = document.createElement('textarea'); input.placeholder = 'Enter JSON object'; input.className = 'w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 font-mono'; input.rows = 4; break; - + default: // string or any other type if (prop.enum) { input = document.createElement('select'); input.className = 'w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500'; - + prop.enum.forEach(option => { const optionEl = document.createElement('option'); optionEl.value = option; @@ -614,32 +610,32 @@ func handleIndex() http.HandlerFunc { if (prop.format === 'password') input.type = 'password'; } } - + input.id = 'param-' + propName; input.name = propName; - + // Add event listener to update JSON when input changes input.addEventListener('input', () => updateJSONFromForm()); if (input.tagName === 'SELECT') { input.addEventListener('change', () => updateJSONFromForm()); } - + formGroup.appendChild(input); formContainer.appendChild(formGroup); } } } - + // Add a new item to an array function addArrayItem(propName, itemSchema) { const container = document.getElementById('array-container-' + propName); const itemIndex = container.children.length; - + // Create item container const itemDiv = document.createElement('div'); itemDiv.className = 'relative p-4 mb-4 bg-gray-50 border border-gray-200 rounded-lg'; itemDiv.dataset.index = itemIndex; - + // Add remove button const removeButton = document.createElement('button'); removeButton.type = 'button'; @@ -653,20 +649,20 @@ func handleIndex() http.HandlerFunc { updateJSONFromForm(); }; itemDiv.appendChild(removeButton); - + // Create form fields based on the item schema if (itemSchema && itemSchema.properties) { for (const fieldName in itemSchema.properties) { const fieldProp = itemSchema.properties[fieldName]; const fieldGroup = document.createElement('div'); fieldGroup.className = 'mb-4'; - + // Label const label = document.createElement('label'); label.htmlFor = 'param-' + propName + '-' + itemIndex + '-' + fieldName; label.textContent = fieldName; label.className = 'block text-sm font-medium text-gray-700 mb-1'; - + if (itemSchema.required && itemSchema.required.includes(fieldName)) { const requiredSpan = document.createElement('span'); requiredSpan.className = 'text-red-600 font-bold'; @@ -674,7 +670,7 @@ func handleIndex() http.HandlerFunc { label.appendChild(requiredSpan); } fieldGroup.appendChild(label); - + // Description if (fieldProp.description) { const description = document.createElement('div'); @@ -682,26 +678,26 @@ func handleIndex() http.HandlerFunc { description.textContent = fieldProp.description; fieldGroup.appendChild(description); } - + // Input let input; switch (fieldProp.type) { case 'boolean': input = document.createElement('select'); input.className = 'w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500'; - + const trueOption = document.createElement('option'); trueOption.value = 'true'; trueOption.textContent = 'true'; - + const falseOption = document.createElement('option'); falseOption.value = 'false'; falseOption.textContent = 'false'; - + input.appendChild(trueOption); input.appendChild(falseOption); break; - + case 'number': case 'integer': input = document.createElement('input'); @@ -710,12 +706,12 @@ func handleIndex() http.HandlerFunc { if (fieldProp.minimum !== undefined) input.min = fieldProp.minimum; if (fieldProp.maximum !== undefined) input.max = fieldProp.maximum; break; - + default: // string, object, or any other type if (fieldProp.enum) { input = document.createElement('select'); input.className = 'w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500'; - + fieldProp.enum.forEach(option => { const optionEl = document.createElement('option'); optionEl.value = option; @@ -729,33 +725,33 @@ func handleIndex() http.HandlerFunc { if (fieldProp.format === 'password') input.type = 'password'; } } - + input.id = 'param-' + propName + '-' + itemIndex + '-' + fieldName; input.name = propName + '-' + itemIndex + '-' + fieldName; input.dataset.field = fieldName; - + // Add event listener to update JSON when item field changes input.addEventListener('input', () => updateJSONFromForm()); if (input.tagName === 'SELECT') { input.addEventListener('change', () => updateJSONFromForm()); } - + fieldGroup.appendChild(input); itemDiv.appendChild(fieldGroup); } } - + container.appendChild(itemDiv); } - + // Update indices for array items after removal function updateArrayItemIndices(propName) { const container = document.getElementById('array-container-' + propName); const items = container.querySelectorAll('.array-item'); - + items.forEach((item, index) => { item.dataset.index = index; - + // Update all input IDs and names within this item const inputs = item.querySelectorAll('input, select, textarea'); inputs.forEach(input => { @@ -765,11 +761,11 @@ func handleIndex() http.HandlerFunc { }); }); } - + // Collect values from form function collectFormValues(tool) { const params = {}; - + // Check for schema in either parameters or inputSchema let schema = null; if (tool.parameters && tool.parameters.properties) { @@ -777,42 +773,42 @@ func handleIndex() http.HandlerFunc { } else if (tool.inputSchema && tool.inputSchema.properties) { schema = tool.inputSchema; } - + if (!schema) { return params; } - + const properties = schema.properties; - + for (const propName in properties) { const prop = properties[propName]; - + if (prop.type === 'array' && prop.items && prop.items.type === 'object' && prop.items.properties) { // Handle array of objects using the specialized UI const container = document.getElementById('array-container-' + propName); if (!container) continue; - + const items = container.querySelectorAll('.array-item'); const arrayValues = []; - + items.forEach(item => { const itemIndex = item.dataset.index; const itemValue = {}; - + // Collect all field values for this item for (const fieldName in prop.items.properties) { const input = document.getElementById('param-' + propName + '-' + itemIndex + '-' + fieldName); if (!input) continue; - + let value = input.value; - + // Convert types appropriately const fieldProp = prop.items.properties[fieldName]; switch (fieldProp.type) { case 'boolean': value = value === 'true'; break; - + case 'number': case 'integer': if (value !== '') { @@ -821,39 +817,39 @@ func handleIndex() http.HandlerFunc { continue; // Skip empty values } break; - + default: // For strings, just use as-is break; } - + // Only add non-empty values if (value !== '' && value !== undefined) { itemValue[fieldName] = value; } } - + // Only add non-empty objects if (Object.keys(itemValue).length > 0) { arrayValues.push(itemValue); } }); - + params[propName] = arrayValues; } else { // Handle regular inputs or simple arrays const input = document.getElementById('param-' + propName); - + if (!input) continue; - + let value = input.value; - + // Convert types appropriately switch (prop.type) { case 'boolean': value = value === 'true'; break; - + case 'number': case 'integer': if (value !== '') { @@ -862,7 +858,7 @@ func handleIndex() http.HandlerFunc { continue; // Skip empty values } break; - + case 'array': if (value) { // Split by new lines and filter empty lines @@ -873,7 +869,7 @@ func handleIndex() http.HandlerFunc { value = []; } break; - + case 'object': if (value) { try { @@ -886,22 +882,22 @@ func handleIndex() http.HandlerFunc { value = {}; } break; - + default: // For strings, just use as-is break; } - + // Only add non-empty values if (value !== '' && value !== undefined) { params[propName] = value; } } } - + return params; } - + // Call a tool with parameters function callTool(name, params) { fetch('/api/call', { @@ -924,18 +920,18 @@ func handleIndex() http.HandlerFunc { }) .catch(err => { document.getElementById('raw-output-container').textContent = 'Error calling tool: ' + err.message; - document.getElementById('formatted-output-container').innerHTML = - '

Error

' + + document.getElementById('formatted-output-container').innerHTML = + '

Error

' + err.message + '
'; }); } - + // Call a resource function callResource(uri) { document.getElementById('main-title').textContent = 'Resource: ' + uri; document.getElementById('tool-description').classList.add('hidden'); document.getElementById('tool-panel').classList.add('hidden'); - + fetch('/api/call', { method: 'POST', headers: { @@ -955,18 +951,18 @@ func handleIndex() http.HandlerFunc { }) .catch(err => { document.getElementById('raw-output-container').textContent = 'Error reading resource: ' + err.message; - document.getElementById('formatted-output-container').innerHTML = - '

Error

' + + document.getElementById('formatted-output-container').innerHTML = + '

Error

' + err.message + '
'; }); } - + // Call a prompt function callPrompt(name) { document.getElementById('main-title').textContent = 'Prompt: ' + name; document.getElementById('tool-description').classList.add('hidden'); document.getElementById('tool-panel').classList.add('hidden'); - + fetch('/api/call', { method: 'POST', headers: { @@ -986,59 +982,59 @@ func handleIndex() http.HandlerFunc { }) .catch(err => { document.getElementById('raw-output-container').textContent = 'Error getting prompt: ' + err.message; - document.getElementById('formatted-output-container').innerHTML = - '

Error

' + + document.getElementById('formatted-output-container').innerHTML = + '

Error

' + err.message + '
'; }); } - + // Display formatted output function displayFormattedOutput(data) { const container = document.getElementById('formatted-output-container'); container.innerHTML = ''; - + if (data.error) { const errorDiv = document.createElement('div'); errorDiv.className = 'result-object'; - + const errorTitle = document.createElement('h3'); errorTitle.textContent = 'Error'; errorDiv.appendChild(errorTitle); - + const errorText = document.createElement('div'); errorText.className = 'result-property'; errorText.textContent = data.error; errorDiv.appendChild(errorText); - + container.appendChild(errorDiv); return; } - + renderObject(data, container); } - + // Recursively render object function renderObject(obj, container, level = 0) { if (!obj || typeof obj !== 'object') return; - + for (const key in obj) { if (!obj.hasOwnProperty(key)) continue; - + const value = obj[key]; - + if (value && typeof value === 'object' && !Array.isArray(value)) { // This is an object const objectDiv = document.createElement('div'); objectDiv.className = 'pl-4 border-l-2 border-blue-400 mb-4'; objectDiv.style.marginLeft = (level * 16) + 'px'; - + const objectTitle = document.createElement('h3'); objectTitle.className = 'text-lg font-semibold text-blue-600 mb-2'; objectTitle.textContent = key; objectDiv.appendChild(objectTitle); - + container.appendChild(objectDiv); - + // Recursively render properties renderObject(value, objectDiv, level + 1); } else if (key === "content" && Array.isArray(value)) { @@ -1046,14 +1042,14 @@ func handleIndex() http.HandlerFunc { const contentDiv = document.createElement('div'); contentDiv.className = 'pl-4 border-l-2 border-blue-400 mb-4'; contentDiv.style.marginLeft = (level * 16) + 'px'; - + const contentTitle = document.createElement('h3'); contentTitle.className = 'text-lg font-semibold text-blue-600 mb-2'; contentTitle.textContent = key; contentDiv.appendChild(contentTitle); - + container.appendChild(contentDiv); - + // Process each content item value.forEach((item, index) => { if (typeof item === 'object') { @@ -1061,12 +1057,12 @@ func handleIndex() http.HandlerFunc { const itemDiv = document.createElement('div'); itemDiv.className = 'pl-4 border-l-2 border-gray-300 mb-2 pb-2'; itemDiv.style.marginLeft = '16px'; - + const itemTitle = document.createElement('h3'); itemTitle.className = 'text-md font-semibold text-gray-700 mb-1'; itemTitle.textContent = 'Item ' + (index + 1); itemDiv.appendChild(itemTitle); - + contentDiv.appendChild(itemDiv); renderObject(item, itemDiv, level + 2); } else if (typeof item === 'string') { @@ -1077,12 +1073,12 @@ func handleIndex() http.HandlerFunc { const itemDiv = document.createElement('div'); itemDiv.className = 'pl-4 border-l-2 border-gray-300 mb-2 pb-2'; itemDiv.style.marginLeft = '16px'; - + const itemTitle = document.createElement('h3'); itemTitle.className = 'text-md font-semibold text-gray-700 mb-1'; itemTitle.textContent = 'Item ' + (index + 1); itemDiv.appendChild(itemTitle); - + contentDiv.appendChild(itemDiv); renderObject(parsedItem, itemDiv, level + 2); } else { @@ -1104,21 +1100,21 @@ func handleIndex() http.HandlerFunc { } } } - + // Helper function to render primitive values function renderPrimitiveValue(container, key, value, level) { const propertyDiv = document.createElement('div'); propertyDiv.className = 'py-1 flex flex-wrap'; propertyDiv.style.marginLeft = (level * 16) + 'px'; - + const nameSpan = document.createElement('span'); nameSpan.className = 'text-gray-600 mr-2 font-medium'; nameSpan.textContent = key + ': '; propertyDiv.appendChild(nameSpan); - + const valueSpan = document.createElement('span'); valueSpan.className = 'font-mono'; - + if (value === null) { valueSpan.classList.add('text-gray-500'); valueSpan.classList.add('italic'); @@ -1127,7 +1123,7 @@ func handleIndex() http.HandlerFunc { valueSpan.textContent = JSON.stringify(value); } else { const type = typeof value; - + if (type === 'string') { valueSpan.classList.add('text-green-600'); } else if (type === 'number') { @@ -1135,14 +1131,14 @@ func handleIndex() http.HandlerFunc { } else if (type === 'boolean') { valueSpan.classList.add('text-red-600'); } - + // Check if string might be parseable JSON if (type === 'string' && value.trim().startsWith('{') && value.trim().endsWith('}')) { try { // Try to parse and pretty print const parsed = JSON.parse(value); valueSpan.textContent = JSON.stringify(parsed, null, 2); - + // Add a special class for JSON strings valueSpan.className = 'font-mono p-3 mt-2 mb-2 block bg-gray-50 border border-gray-200 rounded-md overflow-x-auto whitespace-pre'; } catch (e) { @@ -1153,7 +1149,7 @@ func handleIndex() http.HandlerFunc { valueSpan.textContent = type === 'string' ? '"' + value + '"' : String(value); } } - + propertyDiv.appendChild(valueSpan); container.appendChild(propertyDiv); } @@ -1172,7 +1168,7 @@ func handleTools(cache *MCPClientCache) http.HandlerFunc { //nolint:revive // Parameter r is required by http.HandlerFunc signature return func(w http.ResponseWriter, r *http.Request) { cache.mutex.Lock() - resp, err := cache.client.ListTools() + resp, err := cache.client.ListTools(context.Background(), mcp.ListToolsRequest{}) cache.mutex.Unlock() w.Header().Set("Content-Type", "application/json") @@ -1197,7 +1193,7 @@ func handleResources(cache *MCPClientCache) http.HandlerFunc { //nolint:revive // Parameter r is required by http.HandlerFunc signature return func(w http.ResponseWriter, r *http.Request) { cache.mutex.Lock() - resp, err := cache.client.ListResources() + resp, err := cache.client.ListResources(context.Background(), mcp.ListResourcesRequest{}) cache.mutex.Unlock() w.Header().Set("Content-Type", "application/json") @@ -1222,7 +1218,7 @@ func handlePrompts(cache *MCPClientCache) http.HandlerFunc { //nolint:revive // Parameter r is required by http.HandlerFunc signature return func(w http.ResponseWriter, r *http.Request) { cache.mutex.Lock() - resp, err := cache.client.ListPrompts() + resp, err := cache.client.ListPrompts(context.Background(), mcp.ListPromptsRequest{}) cache.mutex.Unlock() w.Header().Set("Content-Type", "application/json") @@ -1274,11 +1270,24 @@ func handleCall(cache *MCPClientCache) http.HandlerFunc { switch requestData.Type { case "tool": - resp, callErr = cache.client.CallTool(requestData.Name, requestData.Params) + var toolResponse *mcp.CallToolResult + request := mcp.CallToolRequest{} + request.Params.Name = requestData.Name + request.Params.Arguments = requestData.Params + toolResponse, callErr = cache.client.CallTool(context.Background(), request) + resp = ConvertJSONToMap(toolResponse) case "resource": - resp, callErr = cache.client.ReadResource(requestData.Name) + var resourceResponse *mcp.ReadResourceResult + request := mcp.ReadResourceRequest{} + request.Params.URI = requestData.Name + resourceResponse, callErr = cache.client.ReadResource(context.Background(), request) + resp = ConvertJSONToMap(resourceResponse) case "prompt": - resp, callErr = cache.client.GetPrompt(requestData.Name) + var promptResponse *mcp.GetPromptResult + request := mcp.GetPromptRequest{} + request.Params.Name = requestData.Name + promptResponse, callErr = cache.client.GetPrompt(context.Background(), request) + resp = ConvertJSONToMap(promptResponse) default: w.WriteHeader(http.StatusBadRequest) //nolint:errcheck,gosec // No need to handle error from Encode in this context diff --git a/cmd/mcptools/main_test.go b/cmd/mcptools/main_test.go index 75b2180..1be87c8 100644 --- a/cmd/mcptools/main_test.go +++ b/cmd/mcptools/main_test.go @@ -12,7 +12,6 @@ import ( "testing" "github.com/f/mcptools/cmd/mcptools/commands" - "github.com/f/mcptools/pkg/transport" ) const entityTypeValue = "tool" @@ -97,7 +96,7 @@ func (t *MockTransport) Execute(method string, params interface{}) (map[string]i } type Shell struct { - Transport transport.Transport + Transport *MockTransport Reader io.Reader Writer io.Writer Format string diff --git a/go.mod b/go.mod index 41e3f75..96905dc 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/f/mcptools go 1.24.1 require ( + github.com/mark3labs/mcp-go v0.24.1 github.com/peterh/liner v1.2.2 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.1 @@ -15,6 +16,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect @@ -27,6 +29,7 @@ require ( github.com/spf13/cast v1.7.1 // indirect github.com/spf13/pflag v1.0.6 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/sys v0.31.0 // indirect diff --git a/go.sum b/go.sum index 675d4dd..0426ba3 100644 --- a/go.sum +++ b/go.sum @@ -10,12 +10,18 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mark3labs/mcp-go v0.23.1 h1:RzTzZ5kJ+HxwnutKA4rll8N/pKV6Wh5dhCmiJUu5S9I= +github.com/mark3labs/mcp-go v0.23.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.24.1 h1:YV+5X/+W4oBdERLWgiA1uR7AIvenlKJaa5V4hqufI7E= +github.com/mark3labs/mcp-go v0.24.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= @@ -51,6 +57,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/pkg/client/client.go b/pkg/client/client.go deleted file mode 100644 index 48efae3..0000000 --- a/pkg/client/client.go +++ /dev/null @@ -1,128 +0,0 @@ -/* -Package client implements mcp client functionality. -*/ -package client - -import ( - "fmt" - "os" - "strings" - - "github.com/f/mcptools/pkg/transport" -) - -// Client provides an interface to interact with MCP servers. -// It abstracts away the transport mechanism so callers don't need -// to worry about the details of HTTP, stdio, etc. -type Client struct { - transport transport.Transport -} - -// Option provides a way for passing options to the Client to change its -// configuration. -type Option func(*Client) - -// CloseTransportAfterExecute allows keeping a transport alive if supported by -// the transport. -func CloseTransportAfterExecute(closeTransport bool) Option { - return func(c *Client) { - t, ok := c.transport.(interface{ SetCloseAfterExecute(bool) }) - if ok { - t.SetCloseAfterExecute(closeTransport) - } - } -} - -// SetShowServerLogs sets whether to show server logs. -func SetShowServerLogs(showLogs bool) Option { - return func(c *Client) { - t, ok := c.transport.(interface{ SetShowServerLogs(bool) }) - if ok { - t.SetShowServerLogs(showLogs) - } - } -} - -// NewWithTransport creates a new MCP client using the provided transport. -// This allows callers to provide a custom transport implementation. -func NewWithTransport(t transport.Transport) *Client { - return &Client{ - transport: t, - } -} - -// NewStdio creates a new MCP client that communicates with a command -// via stdin/stdout using JSON-RPC. -func NewStdio(command []string, opts ...Option) *Client { - c := &Client{ - transport: transport.NewStdio(command), - } - for _, opt := range opts { - opt(c) - } - return c -} - -// NewHTTP creates a MCP client that communicates with a server via HTTP using JSON-RPC. -func NewHTTP(address string) *Client { - transport, err := transport.NewHTTP(address) - if err != nil { - fmt.Fprintf(os.Stderr, "Error creating HTTP transport: %s\n", err) - os.Exit(1) - } - - return &Client{ - transport: transport, - } -} - -// ListTools retrieves the list of available tools from the MCP server. -func (c *Client) ListTools() (map[string]any, error) { - return c.transport.Execute("tools/list", nil) -} - -// ListResources retrieves the list of available resources from the MCP server. -func (c *Client) ListResources() (map[string]any, error) { - return c.transport.Execute("resources/list", nil) -} - -// ListPrompts retrieves the list of available prompts from the MCP server. -func (c *Client) ListPrompts() (map[string]any, error) { - return c.transport.Execute("prompts/list", nil) -} - -// CallTool calls a specific tool on the MCP server with the given arguments. -func (c *Client) CallTool(toolName string, args map[string]any) (map[string]any, error) { - params := map[string]any{ - "name": toolName, - "arguments": args, - } - return c.transport.Execute("tools/call", params) -} - -// GetPrompt retrieves a specific prompt from the MCP server. -func (c *Client) GetPrompt(promptName string) (map[string]any, error) { - params := map[string]any{ - "name": promptName, - } - return c.transport.Execute("prompts/get", params) -} - -// ReadResource reads the content of a specific resource from the MCP server. -func (c *Client) ReadResource(uri string) (map[string]any, error) { - params := map[string]any{ - "uri": uri, - } - return c.transport.Execute("resources/read", params) -} - -// ParseCommandString splits a command string into separate arguments, -// respecting spaces as argument separators. -// Note: This is a simple implementation that doesn't handle quotes or escapes. -func ParseCommandString(cmdStr string) []string { - if cmdStr == "" { - return nil - } - - return strings.Fields(cmdStr) -} diff --git a/pkg/transport/http.go b/pkg/transport/http.go deleted file mode 100644 index f55f9a5..0000000 --- a/pkg/transport/http.go +++ /dev/null @@ -1,187 +0,0 @@ -package transport - -import ( - "bufio" - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "os" - "strings" - "time" -) - -// HTTP implements the Transport interface by communicating with a MCP server over HTTP using JSON-RPC. -type HTTP struct { - eventCh chan string - address string - debug bool - nextID int -} - -// NewHTTP creates a new Http transport that will execute the given command. -// It communicates with the command using JSON-RPC over HTTP. -// Currently Http transport is implements MCP's Final draft version 2024-11-05, -// https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse -func NewHTTP(address string) (*HTTP, error) { - debug := os.Getenv("MCP_DEBUG") == "1" - - _, uriErr := url.ParseRequestURI(address) - if uriErr != nil { - return nil, fmt.Errorf("invalid address: %w", uriErr) - } - - resp, err := http.Get(address + "/sse") - if err != nil { - return nil, fmt.Errorf("error sending request: %w", err) - } - - eventCh := make(chan string, 1) - - go func() { - defer func() { - if closeErr := resp.Body.Close(); closeErr != nil { - fmt.Fprintf(os.Stderr, "Failed to close response body: %v\n", closeErr) - } - }() - - reader := bufio.NewReader(resp.Body) - for { - line, lineErr := reader.ReadString('\n') - if lineErr != nil { - fmt.Fprintf(os.Stderr, "SSE read error: %v\n", lineErr) - return - } - line = strings.TrimSpace(line) - if debug { - fmt.Fprintf(os.Stderr, "DEBUG: Received SSE: %s\n", line) - } - if strings.HasPrefix(line, "data:") { - data := strings.TrimSpace(line[5:]) - select { - case eventCh <- data: - default: - } - } - } - }() - - // First event we receive from SSE is the message address. We will use this endpoint to keep - // a session alive. - var messageAddress string - select { - case msg := <-eventCh: - messageAddress = msg - case <-time.After(10 * time.Second): - return nil, fmt.Errorf("timeout waiting for SSE response") - } - - client := &HTTP{ - // Use the SSE message address as the base address for the HTTP transport - address: address + "/sse" + messageAddress, - nextID: 1, - debug: debug, - eventCh: eventCh, - } - - // Send initialize request - _, err = client.Execute("initialize", map[string]any{ - "clientInfo": map[string]any{ - "name": "mcp-client", - "version": "0.1.0", - }, - "capabilities": map[string]any{}, - "protocolVersion": "2024-11-05", - }) - if err != nil { - return nil, fmt.Errorf("error sending initialize request: %w", err) - } - - // Send intialized notification - if err := client.send("notifications/initialized", nil); err != nil { - return nil, fmt.Errorf("error sending initialized notification: %w", err) - } - - return client, nil -} - -// Execute implements the Transport via JSON-RPC over HTTP. -func (t *HTTP) Execute(method string, params any) (map[string]any, error) { - if err := t.send(method, params); err != nil { - return nil, err - } - - // After sending the request, we listen the SSE channel for the response - var response Response - select { - case msg := <-t.eventCh: - if unmarshalErr := json.Unmarshal([]byte(msg), &response); unmarshalErr != nil { - return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, msg) - } - case <-time.After(10 * time.Second): - return nil, fmt.Errorf("timeout waiting for SSE response") - } - - if response.Error != nil { - return nil, fmt.Errorf("RPC error %d: %s", response.Error.Code, response.Error.Message) - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response\n") - } - - return response.Result, nil -} - -func (t *HTTP) send(method string, params any) error { - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Connecting to server: %s\n", t.address) - } - - request := Request{ - JSONRPC: "2.0", - Method: method, - ID: t.nextID, - Params: params, - } - t.nextID++ - - requestJSON, err := json.Marshal(request) - if err != nil { - return fmt.Errorf("error marshaling request: %w", err) - } - - requestJSON = append(requestJSON, '\n') - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Sending request: %s\n", string(requestJSON)) - } - - resp, err := http.Post(t.address, "application/json", bytes.NewBuffer(requestJSON)) - if err != nil { - return fmt.Errorf("error sending request: %w", err) - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Sent request to server\n") - } - - defer func() { - if closeErr := resp.Body.Close(); closeErr != nil { - fmt.Fprintf(os.Stderr, "Failed to close response body: %v\n", closeErr) - } - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("error reading response: %w", err) - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Read from server: %s\n", string(body)) - } - - return nil -} diff --git a/pkg/transport/stdio.go b/pkg/transport/stdio.go deleted file mode 100644 index fa451ac..0000000 --- a/pkg/transport/stdio.go +++ /dev/null @@ -1,360 +0,0 @@ -package transport - -import ( - "bufio" - "bytes" - "encoding/json" - "fmt" - "io" - "os" - "os/exec" - "strings" - "time" -) - -// Stdio implements the Transport interface by executing a command -// and communicating with it via stdin/stdout using JSON-RPC. -type Stdio struct { - process *stdioProcess - command []string - nextID int - debug bool - showServerLogs bool -} - -// stdioProcess reflects the state of a running command. -type stdioProcess struct { - stdin io.WriteCloser - stdout io.ReadCloser - cmd *exec.Cmd - stderrBuf *bytes.Buffer - isInitializeSent bool -} - -// NewStdio creates a new Stdio transport that will execute the given command. -// It communicates with the command using JSON-RPC over stdin/stdout. -func NewStdio(command []string) *Stdio { - debug := os.Getenv("MCP_DEBUG") == "1" - return &Stdio{ - command: command, - nextID: 1, - debug: debug, - } -} - -// SetCloseAfterExecute toggles whether the underlying process should be closed -// or kept alive after each call to Execute. -func (t *Stdio) SetCloseAfterExecute(v bool) { - if v { - t.process = nil - } else { - t.process = &stdioProcess{} - } -} - -// SetShowServerLogs toggles whether to print server logs. -func (t *Stdio) SetShowServerLogs(v bool) { - t.showServerLogs = v -} - -// Execute implements the Transport interface by spawning a subprocess -// and communicating with it via JSON-RPC over stdin/stdout. -func (t *Stdio) Execute(method string, params any) (map[string]any, error) { - process := t.process - if process == nil { - process = &stdioProcess{} - } - - if process.cmd == nil { - var err error - process.stdin, process.stdout, process.cmd, process.stderrBuf, err = t.setupCommand() - if err != nil { - return nil, err - } - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Starting initialization\n") - } - - if !process.isInitializeSent { - if initErr := t.initialize(process.stdin, process.stdout); initErr != nil { - t.printStderr(process) - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Initialization failed: %v\n", initErr) - } - return nil, initErr - } - t.printStderr(process) - process.isInitializeSent = true - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Initialization successful, sending method request\n") - } - - request := Request{ - JSONRPC: "2.0", - Method: method, - ID: t.nextID, - Params: params, - } - t.nextID++ - - if sendErr := t.sendRequest(process.stdin, request); sendErr != nil { - return nil, sendErr - } - - response, err := t.readResponse(process.stdout) - t.printStderr(process) - if err != nil { - return nil, err - } - err = t.closeProcess(process) - if err != nil { - return nil, err - } - - return response.Result, nil -} - -// printStderr prints and clears any accumulated stderr output. -func (t *Stdio) printStderr(process *stdioProcess) { - if !t.showServerLogs { - return - } - if process.stderrBuf.Len() > 0 { - for _, line := range strings.SplitAfter(process.stderrBuf.String(), "\n") { - line = strings.TrimSuffix(line, "\n") - if line != "" { - fmt.Fprintf(os.Stderr, "[>] %s\n", line) - } - } - process.stderrBuf.Reset() // Clear the buffer after reading - } -} - -// closeProcess waits for the command to finish, returning any error. -func (t *Stdio) closeProcess(process *stdioProcess) error { - if t.process != nil { - return nil - } - - _ = process.stdin.Close() - - // Wait for the command to finish with a timeout to prevent zombie processes - done := make(chan error, 1) - go func() { - done <- process.cmd.Wait() - }() - - select { - case waitErr := <-done: - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Command completed with err: %v\n", waitErr) - } - - if waitErr != nil && process.stderrBuf.Len() > 0 { - return fmt.Errorf("command error: %w", waitErr) - } - case <-time.After(1 * time.Second): - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Command timed out after 1 seconds\n") - } - // Kill the process if it times out - _ = process.cmd.Process.Kill() - } - - return nil -} - -// setupCommand prepares and starts the command, returning the stdin/stdout pipes and any error. -func (t *Stdio) setupCommand() (stdin io.WriteCloser, stdout io.ReadCloser, cmd *exec.Cmd, stderrBuf *bytes.Buffer, err error) { - if len(t.command) == 0 { - return nil, nil, nil, nil, fmt.Errorf("no command specified for stdio transport") - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Executing command: %v\n", t.command) - } - - cmd = exec.Command(t.command[0], t.command[1:]...) // #nosec G204 - - stdin, err = cmd.StdinPipe() - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("error getting stdin pipe: %w", err) - } - - stdout, err = cmd.StdoutPipe() - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("error getting stdout pipe: %w", err) - } - - stderrBuf = &bytes.Buffer{} - cmd.Stderr = stderrBuf - - if err = cmd.Start(); err != nil { - return nil, nil, nil, nil, fmt.Errorf("error starting command: %w", err) - } - - return stdin, stdout, cmd, stderrBuf, nil -} - -// initialize sends the initialization request and waits for response and then sends the initialized -// notification. -func (t *Stdio) initialize(stdin io.WriteCloser, stdout io.ReadCloser) error { - // Create initialization request with current ID - initRequestID := t.nextID - initRequest := Request{ - JSONRPC: "2.0", - Method: "initialize", - ID: initRequestID, - Params: map[string]any{ - "clientInfo": map[string]any{ - "name": "f/mcptools", - "version": "beta", - }, - "protocolVersion": protocolVersion, - "capabilities": map[string]any{}, - }, - } - t.nextID++ - - if err := t.sendRequest(stdin, initRequest); err != nil { - return fmt.Errorf("init request failed: %w", err) - } - - // readResponse now properly checks for matching response ID - _, err := t.readResponse(stdout) - if err != nil { - return fmt.Errorf("init response failed: %w", err) - } - - // Send initialized notification (notifications don't have IDs) - initNotification := Request{ - JSONRPC: "2.0", - Method: "notifications/initialized", - } - - if sendErr := t.sendRequest(stdin, initNotification); sendErr != nil { - return fmt.Errorf("init notification failed: %w", sendErr) - } - - return nil -} - -// sendRequest sends a JSON-RPC request and returns the marshaled request. -func (t *Stdio) sendRequest(stdin io.WriteCloser, request Request) error { - requestJSON, err := json.Marshal(request) - if err != nil { - return fmt.Errorf("error marshaling request: %w", err) - } - requestJSON = append(requestJSON, '\n') - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Preparing to send request: %s\n", string(requestJSON)) - } - - writer := bufio.NewWriter(stdin) - n, err := writer.Write(requestJSON) - if err != nil { - return fmt.Errorf("error writing bytes to stdin: %w", err) - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Wrote %d bytes\n", n) - } - - if flushErr := writer.Flush(); flushErr != nil { - return fmt.Errorf("error flushing bytes to stdin: %w", flushErr) - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Successfully flushed bytes\n") - } - - return nil -} - -// readResponse reads and parses a JSON-RPC response matching the given request ID. -func (t *Stdio) readResponse(stdout io.ReadCloser) (*Response, error) { - reader := bufio.NewReader(stdout) - - // Keep track of the expected response ID (the last request ID we sent) - expectedID := t.nextID - 1 - - for { - line, err := reader.ReadBytes('\n') - if err != nil { - return nil, fmt.Errorf("error reading from stdout: %w", err) - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Read from stdout: %s", string(line)) - } - - if len(line) == 0 { - return nil, fmt.Errorf("no response from command") - } - - // First check if this is a notification (no ID field) - var msg map[string]interface{} - if err := json.Unmarshal(line, &msg); err != nil { - return nil, fmt.Errorf("error unmarshaling message: %w, response: %s", err, string(line)) - } - - // If it's a notification, display it and continue reading - if methodVal, hasMethod := msg["method"]; hasMethod && msg["id"] == nil { - method, ok := methodVal.(string) - if ok && method == "notifications/message" { - if paramsVal, hasParams := msg["params"].(map[string]interface{}); hasParams { - level, _ := paramsVal["level"].(string) - data, _ := paramsVal["data"].(string) - - // Format and print the notification based on level - switch level { - case "error": - fmt.Fprintf(os.Stderr, "\033[31m[ERROR] %s\033[0m\n", data) // Red - case "warning": - fmt.Fprintf(os.Stderr, "\033[33m[WARNING] %s\033[0m\n", data) // Yellow - case "alert": - fmt.Fprintf(os.Stderr, "\033[35m[ALERT] %s\033[0m\n", data) // Magenta - case "info": - fmt.Fprintf(os.Stderr, "\033[36m[INFO] %s\033[0m\n", data) // Cyan - default: - fmt.Fprintf(os.Stderr, "\033[37m[%s] %s\033[0m\n", level, data) // White for unknown levels - } - } - } else { - // For other notification types - fmt.Fprintf(os.Stderr, "[Notification] %s\n", string(line)) - } - continue - } - - // Parse as a proper response - var response Response - if unmarshalErr := json.Unmarshal(line, &response); unmarshalErr != nil { - return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, string(line)) - } - - // If this response has an ID field and it matches our expected ID, or if it has an error, return it - if response.ID == expectedID || response.Error != nil { - if response.Error != nil { - return nil, fmt.Errorf("RPC error %d: %s", response.Error.Code, response.Error.Message) - } - - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response with matching ID: %d\n", response.ID) - } - - return &response, nil - } - - // Otherwise, this is a response for a different request - if t.debug { - fmt.Fprintf(os.Stderr, "DEBUG: Received response for request ID %d, expecting %d. Continuing to read.\n", - response.ID, expectedID) - } - } -} diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go deleted file mode 100644 index b27b7a8..0000000 --- a/pkg/transport/transport.go +++ /dev/null @@ -1,48 +0,0 @@ -// Package transport contains implementatations for different transport options for MCP. -package transport - -import ( - "encoding/json" - "io" -) - -const ( - protocolVersion = "2024-11-05" -) - -// Transport defines the interface for communicating with MCP servers. -// Implementations should handle the specifics of communication protocols. -type Transport interface { - Execute(method string, params any) (map[string]any, error) -} - -// Request represents a JSON-RPC 2.0 request. -type Request struct { - Params any `json:"params,omitempty"` - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - ID int `json:"id,omitempty"` -} - -// Response represents a JSON-RPC 2.0 response. -type Response struct { - Result map[string]any `json:"result,omitempty"` - Error *Error `json:"error,omitempty"` - JSONRPC string `json:"jsonrpc"` - ID int `json:"id"` -} - -// Error represents a JSON-RPC 2.0 error. -type Error struct { - Message string `json:"message"` - Code int `json:"code"` -} - -// ParseResponse reads and parses a JSON-RPC response from a reader. -func ParseResponse(r io.Reader) (*Response, error) { - var response Response - if err := json.NewDecoder(r).Decode(&response); err != nil { - return nil, err - } - return &response, nil -}