diff --git a/README.md b/README.md index e80e2213ce..6872ae5ec2 100644 --- a/README.md +++ b/README.md @@ -528,6 +528,21 @@ When `restrict_to_workspace: true`, the following tools are sandboxed: | `append_file` | Append to files | Only files within workspace | | `exec` | Execute commands | Command paths must be within workspace | +#### Web Fetch Network Boundary + +`web_fetch` enforces an outbound network boundary independent of `restrict_to_workspace`. + +Blocked destination classes include: + +* loopback +* private RFC1918 / unique-local ranges +* link-local +* multicast +* unspecified / non-routable internal targets +* redirect hops that resolve to blocked targets + +This policy is applied both before connect and during redirect handling, so a public URL cannot bounce into private infrastructure through redirects. + #### Additional Exec Protection Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous commands: @@ -551,6 +566,11 @@ Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous {tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} ``` +``` +[ERROR] tool: Tool execution failed +{tool=web_fetch, error=blocked destination: host "127.0.0.1" resolves to non-public IP 127.0.0.1} +``` + #### Disabling Restrictions (Security Risk) If you need the agent to access paths outside the workspace: @@ -577,7 +597,12 @@ export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false #### Security Boundary Consistency -The `restrict_to_workspace` setting applies consistently across all execution paths: +PicoClaw enforces complementary boundaries: + +* Filesystem + shell boundary via `restrict_to_workspace` +* Network egress boundary via `web_fetch` public-target validation + +The workspace boundary (`restrict_to_workspace`) applies consistently across all execution paths: | Execution Path | Security Boundary | |----------------|-------------------| diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 5601d508c3..b7614a9b68 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -39,6 +39,10 @@ type thinkingCancel struct { fn context.CancelFunc } +const ( + telegramMaxMessageLength = 4096 +) + func (c *thinkingCancel) Cancel() { if c != nil && c.fn != nil { c.fn() @@ -157,35 +161,66 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err c.stopThinking.Delete(msg.ChatID) } - htmlContent := markdownToTelegramHTML(msg.Content) + chunks := splitTelegramMessageContent(msg.Content, telegramMaxMessageLength) + if len(chunks) == 0 { + return nil + } // Try to edit placeholder if pID, ok := c.placeholders.Load(msg.ChatID); ok { c.placeholders.Delete(msg.ChatID) - editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) - editMsg.ParseMode = telego.ModeHTML - - if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { + if err := c.editMessageChunk(ctx, chatID, pID.(int), chunks[0]); err == nil { + for i := 1; i < len(chunks); i++ { + if sendErr := c.sendMessageChunk(ctx, chatID, chunks[i]); sendErr != nil { + return sendErr + } + } return nil } // Fallback to new message if edit fails } + for _, chunk := range chunks { + if err := c.sendMessageChunk(ctx, chatID, chunk); err != nil { + return err + } + } + + return nil +} + +func (c *TelegramChannel) sendMessageChunk(ctx context.Context, chatID int64, content string) error { + htmlContent := markdownToTelegramHTML(content) tgMsg := tu.Message(tu.ID(chatID), htmlContent) tgMsg.ParseMode = telego.ModeHTML - if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + if _, err := c.bot.SendMessage(ctx, tgMsg); err != nil { logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ "error": err.Error(), }) - tgMsg.ParseMode = "" - _, err = c.bot.SendMessage(ctx, tgMsg) - return err + plainMsg := tu.Message(tu.ID(chatID), content) + _, fallbackErr := c.bot.SendMessage(ctx, plainMsg) + return fallbackErr } return nil } +func (c *TelegramChannel) editMessageChunk(ctx context.Context, chatID int64, messageID int, content string) error { + htmlContent := markdownToTelegramHTML(content) + editMsg := tu.EditMessageText(tu.ID(chatID), messageID, htmlContent) + editMsg.ParseMode = telego.ModeHTML + if _, err := c.bot.EditMessageText(ctx, editMsg); err != nil { + logger.ErrorCF("telegram", "HTML edit parse failed, falling back to plain text", map[string]interface{}{ + "error": err.Error(), + }) + plainEdit := tu.EditMessageText(tu.ID(chatID), messageID, content) + _, fallbackErr := c.bot.EditMessageText(ctx, plainEdit) + return fallbackErr + } + return nil +} + func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error { if message == nil { return fmt.Errorf("message is nil") @@ -453,6 +488,67 @@ func markdownToTelegramHTML(text string) string { return text } +func splitTelegramMessageContent(text string, maxLen int) []string { + text = strings.TrimSpace(text) + if text == "" { + return nil + } + + if maxLen <= 0 { + return []string{text} + } + + chunks := utils.SplitMessage(text, maxLen) + return enforceTelegramMessageHTMLLimit(chunks, maxLen) +} + +func enforceTelegramMessageHTMLLimit(chunks []string, maxLen int) []string { + out := make([]string, 0, len(chunks)) + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + + if runeLen(markdownToTelegramHTML(chunk)) <= maxLen { + out = append(out, chunk) + continue + } + + runes := []rune(chunk) + if len(runes) <= 1 { + out = append(out, chunk) + continue + } + + splitLimit := len(runes) / 2 + if splitLimit > maxLen { + splitLimit = maxLen + } + if splitLimit < 1 { + splitLimit = 1 + } + + subChunks := utils.SplitMessage(chunk, splitLimit) + if len(subChunks) <= 1 { + mid := len(runes) / 2 + if mid < 1 { + mid = 1 + } + subChunks = []string{ + string(runes[:mid]), + string(runes[mid:]), + } + } + out = append(out, enforceTelegramMessageHTMLLimit(subChunks, maxLen)...) + } + return out +} + +func runeLen(text string) int { + return len([]rune(text)) +} + type codeBlockMatch struct { text string codes []string diff --git a/pkg/channels/telegram_test.go b/pkg/channels/telegram_test.go new file mode 100644 index 0000000000..fbf77d6e74 --- /dev/null +++ b/pkg/channels/telegram_test.go @@ -0,0 +1,54 @@ +package channels + +import ( + "strings" + "testing" +) + +func TestSplitTelegramMessageContentShortMessage(t *testing.T) { + input := "hello world" + chunks := splitTelegramMessageContent(input, telegramMaxMessageLength) + + if len(chunks) != 1 { + t.Fatalf("len(chunks) = %d, want 1", len(chunks)) + } + if chunks[0] != input { + t.Fatalf("chunk[0] = %q, want %q", chunks[0], input) + } +} + +func TestSplitTelegramMessageContentLongMessage(t *testing.T) { + input := strings.Repeat("This is a long telegram message chunk. ", 300) + chunks := splitTelegramMessageContent(input, telegramMaxMessageLength) + + if len(chunks) < 2 { + t.Fatalf("expected multiple chunks, got %d", len(chunks)) + } + + for i, chunk := range chunks { + if strings.TrimSpace(chunk) == "" { + t.Fatalf("chunk %d is empty", i) + } + html := markdownToTelegramHTML(chunk) + if runeLen(html) > telegramMaxMessageLength { + t.Fatalf("chunk %d HTML length = %d, want <= %d", i, runeLen(html), telegramMaxMessageLength) + } + } +} + +func TestSplitTelegramMessageContentEscapingExpansion(t *testing.T) { + // '&' expands to '&' in HTML, so this validates recursive splitting safety. + input := strings.Repeat("&", 5000) + chunks := splitTelegramMessageContent(input, telegramMaxMessageLength) + + if len(chunks) < 2 { + t.Fatalf("expected multiple chunks, got %d", len(chunks)) + } + + for i, chunk := range chunks { + html := markdownToTelegramHTML(chunk) + if runeLen(html) > telegramMaxMessageLength { + t.Fatalf("chunk %d escaped HTML length = %d, want <= %d", i, runeLen(html), telegramMaxMessageLength) + } + } +} diff --git a/pkg/tools/README.md b/pkg/tools/README.md new file mode 100644 index 0000000000..83143c2115 --- /dev/null +++ b/pkg/tools/README.md @@ -0,0 +1,9 @@ +# Tool Security Checklist + +When adding a new built-in tool, include these minimum safety checks: + +1. Path boundary: if the tool reads/writes files or executes commands with paths, enforce canonical workspace membership when `restrict_to_workspace=true`. +2. Network boundary: if the tool performs outbound network calls, reject loopback/private/link-local/multicast/unspecified/internal targets and validate redirect hops. +3. Timeout behavior: long-running operations must use deterministic timeout/cancel handling and terminate child processes where process trees are possible. +4. Regression tests: add explicit tests for blocked behavior (not just happy-path errors), including redirect/path traversal/process-leak scenarios where relevant. +5. Error clarity: return explicit denial reasons (`blocked destination`, `outside workspace`, `timed out`) so behavior is auditable in logs. diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 09063ea0a6..e97a9b6f84 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -19,62 +19,98 @@ func validatePath(path, workspace string, restrict bool) (string, error) { return "", fmt.Errorf("failed to resolve workspace path: %w", err) } - var absPath string + absPath, err := resolveTargetPath(path, absWorkspace) + if err != nil { + return "", err + } + + if !restrict { + return absPath, nil + } + + canonicalPath, err := pathWithinWorkspace(absPath, absWorkspace) + if err != nil { + return "", err + } + + return canonicalPath, nil +} + +func resolveTargetPath(path, absWorkspace string) (string, error) { if filepath.IsAbs(path) { - absPath = filepath.Clean(path) - } else { - absPath, err = filepath.Abs(filepath.Join(absWorkspace, path)) - if err != nil { - return "", fmt.Errorf("failed to resolve file path: %w", err) - } + return filepath.Clean(path), nil } - if restrict { - if !isWithinWorkspace(absPath, absWorkspace) { - return "", fmt.Errorf("access denied: path is outside the workspace") - } + absPath, err := filepath.Abs(filepath.Join(absWorkspace, path)) + if err != nil { + return "", fmt.Errorf("failed to resolve file path: %w", err) + } + return filepath.Clean(absPath), nil +} - workspaceReal := absWorkspace - if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil { - workspaceReal = resolved - } +func pathWithinWorkspace(target, workspace string) (string, error) { + canonicalWorkspace, err := canonicalizeExistingPath(workspace) + if err != nil { + return "", fmt.Errorf("failed to canonicalize workspace path: %w", err) + } - if resolved, err := filepath.EvalSymlinks(absPath); err == nil { - if !isWithinWorkspace(resolved, workspaceReal) { - return "", fmt.Errorf("access denied: symlink resolves outside workspace") - } - } else if os.IsNotExist(err) { - if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil { - if !isWithinWorkspace(parentResolved, workspaceReal) { - return "", fmt.Errorf("access denied: symlink resolves outside workspace") - } - } else if !os.IsNotExist(err) { - return "", fmt.Errorf("failed to resolve path: %w", err) - } - } else { - return "", fmt.Errorf("failed to resolve path: %w", err) - } + canonicalTarget, err := canonicalizePathForBoundary(target) + if err != nil { + return "", fmt.Errorf("failed to canonicalize target path: %w", err) + } + + rel, err := filepath.Rel(canonicalWorkspace, canonicalTarget) + if err != nil { + return "", fmt.Errorf("failed to evaluate workspace boundary: %w", err) } - return absPath, nil + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("access denied: path is outside the workspace") + } + + return canonicalTarget, nil +} + +func canonicalizeExistingPath(path string) (string, error) { + resolved, err := filepath.EvalSymlinks(path) + if err != nil { + return "", err + } + return filepath.Clean(resolved), nil } -func resolveExistingAncestor(path string) (string, error) { - for current := filepath.Clean(path); ; current = filepath.Dir(current) { - if resolved, err := filepath.EvalSymlinks(current); err == nil { - return resolved, nil - } else if !os.IsNotExist(err) { +func canonicalizePathForBoundary(path string) (string, error) { + cleanPath := filepath.Clean(path) + segments := make([]string, 0, 4) + current := cleanPath + + for { + _, err := os.Lstat(current) + if err == nil { + resolved, evalErr := filepath.EvalSymlinks(current) + if evalErr != nil { + return "", evalErr + } + + for i := len(segments) - 1; i >= 0; i-- { + resolved = filepath.Join(resolved, segments[i]) + } + + return filepath.Clean(resolved), nil + } + + if !os.IsNotExist(err) { return "", err } - if filepath.Dir(current) == current { - return "", os.ErrNotExist + + parent := filepath.Dir(current) + if parent == current { + return "", fmt.Errorf("could not resolve existing parent for %q", path) } - } -} -func isWithinWorkspace(candidate, workspace string) bool { - rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate)) - return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) + segments = append(segments, filepath.Base(current)) + current = parent + } } type ReadFileTool struct { diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 9580364195..f245e4d56c 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -248,34 +248,54 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { } } -// Block paths that look inside workspace but point outside via symlink. -func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) { +// TestFilesystemTool_Restrict_BlocksSymlinkEscape verifies symlink traversal outside workspace is blocked. +func TestFilesystemTool_Restrict_BlocksSymlinkEscape(t *testing.T) { + workspace := t.TempDir() + outside := t.TempDir() + secretFile := filepath.Join(outside, "secret.txt") + if err := os.WriteFile(secretFile, []byte("do-not-read"), 0644); err != nil { + t.Fatalf("failed to create outside file: %v", err) + } + + linkPath := filepath.Join(workspace, "leak.txt") + if err := os.Symlink(secretFile, linkPath); err != nil { + t.Skipf("symlink unavailable in this environment: %v", err) + } + + tool := NewReadFileTool(workspace, true) + result := tool.Execute(context.Background(), map[string]interface{}{"path": "leak.txt"}) + if !result.IsError { + t.Fatalf("Expected symlink escape to be blocked") + } + if !strings.Contains(strings.ToLower(result.ForLLM), "outside") { + t.Fatalf("Expected workspace boundary error, got: %s", result.ForLLM) + } +} + +// TestFilesystemTool_Restrict_BlocksPrefixBypass verifies prefix confusion paths are blocked. +func TestFilesystemTool_Restrict_BlocksPrefixBypass(t *testing.T) { + parent := t.TempDir() + workspace := filepath.Join(parent, "workspace") + prefixBypassDir := filepath.Join(parent, "workspace-evil") - root := t.TempDir() - workspace := filepath.Join(root, "workspace") if err := os.MkdirAll(workspace, 0755); err != nil { t.Fatalf("failed to create workspace: %v", err) } - - secret := filepath.Join(root, "secret.txt") - if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil { - t.Fatalf("failed to write secret file: %v", err) + if err := os.MkdirAll(prefixBypassDir, 0755); err != nil { + t.Fatalf("failed to create bypass dir: %v", err) } - link := filepath.Join(workspace, "leak.txt") - if err := os.Symlink(secret, link); err != nil { - t.Skipf("symlink not supported in this environment: %v", err) + bypassFile := filepath.Join(prefixBypassDir, "stolen.txt") + if err := os.WriteFile(bypassFile, []byte("secret"), 0644); err != nil { + t.Fatalf("failed to create bypass file: %v", err) } tool := NewReadFileTool(workspace, true) - result := tool.Execute(context.Background(), map[string]interface{}{ - "path": link, - }) - + result := tool.Execute(context.Background(), map[string]interface{}{"path": bypassFile}) if !result.IsError { - t.Fatalf("expected symlink escape to be blocked") + t.Fatalf("Expected prefix bypass path to be blocked") } - if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") { - t.Fatalf("expected symlink escape error, got: %s", result.ForLLM) + if !strings.Contains(strings.ToLower(result.ForLLM), "outside") { + t.Fatalf("Expected workspace boundary error, got: %s", result.ForLLM) } } diff --git a/pkg/tools/network_guard.go b/pkg/tools/network_guard.go new file mode 100644 index 0000000000..17eaa3bf21 --- /dev/null +++ b/pkg/tools/network_guard.go @@ -0,0 +1,178 @@ +package tools + +import ( + "context" + "fmt" + "net" + "net/netip" + "net/url" + "strings" +) + +var ( + cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10") + benchmarkPrefix = netip.MustParsePrefix("198.18.0.0/15") + reservedPrefix = netip.MustParsePrefix("240.0.0.0/4") +) + +type fetchTargetValidator struct { + resolver *net.Resolver + allowedHosts map[string]struct{} +} + +func newFetchTargetValidator(allowHosts []string, resolver *net.Resolver) *fetchTargetValidator { + allowed := make(map[string]struct{}, len(allowHosts)) + for _, host := range allowHosts { + normalized := normalizeHostToken(host) + if normalized != "" { + allowed[normalized] = struct{}{} + } + } + + if resolver == nil { + resolver = net.DefaultResolver + } + + return &fetchTargetValidator{ + resolver: resolver, + allowedHosts: allowed, + } +} + +func (v *fetchTargetValidator) validateURL(ctx context.Context, target *url.URL) error { + host := normalizeHostToken(target.Hostname()) + if host == "" { + return fmt.Errorf("missing domain in URL") + } + + port := target.Port() + if v.isAllowed(host, port) { + return nil + } + + if isBlockedHostname(host) { + return fmt.Errorf("blocked destination: host %q is internal-only", target.Hostname()) + } + + if ip, ok := parseIPLiteral(host); ok { + if IsBlockedIP(ip) { + return fmt.Errorf("blocked destination: IP %s is not publicly routable", ip) + } + return nil + } + + addrs, err := v.resolver.LookupNetIP(ctx, "ip", host) + if err != nil { + return fmt.Errorf("failed to resolve host %q: %w", host, err) + } + if len(addrs) == 0 { + return fmt.Errorf("failed to resolve host %q: no records", host) + } + + for _, addr := range addrs { + if IsBlockedIP(addr) { + return fmt.Errorf("blocked destination: host %q resolves to non-public IP %s", host, addr) + } + } + + return nil +} + +func (v *fetchTargetValidator) isAllowed(host, port string) bool { + if len(v.allowedHosts) == 0 { + return false + } + + if _, ok := v.allowedHosts[host]; ok { + return true + } + if port != "" { + if _, ok := v.allowedHosts[host+":"+port]; ok { + return true + } + } + return false +} + +// ValidateFetchTarget applies the default web fetch target policy. +func ValidateFetchTarget(target *url.URL) error { + return newFetchTargetValidator(nil, net.DefaultResolver).validateURL(context.Background(), target) +} + +// IsBlockedIP returns true for IPs that should not be reachable from web_fetch. +func IsBlockedIP(addr netip.Addr) bool { + if !addr.IsValid() { + return true + } + + if addr.IsLoopback() || + addr.IsPrivate() || + addr.IsLinkLocalUnicast() || + addr.IsLinkLocalMulticast() || + addr.IsMulticast() || + addr.IsUnspecified() || + addr.IsInterfaceLocalMulticast() { + return true + } + + if addr.Is4() { + if cgnatPrefix.Contains(addr) || benchmarkPrefix.Contains(addr) || reservedPrefix.Contains(addr) { + return true + } + } + + return false +} + +func isBlockedHostname(host string) bool { + if host == "localhost" || strings.HasSuffix(host, ".localhost") { + return true + } + if strings.HasSuffix(host, ".local") || strings.HasSuffix(host, ".internal") { + return true + } + if host == "metadata.google.internal" || host == "metadata" { + return true + } + return false +} + +func parseIPLiteral(host string) (netip.Addr, bool) { + if i := strings.Index(host, "%"); i >= 0 { + host = host[:i] + } + addr, err := netip.ParseAddr(host) + if err != nil { + return netip.Addr{}, false + } + return addr.Unmap(), true +} + +func normalizeHostToken(raw string) string { + return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(raw)), ".") +} + +func guardedDialContext(base *net.Dialer, validator *fetchTargetValidator) func(ctx context.Context, network, address string) (net.Conn, error) { + if base == nil { + base = &net.Dialer{} + } + + return func(ctx context.Context, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + host = address + port = "" + } + + target := &url.URL{Host: host} + if port != "" { + target.Host = net.JoinHostPort(host, port) + } + + if err := validator.validateURL(ctx, target); err != nil { + return nil, err + } + + return base.DialContext(ctx, network, address) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 713850f977..058446b28f 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -101,33 +101,62 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To var cmd *exec.Cmd if runtime.GOOS == "windows" { - cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command) + cmd = exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", command) } else { - cmd = exec.CommandContext(cmdCtx, "sh", "-c", command) + cmd = exec.Command("sh", "-c", command) } if cwd != "" { cmd.Dir = cwd } + prepareCommandForTreeControl(cmd) + var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr - err := cmd.Run() + if err := cmd.Start(); err != nil { + return ErrorResult(fmt.Sprintf("failed to start command: %v", err)) + } + + waitDone := make(chan error, 1) + go func() { + waitDone <- cmd.Wait() + }() + + var err error + timedOut := false + + select { + case err = <-waitDone: + case <-cmdCtx.Done(): + if killErr := killCommandTree(cmd); killErr != nil { + return ErrorResult(fmt.Sprintf("failed to terminate command tree: %v", killErr)) + } + + timedOut = cmdCtx.Err() == context.DeadlineExceeded + select { + case err = <-waitDone: + case <-time.After(3 * time.Second): + return ErrorResult("command termination timed out after deadline") + } + } + output := stdout.String() if stderr.Len() > 0 { output += "\nSTDERR:\n" + stderr.String() } - if err != nil { - if cmdCtx.Err() == context.DeadlineExceeded { - msg := fmt.Sprintf("Command timed out after %v", t.timeout) - return &ToolResult{ - ForLLM: msg, - ForUser: msg, - IsError: true, - } + if timedOut { + msg := fmt.Sprintf("Command timed out after %v", t.timeout) + return &ToolResult{ + ForLLM: msg, + ForUser: msg, + IsError: true, } + } + + if err != nil { output += fmt.Sprintf("\nExit code: %v", err) } diff --git a/pkg/tools/shell_process_unix.go b/pkg/tools/shell_process_unix.go new file mode 100644 index 0000000000..6037aa7489 --- /dev/null +++ b/pkg/tools/shell_process_unix.go @@ -0,0 +1,37 @@ +//go:build !windows + +package tools + +import ( + "fmt" + "os" + "os/exec" + "syscall" +) + +func prepareCommandForTreeControl(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } +} + +func killCommandTree(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + pgid, err := syscall.Getpgid(cmd.Process.Pid) + if err == nil { + if killErr := syscall.Kill(-pgid, syscall.SIGKILL); killErr == nil || killErr == syscall.ESRCH { + return nil + } else { + return fmt.Errorf("failed to kill process group %d: %w", pgid, killErr) + } + } + + if killErr := cmd.Process.Kill(); killErr != nil && killErr != os.ErrProcessDone { + return fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, killErr) + } + + return nil +} diff --git a/pkg/tools/shell_process_windows.go b/pkg/tools/shell_process_windows.go new file mode 100644 index 0000000000..6669b3dc87 --- /dev/null +++ b/pkg/tools/shell_process_windows.go @@ -0,0 +1,30 @@ +//go:build windows + +package tools + +import ( + "fmt" + "os/exec" + "strconv" + "syscall" +) + +func prepareCommandForTreeControl(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, + } +} + +func killCommandTree(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + pid := strconv.Itoa(cmd.Process.Pid) + killCmd := exec.Command("taskkill", "/T", "/F", "/PID", pid) + if err := killCmd.Run(); err != nil { + return fmt.Errorf("taskkill failed for pid %s: %w", pid, err) + } + + return nil +} diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index c06468a39a..603072dcde 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -2,8 +2,12 @@ package tools import ( "context" + "fmt" "os" + "os/exec" "path/filepath" + "runtime" + "strconv" "strings" "testing" "time" @@ -208,3 +212,95 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) { t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) } } + +// TestShellTool_Timeout_KillsChildProcesses verifies timeout cleanup includes child processes. +func TestShellTool_Timeout_KillsChildProcesses(t *testing.T) { + tool := NewExecTool("", false) + tool.SetTimeout(1200 * time.Millisecond) + + pidFile := filepath.Join(t.TempDir(), "child.pid") + var cmd string + if runtime.GOOS == "windows" { + escapedPidFile := strings.ReplaceAll(pidFile, "'", "''") + cmd = fmt.Sprintf( + "$p = Start-Process -FilePath powershell -ArgumentList '-NoProfile','-NonInteractive','-Command','Start-Sleep -Seconds 30' -WindowStyle Hidden -PassThru; Set-Content -Path '%s' -Value $p.Id; Start-Sleep -Seconds 30", + escapedPidFile, + ) + } else { + cmd = fmt.Sprintf("sleep 30 & echo $! > %q; sleep 30", pidFile) + } + + result := tool.Execute(context.Background(), map[string]interface{}{"command": cmd}) + if !result.IsError { + t.Fatalf("Expected timeout error for long-running command tree") + } + if !strings.Contains(strings.ToLower(result.ForLLM), "timed out") { + t.Fatalf("Expected timeout message, got: %s", result.ForLLM) + } + + childPID, err := waitForPID(pidFile, 4*time.Second) + if err != nil { + t.Fatalf("failed to obtain child pid file before timeout: %v", err) + } + + // Give timeout cleanup a short grace period before checking process liveness. + time.Sleep(400 * time.Millisecond) + alive := processAlive(childPID) + if alive { + status := processStatus(childPID) + killProcess(childPID) + t.Fatalf("expected child process %d to be terminated after timeout; status: %s", childPID, status) + } +} + +func waitForPID(path string, timeout time.Duration) (int, error) { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + raw, err := os.ReadFile(path) + if err == nil { + pid, parseErr := strconv.Atoi(strings.TrimSpace(string(raw))) + if parseErr == nil && pid > 0 { + return pid, nil + } + } + time.Sleep(50 * time.Millisecond) + } + return 0, fmt.Errorf("pid file %s not ready within %v", path, timeout) +} + +func processAlive(pid int) bool { + if runtime.GOOS == "windows" { + err := exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", fmt.Sprintf("Get-Process -Id %d | Out-Null", pid)).Run() + return err == nil + } + + output, err := exec.Command("sh", "-c", fmt.Sprintf("ps -o stat= -p %d", pid)).CombinedOutput() + if err != nil { + return false + } + + state := strings.TrimSpace(string(output)) + if state == "" || strings.HasPrefix(state, "Z") { + return false + } + + return true +} + +func processStatus(pid int) string { + if runtime.GOOS == "windows" { + output, _ := exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %d", pid)).CombinedOutput() + return strings.TrimSpace(string(output)) + } + + output, _ := exec.Command("sh", "-c", fmt.Sprintf("ps -o pid=,ppid=,pgid=,stat=,cmd= -p %d", pid)).CombinedOutput() + return strings.TrimSpace(string(output)) +} + +func killProcess(pid int) { + if runtime.GOOS == "windows" { + _ = exec.Command("taskkill", "/PID", strconv.Itoa(pid), "/T", "/F").Run() + return + } + _ = exec.Command("kill", "-KILL", strconv.Itoa(pid)).Run() +} diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 6a6d40ecf5..f805b303c7 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "regexp" @@ -339,15 +340,21 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} } type WebFetchTool struct { - maxChars int + maxChars int + validator *fetchTargetValidator } func NewWebFetchTool(maxChars int) *WebFetchTool { + return newWebFetchTool(maxChars, nil) +} + +func newWebFetchTool(maxChars int, allowHosts []string) *WebFetchTool { if maxChars <= 0 { maxChars = 50000 } return &WebFetchTool{ - maxChars: maxChars, + maxChars: maxChars, + validator: newFetchTargetValidator(allowHosts, net.DefaultResolver), } } @@ -403,6 +410,10 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) } } + if err := t.validator.validateURL(ctx, parsedURL); err != nil { + return ErrorResult(err.Error()) + } + req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { return ErrorResult(fmt.Sprintf("failed to create request: %v", err)) @@ -410,6 +421,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) req.Header.Set("User-Agent", userAgent) + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + client := &http.Client{ Timeout: 60 * time.Second, Transport: &http.Transport{ @@ -417,11 +433,15 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) IdleConnTimeout: 30 * time.Second, DisableCompression: false, TLSHandshakeTimeout: 15 * time.Second, + DialContext: guardedDialContext(dialer, t.validator), }, CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) >= 5 { return fmt.Errorf("stopped after 5 redirects") } + if err := t.validator.validateURL(req.Context(), req.URL); err != nil { + return err + } return nil }, } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index a526ea34a0..188540e536 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -5,10 +5,17 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "strings" "testing" ) +// newWebFetchToolForTests centralizes test construction. +// allowHosts is wired in Phase 2 once fetch target policy is introduced. +func newWebFetchToolForTests(maxChars int, allowHosts ...string) *WebFetchTool { + return newWebFetchTool(maxChars, allowHosts) +} + // TestWebTool_WebFetch_Success verifies successful URL fetching func TestWebTool_WebFetch_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -18,7 +25,11 @@ func TestWebTool_WebFetch_Success(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + tool := newWebFetchToolForTests(50000, parsed.Host) ctx := context.Background() args := map[string]interface{}{ "url": server.URL, @@ -54,7 +65,11 @@ func TestWebTool_WebFetch_JSON(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + tool := newWebFetchToolForTests(50000, parsed.Host) ctx := context.Background() args := map[string]interface{}{ "url": server.URL, @@ -145,7 +160,11 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(1000) // Limit to 1000 chars + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + tool := newWebFetchToolForTests(1000, parsed.Host) // Limit to 1000 chars ctx := context.Background() args := map[string]interface{}{ "url": server.URL, @@ -210,7 +229,11 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + tool := newWebFetchToolForTests(50000, parsed.Host) ctx := context.Background() args := map[string]interface{}{ "url": server.URL, @@ -254,3 +277,47 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) { t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM) } } + +// TestWebTool_WebFetch_BlocksLoopback verifies loopback targets are blocked +func TestWebTool_WebFetch_BlocksLoopback(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": "http://127.0.0.1/", + } + + result := tool.Execute(ctx, args) + if !result.IsError { + t.Fatalf("Expected blocked destination error for loopback target") + } + if !strings.Contains(strings.ToLower(result.ForLLM), "blocked destination") { + t.Fatalf("Expected blocked destination message, got: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_BlocksRedirectToPrivate verifies private redirect hops are blocked +func TestWebTool_WebFetch_BlocksRedirectToPrivate(t *testing.T) { + redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "http://127.0.0.1:1/private", http.StatusFound) + })) + defer redirectServer.Close() + + parsed, err := url.Parse(redirectServer.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + + tool := newWebFetchToolForTests(50000, parsed.Host) + ctx := context.Background() + args := map[string]interface{}{ + "url": redirectServer.URL, + } + + result := tool.Execute(ctx, args) + if !result.IsError { + t.Fatalf("Expected blocked destination error for redirect to private target") + } + if !strings.Contains(strings.ToLower(result.ForLLM), "blocked destination") { + t.Fatalf("Expected blocked destination message, got: %s", result.ForLLM) + } +} diff --git a/pkg/utils/message.go b/pkg/utils/message.go new file mode 100644 index 0000000000..1d05950d9f --- /dev/null +++ b/pkg/utils/message.go @@ -0,0 +1,179 @@ +package utils + +import ( + "strings" +) + +// SplitMessage splits long messages into chunks, preserving code block integrity. +// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks, +// but may extend to maxLen when needed. +// Call SplitMessage with the full text content and the maximum allowed length of a single message; +// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks. +func SplitMessage(content string, maxLen int) []string { + var messages []string + + // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible + codeBlockBuffer := maxLen / 10 + if codeBlockBuffer < 50 { + codeBlockBuffer = 50 + } + if codeBlockBuffer > maxLen/2 { + codeBlockBuffer = maxLen / 2 + } + + for len(content) > 0 { + if len(content) <= maxLen { + messages = append(messages, content) + break + } + + // Effective split point: maxLen minus buffer, to leave room for code blocks + effectiveLimit := maxLen - codeBlockBuffer + if effectiveLimit < maxLen/2 { + effectiveLimit = maxLen / 2 + } + + // Find natural split point within the effective limit + msgEnd := findLastNewline(content[:effectiveLimit], 200) + if msgEnd <= 0 { + msgEnd = findLastSpace(content[:effectiveLimit], 100) + } + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + // Check if this would end with an incomplete code block + candidate := content[:msgEnd] + unclosedIdx := findLastUnclosedCodeBlock(candidate) + + if unclosedIdx >= 0 { + // Message would end with incomplete code block + // Try to extend up to maxLen to include the closing ``` + if len(content) > msgEnd { + closingIdx := findNextClosingCodeBlock(content, msgEnd) + if closingIdx > 0 && closingIdx <= maxLen { + // Extend to include the closing ``` + msgEnd = closingIdx + } else { + // Code block is too long to fit in one chunk or missing closing fence. + // Try to split inside by injecting closing and reopening fences. + headerEnd := strings.Index(content[unclosedIdx:], "\n") + if headerEnd == -1 { + headerEnd = unclosedIdx + 3 + } else { + headerEnd += unclosedIdx + } + header := strings.TrimSpace(content[unclosedIdx:headerEnd]) + + // If we have a reasonable amount of content after the header, split inside + if msgEnd > headerEnd+20 { + // Find a better split point closer to maxLen + innerLimit := maxLen - 5 // Leave room for "\n```" + betterEnd := findLastNewline(content[:innerLimit], 200) + if betterEnd > headerEnd { + msgEnd = betterEnd + } else { + msgEnd = innerLimit + } + messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") + content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + continue + } + + // Otherwise, try to split before the code block starts + newEnd := findLastNewline(content[:unclosedIdx], 200) + if newEnd <= 0 { + newEnd = findLastSpace(content[:unclosedIdx], 100) + } + if newEnd > 0 { + msgEnd = newEnd + } else { + // If we can't split before, we MUST split inside (last resort) + if unclosedIdx > 20 { + msgEnd = unclosedIdx + } else { + msgEnd = maxLen - 5 + messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") + content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + continue + } + } + } + } + } + + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + messages = append(messages, content[:msgEnd]) + content = strings.TrimSpace(content[msgEnd:]) + } + + return messages +} + +// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ``` +// Returns the position of the opening ``` or -1 if all code blocks are complete +func findLastUnclosedCodeBlock(text string) int { + inCodeBlock := false + lastOpenIdx := -1 + + for i := 0; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + // Toggle code block state on each fence + if !inCodeBlock { + // Entering a code block: record this opening fence + lastOpenIdx = i + } + inCodeBlock = !inCodeBlock + i += 2 + } + } + + if inCodeBlock { + return lastOpenIdx + } + return -1 +} + +// findNextClosingCodeBlock finds the next closing ``` starting from a position +// Returns the position after the closing ``` or -1 if not found +func findNextClosingCodeBlock(text string, startIdx int) int { + for i := startIdx; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + return i + 3 + } + } + return -1 +} + +// findLastNewline finds the last newline character within the last N characters +// Returns the position of the newline or -1 if not found +func findLastNewline(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == '\n' { + return i + } + } + return -1 +} + +// findLastSpace finds the last space character within the last N characters +// Returns the position of the space or -1 if not found +func findLastSpace(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == ' ' || s[i] == '\t' { + return i + } + } + return -1 +} diff --git a/pkg/utils/message_test.go b/pkg/utils/message_test.go new file mode 100644 index 0000000000..338509437d --- /dev/null +++ b/pkg/utils/message_test.go @@ -0,0 +1,151 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestSplitMessage(t *testing.T) { + longText := strings.Repeat("a", 2500) + longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars + + tests := []struct { + name string + content string + maxLen int + expectChunks int // Check number of chunks + checkContent func(t *testing.T, chunks []string) // Custom validation + }{ + { + name: "Empty message", + content: "", + maxLen: 2000, + expectChunks: 0, + }, + { + name: "Short message fits in one chunk", + content: "Hello world", + maxLen: 2000, + expectChunks: 1, + }, + { + name: "Simple split regular text", + content: longText, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len(chunks[0]) > 2000 { + t.Errorf("Chunk 0 too large: %d", len(chunks[0])) + } + if len(chunks[0])+len(chunks[1]) != len(longText) { + t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText)) + } + }, + }, + { + name: "Split at newline", + // 1750 chars then newline, then more chars. + // Dynamic buffer: 2000 / 10 = 200. + // Effective limit: 2000 - 200 = 1800. + // Split should happen at newline because it's at 1750 (< 1800). + // Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051. + content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300), + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len(chunks[0]) != 1750 { + t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0])) + } + if chunks[1] != strings.Repeat("b", 300) { + t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1])) + } + }, + }, + { + name: "Long code block split", + content: "Prefix\n" + longCode, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Check that first chunk ends with closing fence + if !strings.HasSuffix(chunks[0], "\n```") { + t.Error("First chunk should end with injected closing fence") + } + // Check that second chunk starts with execution header + if !strings.HasPrefix(chunks[1], "```go") { + t.Error("Second chunk should start with injected code block header") + } + }, + }, + { + name: "Preserve Unicode characters", + content: strings.Repeat("\u4e16", 1000), // 3000 bytes + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Just verify we didn't panic and got valid strings. + // Go strings are UTF-8, if we split mid-rune it would be bad, + // but standard slicing might do that. + // Let's assume standard behavior is acceptable or check if it produces invalid rune? + if !strings.Contains(chunks[0], "\u4e16") { + t.Error("Chunk should contain unicode characters") + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SplitMessage(tc.content, tc.maxLen) + + if tc.expectChunks == 0 { + if len(got) != 0 { + t.Errorf("Expected 0 chunks, got %d", len(got)) + } + return + } + + if len(got) != tc.expectChunks { + t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got)) + // Log sizes for debugging + for i, c := range got { + t.Logf("Chunk %d length: %d", i, len(c)) + } + return // Stop further checks if count assumes specific split + } + + if tc.checkContent != nil { + tc.checkContent(t, got) + } + }) + } +} + +func TestSplitMessage_CodeBlockIntegrity(t *testing.T) { + // Focused test for the core requirement: splitting inside a code block preserves syntax highlighting + + // 60 chars total approximately + content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```" + maxLen := 40 + + chunks := SplitMessage(content, maxLen) + + if len(chunks) != 2 { + t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks) + } + + // First chunk must end with "\n```" + if !strings.HasSuffix(chunks[0], "\n```") { + t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0]) + } + + // Second chunk must start with the header "```go" + if !strings.HasPrefix(chunks[1], "```go") { + t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1]) + } + + // First chunk should contain meaningful content + if len(chunks[0]) > 40 { + t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0])) + } +}