diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go index c28ca6ca2f..d3ab267bf2 100644 --- a/pkg/tools/edit.go +++ b/pkg/tools/edit.go @@ -2,24 +2,27 @@ package tools import ( "context" + "errors" "fmt" - "os" + "io/fs" "strings" ) // EditFileTool edits a file by replacing old_text with new_text. // The old_text must exist exactly in the file. type EditFileTool struct { - allowedDir string - restrict bool + fs fileSystem } // NewEditFileTool creates a new EditFileTool with optional directory restriction. -func NewEditFileTool(allowedDir string, restrict bool) *EditFileTool { - return &EditFileTool{ - allowedDir: allowedDir, - restrict: restrict, - } +func NewEditFileTool(workspace string, restrict bool) *EditFileTool { + var fs fileSystem + if restrict { + fs = &sandboxFs{workspace: workspace} + } else { + fs = &hostFs{} + } + return &EditFileTool{fs: fs} } func (t *EditFileTool) Name() string { @@ -67,49 +70,24 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult("new_text is required") } - resolvedPath, err := validatePath(path, t.allowedDir, t.restrict) - if err != nil { + if err := editFile(t.fs, path, oldText, newText); err != nil { return ErrorResult(err.Error()) } - - if _, err = os.Stat(resolvedPath); os.IsNotExist(err) { - return ErrorResult(fmt.Sprintf("file not found: %s", path)) - } - - content, err := os.ReadFile(resolvedPath) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) - } - - contentStr := string(content) - - if !strings.Contains(contentStr, oldText) { - return ErrorResult("old_text not found in file. Make sure it matches exactly") - } - - count := strings.Count(contentStr, oldText) - if count > 1 { - return ErrorResult( - fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count), - ) - } - - newContent := strings.Replace(contentStr, oldText, newText, 1) - - if err := os.WriteFile(resolvedPath, []byte(newContent), 0o644); err != nil { - return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) - } - return SilentResult(fmt.Sprintf("File edited: %s", path)) } type AppendFileTool struct { - workspace string - restrict bool + fs fileSystem } func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool { - return &AppendFileTool{workspace: workspace, restrict: restrict} + var fs fileSystem + if restrict { + fs = &sandboxFs{workspace: workspace} + } else { + fs = &hostFs{} + } + return &AppendFileTool{fs: fs} } func (t *AppendFileTool) Name() string { @@ -148,20 +126,52 @@ func (t *AppendFileTool) Execute(ctx context.Context, args map[string]any) *Tool return ErrorResult("content is required") } - resolvedPath, err := validatePath(path, t.workspace, t.restrict) - if err != nil { + if err := appendFile(t.fs, path, content); err != nil { return ErrorResult(err.Error()) } + return SilentResult(fmt.Sprintf("Appended to %s", path)) +} - f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) +// editFile reads the file via sysFs, performs the replacement, and writes back. +// It uses a fileSystem interface, allowing the same logic for both restricted and unrestricted modes. +func editFile(sysFs fileSystem, path, oldText, newText string) error { + content, err := sysFs.ReadFile(path) if err != nil { - return ErrorResult(fmt.Sprintf("failed to open file: %v", err)) + return err } - defer f.Close() - if _, err := f.WriteString(content); err != nil { - return ErrorResult(fmt.Sprintf("failed to append to file: %v", err)) + newContent, err := replaceEditContent(content, oldText, newText) + if err != nil { + return err } - return SilentResult(fmt.Sprintf("Appended to %s", path)) + return sysFs.WriteFile(path, newContent) +} + +// appendFile reads the existing content (if any) via sysFs, appends new content, and writes back. +func appendFile(sysFs fileSystem, path, appendContent string) error { + content, err := sysFs.ReadFile(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return err + } + + newContent := append(content, []byte(appendContent)...) + return sysFs.WriteFile(path, newContent) +} + +// replaceEditContent handles the core logic of finding and replacing a single occurrence of oldText. +func replaceEditContent(content []byte, oldText, newText string) ([]byte, error) { + contentStr := string(content) + + if !strings.Contains(contentStr, oldText) { + return nil, fmt.Errorf("old_text not found in file. Make sure it matches exactly") + } + + count := strings.Count(contentStr, oldText) + if count > 1 { + return nil, fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count) + } + + newContent := strings.Replace(contentStr, oldText, newText, 1) + return []byte(newContent), nil } diff --git a/pkg/tools/edit_test.go b/pkg/tools/edit_test.go index 6780dd9f6c..83a7e778ca 100644 --- a/pkg/tools/edit_test.go +++ b/pkg/tools/edit_test.go @@ -6,6 +6,8 @@ import ( "path/filepath" "strings" "testing" + + "github.com/stretchr/testify/assert" ) // TestEditTool_EditFile_Success verifies successful file editing @@ -151,14 +153,18 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) { result := tool.Execute(ctx, args) // Should return error result - if !result.IsError { - t.Errorf("Expected error when path is outside allowed directory") - } + assert.True(t, result.IsError, "Expected error when path is outside allowed directory") // Should mention outside allowed directory - if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") { - t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM) - } + // Note: ErrorResult only sets ForLLM by default, so ForUser might be empty. + // We check ForLLM as it's the primary error channel. + assert.True( + t, + strings.Contains(result.ForLLM, "outside") || strings.Contains(result.ForLLM, "access denied") || + strings.Contains(result.ForLLM, "escapes"), + "Expected 'outside allowed' or 'access denied' message, got ForLLM: %s", + result.ForLLM, + ) } // TestEditTool_EditFile_MissingPath verifies error handling for missing path @@ -287,3 +293,145 @@ func TestEditTool_AppendFile_MissingContent(t *testing.T) { t.Errorf("Expected error when content is missing") } } + +// TestReplaceEditContent verifies the helper function replaceEditContent +func TestReplaceEditContent(t *testing.T) { + tests := []struct { + name string + content []byte + oldText string + newText string + expected []byte + expectError bool + }{ + { + name: "successful replacement", + content: []byte("hello world"), + oldText: "world", + newText: "universe", + expected: []byte("hello universe"), + expectError: false, + }, + { + name: "old text not found", + content: []byte("hello world"), + oldText: "golang", + newText: "rust", + expected: nil, + expectError: true, + }, + { + name: "multiple matches found", + content: []byte("test text test"), + oldText: "test", + newText: "done", + expected: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := replaceEditContent(tt.content, tt.oldText, tt.newText) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +// TestAppendFileTool_AppendToNonExistent_Restricted verifies that AppendFileTool in restricted mode +// can append to a file that does not yet exist — it should silently create the file. +// This exercises the errors.Is(err, fs.ErrNotExist) path in appendFileWithRW + rootRW. +func TestAppendFileTool_AppendToNonExistent_Restricted(t *testing.T) { + workspace := t.TempDir() + tool := NewAppendFileTool(workspace, true) + ctx := context.Background() + + args := map[string]any{ + "path": "brand_new_file.txt", + "content": "first content", + } + + result := tool.Execute(ctx, args) + assert.False( + t, + result.IsError, + "Expected success when appending to non-existent file in restricted mode, got: %s", + result.ForLLM, + ) + + // Verify the file was created with correct content + data, err := os.ReadFile(filepath.Join(workspace, "brand_new_file.txt")) + assert.NoError(t, err) + assert.Equal(t, "first content", string(data)) +} + +// TestAppendFileTool_Restricted_Success verifies that AppendFileTool in restricted mode +// correctly appends to an existing file within the sandbox. +func TestAppendFileTool_Restricted_Success(t *testing.T) { + workspace := t.TempDir() + testFile := "existing.txt" + err := os.WriteFile(filepath.Join(workspace, testFile), []byte("initial"), 0o644) + assert.NoError(t, err) + + tool := NewAppendFileTool(workspace, true) + ctx := context.Background() + args := map[string]any{ + "path": testFile, + "content": " appended", + } + + result := tool.Execute(ctx, args) + assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM) + assert.True(t, result.Silent) + + data, err := os.ReadFile(filepath.Join(workspace, testFile)) + assert.NoError(t, err) + assert.Equal(t, "initial appended", string(data)) +} + +// TestEditFileTool_Restricted_InPlaceEdit verifies that EditFileTool in restricted mode +// correctly edits a file using the single-open editFileInRoot path. +func TestEditFileTool_Restricted_InPlaceEdit(t *testing.T) { + workspace := t.TempDir() + testFile := "edit_target.txt" + err := os.WriteFile(filepath.Join(workspace, testFile), []byte("Hello World"), 0o644) + assert.NoError(t, err) + + tool := NewEditFileTool(workspace, true) + ctx := context.Background() + args := map[string]any{ + "path": testFile, + "old_text": "World", + "new_text": "Go", + } + + result := tool.Execute(ctx, args) + assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM) + assert.True(t, result.Silent) + + data, err := os.ReadFile(filepath.Join(workspace, testFile)) + assert.NoError(t, err) + assert.Equal(t, "Hello Go", string(data)) +} + +// TestEditFileTool_Restricted_FileNotFound verifies that editFileInRoot returns a proper +// error message when the target file does not exist. +func TestEditFileTool_Restricted_FileNotFound(t *testing.T) { + workspace := t.TempDir() + tool := NewEditFileTool(workspace, true) + ctx := context.Background() + args := map[string]any{ + "path": "no_such_file.txt", + "old_text": "old", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "not found") +} diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 1bf50906e5..37db8b4ae5 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -3,15 +3,17 @@ package tools import ( "context" "fmt" + "io/fs" "os" "path/filepath" "strings" + "time" ) // validatePath ensures the given path is within the workspace if restrict is true. func validatePath(path, workspace string, restrict bool) (string, error) { if workspace == "" { - return path, nil + return path, fmt.Errorf("workspace is not defined") } absWorkspace, err := filepath.Abs(workspace) @@ -76,16 +78,21 @@ func resolveExistingAncestor(path string) (string, error) { func isWithinWorkspace(candidate, workspace string) bool { rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate)) - return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) + return err == nil && filepath.IsLocal(rel) } type ReadFileTool struct { - workspace string - restrict bool + fs fileSystem } func NewReadFileTool(workspace string, restrict bool) *ReadFileTool { - return &ReadFileTool{workspace: workspace, restrict: restrict} + var fs fileSystem + if restrict { + fs = &sandboxFs{workspace: workspace} + } else { + fs = &hostFs{} + } + return &ReadFileTool{fs: fs} } func (t *ReadFileTool) Name() string { @@ -115,26 +122,25 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult("path is required") } - resolvedPath, err := validatePath(path, t.workspace, t.restrict) + content, err := t.fs.ReadFile(path) if err != nil { return ErrorResult(err.Error()) } - - content, err := os.ReadFile(resolvedPath) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) - } - return NewToolResult(string(content)) } type WriteFileTool struct { - workspace string - restrict bool + fs fileSystem } func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool { - return &WriteFileTool{workspace: workspace, restrict: restrict} + var fs fileSystem + if restrict { + fs = &sandboxFs{workspace: workspace} + } else { + fs = &hostFs{} + } + return &WriteFileTool{fs: fs} } func (t *WriteFileTool) Name() string { @@ -173,30 +179,25 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolR return ErrorResult("content is required") } - resolvedPath, err := validatePath(path, t.workspace, t.restrict) - if err != nil { + if err := t.fs.WriteFile(path, []byte(content)); err != nil { return ErrorResult(err.Error()) } - dir := filepath.Dir(resolvedPath) - if err := os.MkdirAll(dir, 0o755); err != nil { - return ErrorResult(fmt.Sprintf("failed to create directory: %v", err)) - } - - if err := os.WriteFile(resolvedPath, []byte(content), 0o644); err != nil { - return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) - } - return SilentResult(fmt.Sprintf("File written: %s", path)) } type ListDirTool struct { - workspace string - restrict bool + fs fileSystem } func NewListDirTool(workspace string, restrict bool) *ListDirTool { - return &ListDirTool{workspace: workspace, restrict: restrict} + var fs fileSystem + if restrict { + fs = &sandboxFs{workspace: workspace} + } else { + fs = &hostFs{} + } + return &ListDirTool{fs: fs} } func (t *ListDirTool) Name() string { @@ -226,24 +227,179 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]any) *ToolRes path = "." } - resolvedPath, err := validatePath(path, t.workspace, t.restrict) - if err != nil { - return ErrorResult(err.Error()) - } - - entries, err := os.ReadDir(resolvedPath) + entries, err := t.fs.ReadDir(path) if err != nil { return ErrorResult(fmt.Sprintf("failed to read directory: %v", err)) } + return formatDirEntries(entries) +} - result := "" +func formatDirEntries(entries []os.DirEntry) *ToolResult { + var result strings.Builder for _, entry := range entries { if entry.IsDir() { - result += "DIR: " + entry.Name() + "\n" + result.WriteString("DIR: " + entry.Name() + "\n") } else { - result += "FILE: " + entry.Name() + "\n" + result.WriteString("FILE: " + entry.Name() + "\n") + } + } + return NewToolResult(result.String()) +} + +// fileSystem abstracts reading, writing, and listing files, allowing both +// unrestricted (host filesystem) and sandbox (os.Root) implementations to share the same polymorphic interface. +type fileSystem interface { + ReadFile(path string) ([]byte, error) + WriteFile(path string, data []byte) error + ReadDir(path string) ([]os.DirEntry, error) +} + +// hostFs is an unrestricted fileReadWriter that operates directly on the host filesystem. +type hostFs struct{} + +func (h *hostFs) ReadFile(path string) ([]byte, error) { + content, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("failed to read file: file not found: %w", err) + } + if os.IsPermission(err) { + return nil, fmt.Errorf("failed to read file: access denied: %w", err) + } + return nil, fmt.Errorf("failed to read file: %w", err) + } + return content, nil +} + +func (h *hostFs) ReadDir(path string) ([]os.DirEntry, error) { + return os.ReadDir(path) +} + +func (h *hostFs) WriteFile(path string, data []byte) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("failed to create parent directories: %w", err) + } + + // We use a "write-then-rename" pattern here to ensure an atomic write. + // This prevents the target file from being left in a truncated or partial state + // if the operation is interrupted, as the rename operation is atomic on Linux. + tmpPath := fmt.Sprintf("%s.%d.tmp", path, time.Now().UnixNano()) + if err := os.WriteFile(tmpPath, data, 0o644); err != nil { + os.Remove(tmpPath) // Ensure cleanup of partial/empty temp file + return fmt.Errorf("failed to write temp file: %w", err) + } + + if err := os.Rename(tmpPath, path); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("failed to replace original file: %w", err) + } + return nil +} + +// sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root. +type sandboxFs struct { + workspace string +} + +func (r *sandboxFs) execute(path string, fn func(root *os.Root, relPath string) error) error { + if r.workspace == "" { + return fmt.Errorf("workspace is not defined") + } + + root, err := os.OpenRoot(r.workspace) + if err != nil { + return fmt.Errorf("failed to open workspace: %w", err) + } + defer root.Close() + + relPath, err := getSafeRelPath(r.workspace, path) + if err != nil { + return err + } + + return fn(root, relPath) +} + +func (r *sandboxFs) ReadFile(path string) ([]byte, error) { + var content []byte + err := r.execute(path, func(root *os.Root, relPath string) error { + fileContent, err := root.ReadFile(relPath) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("failed to read file: file not found: %w", err) + } + // os.Root returns "escapes from parent" for paths outside the root + if os.IsPermission(err) || strings.Contains(err.Error(), "escapes from parent") || + strings.Contains(err.Error(), "permission denied") { + return fmt.Errorf("failed to read file: access denied: %w", err) + } + return fmt.Errorf("failed to read file: %w", err) } + content = fileContent + return nil + }) + return content, err +} + +func (r *sandboxFs) WriteFile(path string, data []byte) error { + return r.execute(path, func(root *os.Root, relPath string) error { + dir := filepath.Dir(relPath) + if dir != "." && dir != "/" { + if err := root.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("failed to create parent directories: %w", err) + } + } + + // We use a "write-then-rename" pattern here to ensure an atomic write. + // This prevents the target file from being left in a truncated or partial state + // if the operation is interrupted, as the rename operation is atomic on Linux. + tmpRelPath := fmt.Sprintf("%s.%d.tmp", relPath, time.Now().UnixNano()) + + if err := root.WriteFile(tmpRelPath, data, 0o644); err != nil { + root.Remove(tmpRelPath) // Ensure cleanup of partial/empty temp file + return fmt.Errorf("failed to write to temp file: %w", err) + } + + if err := root.Rename(tmpRelPath, relPath); err != nil { + root.Remove(tmpRelPath) + return fmt.Errorf("failed to rename temp file over target: %w", err) + } + return nil + }) +} + +func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) { + var entries []os.DirEntry + err := r.execute(path, func(root *os.Root, relPath string) error { + dirEntries, err := fs.ReadDir(root.FS(), relPath) + if err != nil { + return err + } + entries = dirEntries + return nil + }) + return entries, err +} + +// Helper to get a safe relative path for os.Root usage +func getSafeRelPath(workspace, path string) (string, error) { + if workspace == "" { + return "", fmt.Errorf("workspace is not defined") + } + + rel := filepath.Clean(path) + if filepath.IsAbs(rel) { + var err error + rel, err = filepath.Rel(workspace, rel) + if err != nil { + return "", fmt.Errorf("failed to calculate relative path: %w", err) + } + } + + if !filepath.IsLocal(rel) { + return "", fmt.Errorf("path escapes workspace: %s", path) } - return NewToolResult(result) + return rel, nil } diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 5daa3dceae..6f896e22d5 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -2,10 +2,13 @@ package tools import ( "context" + "io" "os" "path/filepath" "strings" "testing" + + "github.com/stretchr/testify/assert" ) // TestFilesystemTool_ReadFile_Success verifies successful file reading @@ -14,7 +17,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) { testFile := filepath.Join(tmpDir, "test.txt") os.WriteFile(testFile, []byte("test content"), 0o644) - tool := &ReadFileTool{} + tool := NewReadFileTool("", false) ctx := context.Background() args := map[string]any{ "path": testFile, @@ -41,7 +44,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) { // TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { - tool := &ReadFileTool{} + tool := NewReadFileTool("", false) ctx := context.Background() args := map[string]any{ "path": "/nonexistent_file_12345.txt", @@ -84,7 +87,7 @@ func TestFilesystemTool_WriteFile_Success(t *testing.T) { tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "newfile.txt") - tool := &WriteFileTool{} + tool := NewWriteFileTool("", false) ctx := context.Background() args := map[string]any{ "path": testFile, @@ -123,7 +126,7 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) { tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "subdir", "newfile.txt") - tool := &WriteFileTool{} + tool := NewWriteFileTool("", false) ctx := context.Background() args := map[string]any{ "path": testFile, @@ -149,7 +152,7 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) { // TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) { - tool := &WriteFileTool{} + tool := NewWriteFileTool("", false) ctx := context.Background() args := map[string]any{ "content": "test", @@ -165,7 +168,7 @@ func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) { // TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { - tool := &WriteFileTool{} + tool := NewWriteFileTool("", false) ctx := context.Background() args := map[string]any{ "path": "/tmp/test.txt", @@ -192,7 +195,7 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) { os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0o644) os.Mkdir(filepath.Join(tmpDir, "subdir"), 0o755) - tool := &ListDirTool{} + tool := NewListDirTool("", false) ctx := context.Background() args := map[string]any{ "path": tmpDir, @@ -216,7 +219,7 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) { // TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory func TestFilesystemTool_ListDir_NotFound(t *testing.T) { - tool := &ListDirTool{} + tool := NewListDirTool("", false) ctx := context.Background() args := map[string]any{ "path": "/nonexistent_directory_12345", @@ -237,7 +240,7 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) { // TestFilesystemTool_ListDir_DefaultPath verifies default to current directory func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { - tool := &ListDirTool{} + tool := NewListDirTool("", false) ctx := context.Background() args := map[string]any{} @@ -275,7 +278,211 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) { if !result.IsError { t.Fatalf("expected symlink escape to be blocked") } - if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") { + // os.Root might return different errors depending on platform/implementation + // but it definitely should error. + // Our wrapper returns "access denied or file not found" + if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") && + !strings.Contains(result.ForLLM, "no such file") { t.Fatalf("expected symlink escape error, got: %s", result.ForLLM) } } + +func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) { + tool := NewReadFileTool("", true) // restrict=true but workspace="" + + // Try to read a sensitive file (simulated by a temp file outside workspace) + tmpDir := t.TempDir() + secretFile := filepath.Join(tmpDir, "shadow") + os.WriteFile(secretFile, []byte("secret data"), 0o600) + + result := tool.Execute(context.Background(), map[string]any{ + "path": secretFile, + }) + + // We EXPECT IsError=true (access blocked due to empty workspace) + assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM) + + // Verify it failed for the right reason + assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error") +} + +// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases: +// single dir, deeply nested dirs, already-existing dirs, and a file blocking a directory path. +func TestRootMkdirAll(t *testing.T) { + workspace := t.TempDir() + root, err := os.OpenRoot(workspace) + if err != nil { + t.Fatalf("failed to open root: %v", err) + } + defer root.Close() + + // Case 1: Single directory + err = root.MkdirAll("dir1", 0o755) + assert.NoError(t, err) + _, err = os.Stat(filepath.Join(workspace, "dir1")) + assert.NoError(t, err) + + // Case 2: Deeply nested directory + err = root.MkdirAll("a/b/c/d", 0o755) + assert.NoError(t, err) + _, err = os.Stat(filepath.Join(workspace, "a/b/c/d")) + assert.NoError(t, err) + + // Case 3: Already exists — must be idempotent + err = root.MkdirAll("a/b/c/d", 0o755) + assert.NoError(t, err) + + // Case 4: A regular file blocks directory creation — must error + err = os.WriteFile(filepath.Join(workspace, "file_exists"), []byte("data"), 0o644) + assert.NoError(t, err) + err = root.MkdirAll("file_exists", 0o755) + assert.Error(t, err, "expected error when a file exists at the directory path") +} + +func TestFilesystemTool_WriteFile_Restricted_CreateDir(t *testing.T) { + workspace := t.TempDir() + tool := NewWriteFileTool(workspace, true) + ctx := context.Background() + + testFile := "deep/nested/path/to/file.txt" + content := "deep content" + args := map[string]any{ + "path": testFile, + "content": content, + } + + result := tool.Execute(ctx, args) + assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM) + + // Verify file content + actualPath := filepath.Join(workspace, testFile) + data, err := os.ReadFile(actualPath) + assert.NoError(t, err) + assert.Equal(t, content, string(data)) +} + +// TestHostRW_Read_PermissionDenied verifies that hostRW.Read surfaces access denied errors. +func TestHostRW_Read_PermissionDenied(t *testing.T) { + if os.Getuid() == 0 { + t.Skip("skipping permission test: running as root") + } + tmpDir := t.TempDir() + protected := filepath.Join(tmpDir, "protected.txt") + err := os.WriteFile(protected, []byte("secret"), 0o000) + assert.NoError(t, err) + defer os.Chmod(protected, 0o644) // ensure cleanup + + _, err = (&hostFs{}).ReadFile(protected) + assert.Error(t, err) + assert.Contains(t, err.Error(), "access denied") +} + +// TestHostRW_Read_Directory verifies that hostRW.Read returns an error when given a directory path. +func TestHostRW_Read_Directory(t *testing.T) { + tmpDir := t.TempDir() + + _, err := (&hostFs{}).ReadFile(tmpDir) + assert.Error(t, err, "expected error when reading a directory as a file") +} + +// TestRootRW_Read_Directory verifies that rootRW.Read returns an error when given a directory. +func TestRootRW_Read_Directory(t *testing.T) { + workspace := t.TempDir() + root, err := os.OpenRoot(workspace) + assert.NoError(t, err) + defer root.Close() + + // Create a subdirectory + err = root.Mkdir("subdir", 0o755) + assert.NoError(t, err) + + _, err = (&sandboxFs{workspace: workspace}).ReadFile("subdir") + assert.Error(t, err, "expected error when reading a directory as a file") +} + +// TestHostRW_Write_ParentDirMissing verifies that hostRW.Write creates parent dirs automatically. +func TestHostRW_Write_ParentDirMissing(t *testing.T) { + tmpDir := t.TempDir() + target := filepath.Join(tmpDir, "a", "b", "c", "file.txt") + + err := (&hostFs{}).WriteFile(target, []byte("hello")) + assert.NoError(t, err) + + data, err := os.ReadFile(target) + assert.NoError(t, err) + assert.Equal(t, "hello", string(data)) +} + +// TestRootRW_Write_ParentDirMissing verifies that rootRW.Write creates +// nested parent directories automatically within the sandbox. +func TestRootRW_Write_ParentDirMissing(t *testing.T) { + workspace := t.TempDir() + + relPath := "x/y/z/file.txt" + err := (&sandboxFs{workspace: workspace}).WriteFile(relPath, []byte("nested")) + assert.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(workspace, relPath)) + assert.NoError(t, err) + assert.Equal(t, "nested", string(data)) +} + +// TestHostRW_Write verifies the hostRW.Write helper function +func TestHostRW_Write(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "atomic_test.txt") + testData := []byte("atomic test content") + + err := (&hostFs{}).WriteFile(testFile, testData) + assert.NoError(t, err) + + content, err := os.ReadFile(testFile) + assert.NoError(t, err) + assert.Equal(t, testData, content) + + // Verify it overwrites correctly + newData := []byte("new atomic content") + err = (&hostFs{}).WriteFile(testFile, newData) + assert.NoError(t, err) + + content, err = os.ReadFile(testFile) + assert.NoError(t, err) + assert.Equal(t, newData, content) +} + +// TestRootRW_Write verifies the rootRW.Write helper function +func TestRootRW_Write(t *testing.T) { + tmpDir := t.TempDir() + + relPath := "atomic_root_test.txt" + testData := []byte("atomic root test content") + + erw := &sandboxFs{workspace: tmpDir} + err := erw.WriteFile(relPath, testData) + assert.NoError(t, err) + + root, err := os.OpenRoot(tmpDir) + assert.NoError(t, err) + defer root.Close() + + f, err := root.Open(relPath) + assert.NoError(t, err) + defer f.Close() + + content, err := io.ReadAll(f) + assert.NoError(t, err) + assert.Equal(t, testData, content) + + // Verify it overwrites correctly + newData := []byte("new root atomic content") + err = erw.WriteFile(relPath, newData) + assert.NoError(t, err) + + f2, err := root.Open(relPath) + assert.NoError(t, err) + defer f2.Close() + + content, err = io.ReadAll(f2) + assert.NoError(t, err) + assert.Equal(t, newData, content) +}