Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion internal/codemode/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"go/parser"
"go/printer"
"go/token"
"os"
"os/exec"
Expand All @@ -16,6 +17,7 @@ import (
"time"

"github.com/modelcontextprotocol/go-sdk/mcp"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/imports"
)

Expand All @@ -25,6 +27,9 @@ const mcpSDKVersion = "v1.1.0"
// gracePeriod is the time to wait after sending SIGINT before sending SIGKILL
const gracePeriod = 5 * time.Second

// mcpSDKImport is the import path for the MCP SDK package
const mcpSDKImport = "github.com/modelcontextprotocol/go-sdk/mcp"

// ExecutionResult represents the outcome of code execution
type ExecutionResult struct {
Output string // Combined stdout/stderr
Expand Down Expand Up @@ -231,10 +236,12 @@ func autoCorrectImports(ctx context.Context, dir, filename string) string {
return ""
}

// Ensure correct mcp import before goimports to prevent wrong package resolution
preprocessed := ensureMCPImport(orig)
origImports := extractImports(orig)

// Process the file using golang.org/x/tools/imports
newContent, err := imports.Process(filePath, orig, nil)
newContent, err := imports.Process(filePath, preprocessed, nil)
if err != nil {
// If processing fails (e.g. syntax errors), let the compiler catch it
return ""
Expand Down Expand Up @@ -318,6 +325,26 @@ func unmarshalContent(data []byte) ([]mcp.Content, error) {
return result, nil
}

// ensureMCPImport adds the MCP SDK import if not already present.
// This prevents goimports from resolving mcp to the wrong package.
// The import will be removed by goimports if not actually used.
func ensureMCPImport(src []byte) []byte {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
if err != nil {
return src
}

astutil.AddImport(fset, f, mcpSDKImport)

var buf bytes.Buffer
if err := printer.Fprint(&buf, fset, f); err != nil {
return src
}

return buf.Bytes()
}

// formatImportChanges generates a message describing which imports were added/removed.
func formatImportChanges(filename string, origImports, newImports map[string]bool) string {
var added, removed []string
Expand Down
123 changes: 123 additions & 0 deletions internal/codemode/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -927,3 +927,126 @@ func Run(ctx context.Context) ([]mcp.Content, error) {
t.Errorf("Output = %q, want to contain 'text only output'", result.Output)
}
}

func TestEnsureMCPImport(t *testing.T) {
tests := []struct {
name string
input string
wantImport bool
}{
{
name: "adds import when no imports exist",
input: `package main

func Run() error {
return nil
}
`,
wantImport: true,
},
{
name: "adds import to existing import block",
input: `package main

import "fmt"

func Run() error {
fmt.Println("hello")
return nil
}
`,
wantImport: true,
},
{
name: "import already present",
input: `package main

import (
"context"
"github.com/modelcontextprotocol/go-sdk/mcp"
)

func Run(ctx context.Context) ([]mcp.Content, error) {
return nil, nil
}
`,
wantImport: true,
},
{
name: "adds to multi-import block",
input: `package main

import (
"context"
"os"
)

func Run(ctx context.Context) ([]mcp.Content, error) {
data, _ := os.ReadFile("image.png")
return []mcp.Content{
&mcp.ImageContent{Data: data, MIMEType: "image/png"},
}, nil
}
`,
wantImport: true,
},
{
name: "syntax error - returns unchanged",
input: `package main

import (
"context"

func Run( { // syntax error
`,
wantImport: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ensureMCPImport([]byte(tt.input))
hasImport := strings.Contains(string(result), "github.com/modelcontextprotocol/go-sdk/mcp")
if hasImport != tt.wantImport {
t.Errorf("hasImport = %v, want %v\nresult:\n%s", hasImport, tt.wantImport, string(result))
}
})
}
}

func TestEnsureMCPImport_Integration(t *testing.T) {
code := `package main

import (
"context"
"os"
)

func Run(ctx context.Context) ([]mcp.Content, error) {
data, _ := os.ReadFile("test.png")
return []mcp.Content{
&mcp.ImageContent{Data: data, MIMEType: "image/png"},
}, nil
}
`

result := ensureMCPImport([]byte(code))
if !strings.Contains(string(result), "github.com/modelcontextprotocol/go-sdk/mcp") {
t.Fatal("expected mcp import to be added")
}

ctx := context.Background()
execResult, err := ExecuteCode(ctx, nil, string(result), 30)
if err != nil {
if strings.Contains(execResult.Output, "undefined: mcp") {
t.Errorf("mcp import was not properly added:\n%s", execResult.Output)
}
// Other errors are acceptable (e.g., file not found at runtime)
}
// Verify compilation succeeded (exit code 0 or 1 for runtime error is fine)
if execResult.ExitCode != 0 && execResult.ExitCode != 1 {
if strings.Contains(execResult.Output, "undefined:") {
t.Errorf("compilation failed with undefined symbol:\n%s", execResult.Output)
}
}
}