From 3b4b6bfc026c8fba8c62ef2293c9e751152d2c26 Mon Sep 17 00:00:00 2001 From: Jared Mahotiere Date: Sun, 15 Feb 2026 17:52:02 -0500 Subject: [PATCH 1/3] fix(telegram): split oversized outbound messages into chunks --- pkg/channels/telegram.go | 179 ++++++++++++++++++++++++++++++++-- pkg/channels/telegram_test.go | 54 ++++++++++ 2 files changed, 224 insertions(+), 9 deletions(-) create mode 100644 pkg/channels/telegram_test.go diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index b14b1632e4..e1b25f0383 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -35,6 +35,11 @@ type thinkingCancel struct { fn context.CancelFunc } +const ( + telegramMaxMessageLength = 4096 + telegramSplitTarget = 3900 +) + func (c *thinkingCancel) Cancel() { if c != nil && c.fn != nil { c.fn() @@ -137,35 +142,66 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err c.stopThinking.Delete(msg.ChatID) } - htmlContent := markdownToTelegramHTML(msg.Content) + chunks := splitTelegramMessageContent(msg.Content, telegramMaxMessageLength) + if len(chunks) == 0 { + return nil + } // Try to edit placeholder if pID, ok := c.placeholders.Load(msg.ChatID); ok { c.placeholders.Delete(msg.ChatID) - editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) - editMsg.ParseMode = telego.ModeHTML - - if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { + if err := c.editMessageChunk(ctx, chatID, pID.(int), chunks[0]); err == nil { + for i := 1; i < len(chunks); i++ { + if sendErr := c.sendMessageChunk(ctx, chatID, chunks[i]); sendErr != nil { + return sendErr + } + } return nil } // Fallback to new message if edit fails } + for _, chunk := range chunks { + if err := c.sendMessageChunk(ctx, chatID, chunk); err != nil { + return err + } + } + + return nil +} + +func (c *TelegramChannel) sendMessageChunk(ctx context.Context, chatID int64, content string) error { + htmlContent := markdownToTelegramHTML(content) tgMsg := tu.Message(tu.ID(chatID), htmlContent) tgMsg.ParseMode = telego.ModeHTML - if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + if _, err := c.bot.SendMessage(ctx, tgMsg); err != nil { logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ "error": err.Error(), }) - tgMsg.ParseMode = "" - _, err = c.bot.SendMessage(ctx, tgMsg) - return err + plainMsg := tu.Message(tu.ID(chatID), content) + _, fallbackErr := c.bot.SendMessage(ctx, plainMsg) + return fallbackErr } return nil } +func (c *TelegramChannel) editMessageChunk(ctx context.Context, chatID int64, messageID int, content string) error { + htmlContent := markdownToTelegramHTML(content) + editMsg := tu.EditMessageText(tu.ID(chatID), messageID, htmlContent) + editMsg.ParseMode = telego.ModeHTML + if _, err := c.bot.EditMessageText(ctx, editMsg); err != nil { + logger.ErrorCF("telegram", "HTML edit parse failed, falling back to plain text", map[string]interface{}{ + "error": err.Error(), + }) + plainEdit := tu.EditMessageText(tu.ID(chatID), messageID, content) + _, fallbackErr := c.bot.EditMessageText(ctx, plainEdit) + return fallbackErr + } + return nil +} + func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) { message := update.Message if message == nil { @@ -435,6 +471,131 @@ func markdownToTelegramHTML(text string) string { return text } +func splitTelegramMessageContent(text string, maxLen int) []string { + text = strings.TrimSpace(text) + if text == "" { + return nil + } + + if maxLen <= 0 { + return []string{text} + } + + if runeLen(markdownToTelegramHTML(text)) <= maxLen { + return []string{text} + } + + target := telegramSplitTarget + if target >= maxLen { + target = maxLen - 64 + } + if target < 256 { + target = maxLen / 2 + } + if target < 1 { + target = 1 + } + + parts := splitTextByBoundary(text, target) + if len(parts) == 1 { + runes := []rune(text) + if len(runes) <= 1 { + return parts + } + mid := len(runes) / 2 + if mid < 1 { + mid = 1 + } + parts = []string{ + strings.TrimSpace(string(runes[:mid])), + strings.TrimSpace(string(runes[mid:])), + } + } + + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + out = append(out, splitTelegramMessageContent(p, maxLen)...) + } + return out +} + +func splitTextByBoundary(text string, limit int) []string { + runes := []rune(text) + if len(runes) <= limit { + return []string{text} + } + + result := make([]string, 0) + for len(runes) > 0 { + if len(runes) <= limit { + tail := strings.TrimSpace(string(runes)) + if tail != "" { + result = append(result, tail) + } + break + } + + splitAt := findSplitPoint(runes, limit) + if splitAt <= 0 { + splitAt = limit + } + if splitAt > len(runes) { + splitAt = len(runes) + } + + chunk := strings.TrimSpace(string(runes[:splitAt])) + if chunk != "" { + result = append(result, chunk) + } + + runes = runes[splitAt:] + for len(runes) > 0 && (runes[0] == '\n' || runes[0] == '\r' || runes[0] == ' ' || runes[0] == '\t') { + runes = runes[1:] + } + } + + return result +} + +func findSplitPoint(runes []rune, limit int) int { + if len(runes) <= limit { + return len(runes) + } + if limit <= 1 { + return 1 + } + + floor := limit / 2 + if floor < 1 { + floor = 1 + } + + for i := limit; i > floor; i-- { + if i > 1 && runes[i-1] == '\n' && runes[i-2] == '\n' { + return i + } + } + for i := limit; i > floor; i-- { + if runes[i-1] == '\n' { + return i + } + } + for i := limit; i > floor; i-- { + if runes[i-1] == ' ' || runes[i-1] == '\t' { + return i + } + } + return limit +} + +func runeLen(text string) int { + return len([]rune(text)) +} + type codeBlockMatch struct { text string codes []string diff --git a/pkg/channels/telegram_test.go b/pkg/channels/telegram_test.go new file mode 100644 index 0000000000..fbf77d6e74 --- /dev/null +++ b/pkg/channels/telegram_test.go @@ -0,0 +1,54 @@ +package channels + +import ( + "strings" + "testing" +) + +func TestSplitTelegramMessageContentShortMessage(t *testing.T) { + input := "hello world" + chunks := splitTelegramMessageContent(input, telegramMaxMessageLength) + + if len(chunks) != 1 { + t.Fatalf("len(chunks) = %d, want 1", len(chunks)) + } + if chunks[0] != input { + t.Fatalf("chunk[0] = %q, want %q", chunks[0], input) + } +} + +func TestSplitTelegramMessageContentLongMessage(t *testing.T) { + input := strings.Repeat("This is a long telegram message chunk. ", 300) + chunks := splitTelegramMessageContent(input, telegramMaxMessageLength) + + if len(chunks) < 2 { + t.Fatalf("expected multiple chunks, got %d", len(chunks)) + } + + for i, chunk := range chunks { + if strings.TrimSpace(chunk) == "" { + t.Fatalf("chunk %d is empty", i) + } + html := markdownToTelegramHTML(chunk) + if runeLen(html) > telegramMaxMessageLength { + t.Fatalf("chunk %d HTML length = %d, want <= %d", i, runeLen(html), telegramMaxMessageLength) + } + } +} + +func TestSplitTelegramMessageContentEscapingExpansion(t *testing.T) { + // '&' expands to '&' in HTML, so this validates recursive splitting safety. + input := strings.Repeat("&", 5000) + chunks := splitTelegramMessageContent(input, telegramMaxMessageLength) + + if len(chunks) < 2 { + t.Fatalf("expected multiple chunks, got %d", len(chunks)) + } + + for i, chunk := range chunks { + html := markdownToTelegramHTML(chunk) + if runeLen(html) > telegramMaxMessageLength { + t.Fatalf("chunk %d escaped HTML length = %d, want <= %d", i, runeLen(html), telegramMaxMessageLength) + } + } +} From ddc942542e8cf84e04c422cd5b7ced3b31a2bba9 Mon Sep 17 00:00:00 2001 From: Jared Mahotiere Date: Mon, 16 Feb 2026 13:06:31 -0500 Subject: [PATCH 2/3] harden tool security boundaries across web, filesystem, and exec Phase 1 - Security regression baseline - add blocking tests for loopback and redirect-to-private web_fetch targets - add filesystem symlink escape and prefix-bypass restriction tests - add shell timeout test to verify child-process cleanup behavior Phase 2 - Web fetch egress guard - introduce stdlib-only fetch target validator for host/IP policy checks - enforce target validation before connect, during redirect, and in dial path - keep public constructor API and add internal test-aware constructor path Phase 3 - Canonical filesystem boundary enforcement - replace prefix checks with canonical workspace containment via filepath.Rel - canonicalize workspace and target paths with symlink-aware resolution - protect create/read/write/list/edit/append flows through shared validation Phase 4 - Shell timeout process-tree cleanup - switch command execution to explicit start/wait with timeout select - add OS-specific process tree helpers for unix process groups and taskkill on windows - preserve existing output contract and timeout messaging semantics Phase 5 - Documentation and contributor guidance - document web_fetch network boundary and complementary security model - add tool security checklist for future built-in tool additions Verification - go test ./pkg/tools -run TestWebTool_WebFetch_Blocks (via golang:1.25) - go test ./pkg/tools -run TestFilesystemTool_Restrict (via golang:1.25) - go test ./pkg/tools -run TestShellTool_Timeout_KillsChildProcesses (via golang:1.25) - go test ./pkg/tools -run TestWebTool_WebFetch_ (via golang:1.25) - go test ./pkg/tools -run TestFilesystemTool_ (via golang:1.25) - go test ./pkg/tools -run TestShellTool_ (via golang:1.25) - go test ./pkg/tools (via golang:1.25) - go generate ./... && go test ./... (via golang:1.25) --- README.md | 27 ++++- pkg/tools/README.md | 9 ++ pkg/tools/filesystem.go | 95 +++++++++++++-- pkg/tools/filesystem_test.go | 52 +++++++++ pkg/tools/network_guard.go | 178 +++++++++++++++++++++++++++++ pkg/tools/shell.go | 51 +++++++-- pkg/tools/shell_process_unix.go | 37 ++++++ pkg/tools/shell_process_windows.go | 30 +++++ pkg/tools/shell_test.go | 96 ++++++++++++++++ pkg/tools/web.go | 24 +++- pkg/tools/web_test.go | 75 +++++++++++- 11 files changed, 647 insertions(+), 27 deletions(-) create mode 100644 pkg/tools/README.md create mode 100644 pkg/tools/network_guard.go create mode 100644 pkg/tools/shell_process_unix.go create mode 100644 pkg/tools/shell_process_windows.go diff --git a/README.md b/README.md index 091af28116..99e3ec77df 100644 --- a/README.md +++ b/README.md @@ -511,6 +511,21 @@ When `restrict_to_workspace: true`, the following tools are sandboxed: | `append_file` | Append to files | Only files within workspace | | `exec` | Execute commands | Command paths must be within workspace | +#### Web Fetch Network Boundary + +`web_fetch` enforces an outbound network boundary independent of `restrict_to_workspace`. + +Blocked destination classes include: + +* loopback +* private RFC1918 / unique-local ranges +* link-local +* multicast +* unspecified / non-routable internal targets +* redirect hops that resolve to blocked targets + +This policy is applied both before connect and during redirect handling, so a public URL cannot bounce into private infrastructure through redirects. + #### Additional Exec Protection Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous commands: @@ -534,6 +549,11 @@ Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous {tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} ``` +``` +[ERROR] tool: Tool execution failed +{tool=web_fetch, error=blocked destination: host "127.0.0.1" resolves to non-public IP 127.0.0.1} +``` + #### Disabling Restrictions (Security Risk) If you need the agent to access paths outside the workspace: @@ -560,7 +580,12 @@ export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false #### Security Boundary Consistency -The `restrict_to_workspace` setting applies consistently across all execution paths: +PicoClaw enforces complementary boundaries: + +* Filesystem + shell boundary via `restrict_to_workspace` +* Network egress boundary via `web_fetch` public-target validation + +The workspace boundary (`restrict_to_workspace`) applies consistently across all execution paths: | Execution Path | Security Boundary | |----------------|-------------------| diff --git a/pkg/tools/README.md b/pkg/tools/README.md new file mode 100644 index 0000000000..83143c2115 --- /dev/null +++ b/pkg/tools/README.md @@ -0,0 +1,9 @@ +# Tool Security Checklist + +When adding a new built-in tool, include these minimum safety checks: + +1. Path boundary: if the tool reads/writes files or executes commands with paths, enforce canonical workspace membership when `restrict_to_workspace=true`. +2. Network boundary: if the tool performs outbound network calls, reject loopback/private/link-local/multicast/unspecified/internal targets and validate redirect hops. +3. Timeout behavior: long-running operations must use deterministic timeout/cancel handling and terminate child processes where process trees are possible. +4. Regression tests: add explicit tests for blocked behavior (not just happy-path errors), including redirect/path traversal/process-leak scenarios where relevant. +5. Error clarity: return explicit denial reasons (`blocked destination`, `outside workspace`, `timed out`) so behavior is auditable in logs. diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 2376877344..e97a9b6f84 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -19,21 +19,98 @@ func validatePath(path, workspace string, restrict bool) (string, error) { return "", fmt.Errorf("failed to resolve workspace path: %w", err) } - var absPath string + absPath, err := resolveTargetPath(path, absWorkspace) + if err != nil { + return "", err + } + + if !restrict { + return absPath, nil + } + + canonicalPath, err := pathWithinWorkspace(absPath, absWorkspace) + if err != nil { + return "", err + } + + return canonicalPath, nil +} + +func resolveTargetPath(path, absWorkspace string) (string, error) { if filepath.IsAbs(path) { - absPath = filepath.Clean(path) - } else { - absPath, err = filepath.Abs(filepath.Join(absWorkspace, path)) - if err != nil { - return "", fmt.Errorf("failed to resolve file path: %w", err) - } + return filepath.Clean(path), nil + } + + absPath, err := filepath.Abs(filepath.Join(absWorkspace, path)) + if err != nil { + return "", fmt.Errorf("failed to resolve file path: %w", err) + } + return filepath.Clean(absPath), nil +} + +func pathWithinWorkspace(target, workspace string) (string, error) { + canonicalWorkspace, err := canonicalizeExistingPath(workspace) + if err != nil { + return "", fmt.Errorf("failed to canonicalize workspace path: %w", err) + } + + canonicalTarget, err := canonicalizePathForBoundary(target) + if err != nil { + return "", fmt.Errorf("failed to canonicalize target path: %w", err) } - if restrict && !strings.HasPrefix(absPath, absWorkspace) { + rel, err := filepath.Rel(canonicalWorkspace, canonicalTarget) + if err != nil { + return "", fmt.Errorf("failed to evaluate workspace boundary: %w", err) + } + + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { return "", fmt.Errorf("access denied: path is outside the workspace") } - return absPath, nil + return canonicalTarget, nil +} + +func canonicalizeExistingPath(path string) (string, error) { + resolved, err := filepath.EvalSymlinks(path) + if err != nil { + return "", err + } + return filepath.Clean(resolved), nil +} + +func canonicalizePathForBoundary(path string) (string, error) { + cleanPath := filepath.Clean(path) + segments := make([]string, 0, 4) + current := cleanPath + + for { + _, err := os.Lstat(current) + if err == nil { + resolved, evalErr := filepath.EvalSymlinks(current) + if evalErr != nil { + return "", evalErr + } + + for i := len(segments) - 1; i >= 0; i-- { + resolved = filepath.Join(resolved, segments[i]) + } + + return filepath.Clean(resolved), nil + } + + if !os.IsNotExist(err) { + return "", err + } + + parent := filepath.Dir(current) + if parent == current { + return "", fmt.Errorf("could not resolve existing parent for %q", path) + } + + segments = append(segments, filepath.Base(current)) + current = parent + } } type ReadFileTool struct { diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 2707f29b5b..f245e4d56c 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -247,3 +247,55 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM) } } + +// TestFilesystemTool_Restrict_BlocksSymlinkEscape verifies symlink traversal outside workspace is blocked. +func TestFilesystemTool_Restrict_BlocksSymlinkEscape(t *testing.T) { + workspace := t.TempDir() + outside := t.TempDir() + secretFile := filepath.Join(outside, "secret.txt") + if err := os.WriteFile(secretFile, []byte("do-not-read"), 0644); err != nil { + t.Fatalf("failed to create outside file: %v", err) + } + + linkPath := filepath.Join(workspace, "leak.txt") + if err := os.Symlink(secretFile, linkPath); err != nil { + t.Skipf("symlink unavailable in this environment: %v", err) + } + + tool := NewReadFileTool(workspace, true) + result := tool.Execute(context.Background(), map[string]interface{}{"path": "leak.txt"}) + if !result.IsError { + t.Fatalf("Expected symlink escape to be blocked") + } + if !strings.Contains(strings.ToLower(result.ForLLM), "outside") { + t.Fatalf("Expected workspace boundary error, got: %s", result.ForLLM) + } +} + +// TestFilesystemTool_Restrict_BlocksPrefixBypass verifies prefix confusion paths are blocked. +func TestFilesystemTool_Restrict_BlocksPrefixBypass(t *testing.T) { + parent := t.TempDir() + workspace := filepath.Join(parent, "workspace") + prefixBypassDir := filepath.Join(parent, "workspace-evil") + + if err := os.MkdirAll(workspace, 0755); err != nil { + t.Fatalf("failed to create workspace: %v", err) + } + if err := os.MkdirAll(prefixBypassDir, 0755); err != nil { + t.Fatalf("failed to create bypass dir: %v", err) + } + + bypassFile := filepath.Join(prefixBypassDir, "stolen.txt") + if err := os.WriteFile(bypassFile, []byte("secret"), 0644); err != nil { + t.Fatalf("failed to create bypass file: %v", err) + } + + tool := NewReadFileTool(workspace, true) + result := tool.Execute(context.Background(), map[string]interface{}{"path": bypassFile}) + if !result.IsError { + t.Fatalf("Expected prefix bypass path to be blocked") + } + if !strings.Contains(strings.ToLower(result.ForLLM), "outside") { + t.Fatalf("Expected workspace boundary error, got: %s", result.ForLLM) + } +} diff --git a/pkg/tools/network_guard.go b/pkg/tools/network_guard.go new file mode 100644 index 0000000000..17eaa3bf21 --- /dev/null +++ b/pkg/tools/network_guard.go @@ -0,0 +1,178 @@ +package tools + +import ( + "context" + "fmt" + "net" + "net/netip" + "net/url" + "strings" +) + +var ( + cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10") + benchmarkPrefix = netip.MustParsePrefix("198.18.0.0/15") + reservedPrefix = netip.MustParsePrefix("240.0.0.0/4") +) + +type fetchTargetValidator struct { + resolver *net.Resolver + allowedHosts map[string]struct{} +} + +func newFetchTargetValidator(allowHosts []string, resolver *net.Resolver) *fetchTargetValidator { + allowed := make(map[string]struct{}, len(allowHosts)) + for _, host := range allowHosts { + normalized := normalizeHostToken(host) + if normalized != "" { + allowed[normalized] = struct{}{} + } + } + + if resolver == nil { + resolver = net.DefaultResolver + } + + return &fetchTargetValidator{ + resolver: resolver, + allowedHosts: allowed, + } +} + +func (v *fetchTargetValidator) validateURL(ctx context.Context, target *url.URL) error { + host := normalizeHostToken(target.Hostname()) + if host == "" { + return fmt.Errorf("missing domain in URL") + } + + port := target.Port() + if v.isAllowed(host, port) { + return nil + } + + if isBlockedHostname(host) { + return fmt.Errorf("blocked destination: host %q is internal-only", target.Hostname()) + } + + if ip, ok := parseIPLiteral(host); ok { + if IsBlockedIP(ip) { + return fmt.Errorf("blocked destination: IP %s is not publicly routable", ip) + } + return nil + } + + addrs, err := v.resolver.LookupNetIP(ctx, "ip", host) + if err != nil { + return fmt.Errorf("failed to resolve host %q: %w", host, err) + } + if len(addrs) == 0 { + return fmt.Errorf("failed to resolve host %q: no records", host) + } + + for _, addr := range addrs { + if IsBlockedIP(addr) { + return fmt.Errorf("blocked destination: host %q resolves to non-public IP %s", host, addr) + } + } + + return nil +} + +func (v *fetchTargetValidator) isAllowed(host, port string) bool { + if len(v.allowedHosts) == 0 { + return false + } + + if _, ok := v.allowedHosts[host]; ok { + return true + } + if port != "" { + if _, ok := v.allowedHosts[host+":"+port]; ok { + return true + } + } + return false +} + +// ValidateFetchTarget applies the default web fetch target policy. +func ValidateFetchTarget(target *url.URL) error { + return newFetchTargetValidator(nil, net.DefaultResolver).validateURL(context.Background(), target) +} + +// IsBlockedIP returns true for IPs that should not be reachable from web_fetch. +func IsBlockedIP(addr netip.Addr) bool { + if !addr.IsValid() { + return true + } + + if addr.IsLoopback() || + addr.IsPrivate() || + addr.IsLinkLocalUnicast() || + addr.IsLinkLocalMulticast() || + addr.IsMulticast() || + addr.IsUnspecified() || + addr.IsInterfaceLocalMulticast() { + return true + } + + if addr.Is4() { + if cgnatPrefix.Contains(addr) || benchmarkPrefix.Contains(addr) || reservedPrefix.Contains(addr) { + return true + } + } + + return false +} + +func isBlockedHostname(host string) bool { + if host == "localhost" || strings.HasSuffix(host, ".localhost") { + return true + } + if strings.HasSuffix(host, ".local") || strings.HasSuffix(host, ".internal") { + return true + } + if host == "metadata.google.internal" || host == "metadata" { + return true + } + return false +} + +func parseIPLiteral(host string) (netip.Addr, bool) { + if i := strings.Index(host, "%"); i >= 0 { + host = host[:i] + } + addr, err := netip.ParseAddr(host) + if err != nil { + return netip.Addr{}, false + } + return addr.Unmap(), true +} + +func normalizeHostToken(raw string) string { + return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(raw)), ".") +} + +func guardedDialContext(base *net.Dialer, validator *fetchTargetValidator) func(ctx context.Context, network, address string) (net.Conn, error) { + if base == nil { + base = &net.Dialer{} + } + + return func(ctx context.Context, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + host = address + port = "" + } + + target := &url.URL{Host: host} + if port != "" { + target.Host = net.JoinHostPort(host, port) + } + + if err := validator.validateURL(ctx, target); err != nil { + return nil, err + } + + return base.DialContext(ctx, network, address) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 1ca3fc35a9..9d7cd603ec 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -94,33 +94,62 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To var cmd *exec.Cmd if runtime.GOOS == "windows" { - cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command) + cmd = exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", command) } else { - cmd = exec.CommandContext(cmdCtx, "sh", "-c", command) + cmd = exec.Command("sh", "-c", command) } if cwd != "" { cmd.Dir = cwd } + prepareCommandForTreeControl(cmd) + var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr - err := cmd.Run() + if err := cmd.Start(); err != nil { + return ErrorResult(fmt.Sprintf("failed to start command: %v", err)) + } + + waitDone := make(chan error, 1) + go func() { + waitDone <- cmd.Wait() + }() + + var err error + timedOut := false + + select { + case err = <-waitDone: + case <-cmdCtx.Done(): + if killErr := killCommandTree(cmd); killErr != nil { + return ErrorResult(fmt.Sprintf("failed to terminate command tree: %v", killErr)) + } + + timedOut = cmdCtx.Err() == context.DeadlineExceeded + select { + case err = <-waitDone: + case <-time.After(3 * time.Second): + return ErrorResult("command termination timed out after deadline") + } + } + output := stdout.String() if stderr.Len() > 0 { output += "\nSTDERR:\n" + stderr.String() } - if err != nil { - if cmdCtx.Err() == context.DeadlineExceeded { - msg := fmt.Sprintf("Command timed out after %v", t.timeout) - return &ToolResult{ - ForLLM: msg, - ForUser: msg, - IsError: true, - } + if timedOut { + msg := fmt.Sprintf("Command timed out after %v", t.timeout) + return &ToolResult{ + ForLLM: msg, + ForUser: msg, + IsError: true, } + } + + if err != nil { output += fmt.Sprintf("\nExit code: %v", err) } diff --git a/pkg/tools/shell_process_unix.go b/pkg/tools/shell_process_unix.go new file mode 100644 index 0000000000..6037aa7489 --- /dev/null +++ b/pkg/tools/shell_process_unix.go @@ -0,0 +1,37 @@ +//go:build !windows + +package tools + +import ( + "fmt" + "os" + "os/exec" + "syscall" +) + +func prepareCommandForTreeControl(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } +} + +func killCommandTree(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + pgid, err := syscall.Getpgid(cmd.Process.Pid) + if err == nil { + if killErr := syscall.Kill(-pgid, syscall.SIGKILL); killErr == nil || killErr == syscall.ESRCH { + return nil + } else { + return fmt.Errorf("failed to kill process group %d: %w", pgid, killErr) + } + } + + if killErr := cmd.Process.Kill(); killErr != nil && killErr != os.ErrProcessDone { + return fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, killErr) + } + + return nil +} diff --git a/pkg/tools/shell_process_windows.go b/pkg/tools/shell_process_windows.go new file mode 100644 index 0000000000..6669b3dc87 --- /dev/null +++ b/pkg/tools/shell_process_windows.go @@ -0,0 +1,30 @@ +//go:build windows + +package tools + +import ( + "fmt" + "os/exec" + "strconv" + "syscall" +) + +func prepareCommandForTreeControl(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, + } +} + +func killCommandTree(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + pid := strconv.Itoa(cmd.Process.Pid) + killCmd := exec.Command("taskkill", "/T", "/F", "/PID", pid) + if err := killCmd.Run(); err != nil { + return fmt.Errorf("taskkill failed for pid %s: %w", pid, err) + } + + return nil +} diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index c06468a39a..603072dcde 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -2,8 +2,12 @@ package tools import ( "context" + "fmt" "os" + "os/exec" "path/filepath" + "runtime" + "strconv" "strings" "testing" "time" @@ -208,3 +212,95 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) { t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) } } + +// TestShellTool_Timeout_KillsChildProcesses verifies timeout cleanup includes child processes. +func TestShellTool_Timeout_KillsChildProcesses(t *testing.T) { + tool := NewExecTool("", false) + tool.SetTimeout(1200 * time.Millisecond) + + pidFile := filepath.Join(t.TempDir(), "child.pid") + var cmd string + if runtime.GOOS == "windows" { + escapedPidFile := strings.ReplaceAll(pidFile, "'", "''") + cmd = fmt.Sprintf( + "$p = Start-Process -FilePath powershell -ArgumentList '-NoProfile','-NonInteractive','-Command','Start-Sleep -Seconds 30' -WindowStyle Hidden -PassThru; Set-Content -Path '%s' -Value $p.Id; Start-Sleep -Seconds 30", + escapedPidFile, + ) + } else { + cmd = fmt.Sprintf("sleep 30 & echo $! > %q; sleep 30", pidFile) + } + + result := tool.Execute(context.Background(), map[string]interface{}{"command": cmd}) + if !result.IsError { + t.Fatalf("Expected timeout error for long-running command tree") + } + if !strings.Contains(strings.ToLower(result.ForLLM), "timed out") { + t.Fatalf("Expected timeout message, got: %s", result.ForLLM) + } + + childPID, err := waitForPID(pidFile, 4*time.Second) + if err != nil { + t.Fatalf("failed to obtain child pid file before timeout: %v", err) + } + + // Give timeout cleanup a short grace period before checking process liveness. + time.Sleep(400 * time.Millisecond) + alive := processAlive(childPID) + if alive { + status := processStatus(childPID) + killProcess(childPID) + t.Fatalf("expected child process %d to be terminated after timeout; status: %s", childPID, status) + } +} + +func waitForPID(path string, timeout time.Duration) (int, error) { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + raw, err := os.ReadFile(path) + if err == nil { + pid, parseErr := strconv.Atoi(strings.TrimSpace(string(raw))) + if parseErr == nil && pid > 0 { + return pid, nil + } + } + time.Sleep(50 * time.Millisecond) + } + return 0, fmt.Errorf("pid file %s not ready within %v", path, timeout) +} + +func processAlive(pid int) bool { + if runtime.GOOS == "windows" { + err := exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", fmt.Sprintf("Get-Process -Id %d | Out-Null", pid)).Run() + return err == nil + } + + output, err := exec.Command("sh", "-c", fmt.Sprintf("ps -o stat= -p %d", pid)).CombinedOutput() + if err != nil { + return false + } + + state := strings.TrimSpace(string(output)) + if state == "" || strings.HasPrefix(state, "Z") { + return false + } + + return true +} + +func processStatus(pid int) string { + if runtime.GOOS == "windows" { + output, _ := exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %d", pid)).CombinedOutput() + return strings.TrimSpace(string(output)) + } + + output, _ := exec.Command("sh", "-c", fmt.Sprintf("ps -o pid=,ppid=,pgid=,stat=,cmd= -p %d", pid)).CombinedOutput() + return strings.TrimSpace(string(output)) +} + +func killProcess(pid int) { + if runtime.GOOS == "windows" { + _ = exec.Command("taskkill", "/PID", strconv.Itoa(pid), "/T", "/F").Run() + return + } + _ = exec.Command("kill", "-KILL", strconv.Itoa(pid)).Run() +} diff --git a/pkg/tools/web.go b/pkg/tools/web.go index ccd9958429..5ba87e2e19 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "regexp" @@ -266,15 +267,21 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} } type WebFetchTool struct { - maxChars int + maxChars int + validator *fetchTargetValidator } func NewWebFetchTool(maxChars int) *WebFetchTool { + return newWebFetchTool(maxChars, nil) +} + +func newWebFetchTool(maxChars int, allowHosts []string) *WebFetchTool { if maxChars <= 0 { maxChars = 50000 } return &WebFetchTool{ - maxChars: maxChars, + maxChars: maxChars, + validator: newFetchTargetValidator(allowHosts, net.DefaultResolver), } } @@ -330,6 +337,10 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) } } + if err := t.validator.validateURL(ctx, parsedURL); err != nil { + return ErrorResult(err.Error()) + } + req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { return ErrorResult(fmt.Sprintf("failed to create request: %v", err)) @@ -337,6 +348,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) req.Header.Set("User-Agent", userAgent) + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + client := &http.Client{ Timeout: 60 * time.Second, Transport: &http.Transport{ @@ -344,11 +360,15 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) IdleConnTimeout: 30 * time.Second, DisableCompression: false, TLSHandshakeTimeout: 15 * time.Second, + DialContext: guardedDialContext(dialer, t.validator), }, CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) >= 5 { return fmt.Errorf("stopped after 5 redirects") } + if err := t.validator.validateURL(req.Context(), req.URL); err != nil { + return err + } return nil }, } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 988eada162..40f8dbdf7b 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -5,10 +5,17 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "strings" "testing" ) +// newWebFetchToolForTests centralizes test construction. +// allowHosts is wired in Phase 2 once fetch target policy is introduced. +func newWebFetchToolForTests(maxChars int, allowHosts ...string) *WebFetchTool { + return newWebFetchTool(maxChars, allowHosts) +} + // TestWebTool_WebFetch_Success verifies successful URL fetching func TestWebTool_WebFetch_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -18,7 +25,11 @@ func TestWebTool_WebFetch_Success(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + tool := newWebFetchToolForTests(50000, parsed.Host) ctx := context.Background() args := map[string]interface{}{ "url": server.URL, @@ -54,7 +65,11 @@ func TestWebTool_WebFetch_JSON(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + tool := newWebFetchToolForTests(50000, parsed.Host) ctx := context.Background() args := map[string]interface{}{ "url": server.URL, @@ -145,7 +160,11 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(1000) // Limit to 1000 chars + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + tool := newWebFetchToolForTests(1000, parsed.Host) // Limit to 1000 chars ctx := context.Background() args := map[string]interface{}{ "url": server.URL, @@ -206,7 +225,11 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + tool := newWebFetchToolForTests(50000, parsed.Host) ctx := context.Background() args := map[string]interface{}{ "url": server.URL, @@ -250,3 +273,47 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) { t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM) } } + +// TestWebTool_WebFetch_BlocksLoopback verifies loopback targets are blocked +func TestWebTool_WebFetch_BlocksLoopback(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": "http://127.0.0.1/", + } + + result := tool.Execute(ctx, args) + if !result.IsError { + t.Fatalf("Expected blocked destination error for loopback target") + } + if !strings.Contains(strings.ToLower(result.ForLLM), "blocked destination") { + t.Fatalf("Expected blocked destination message, got: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_BlocksRedirectToPrivate verifies private redirect hops are blocked +func TestWebTool_WebFetch_BlocksRedirectToPrivate(t *testing.T) { + redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "http://127.0.0.1:1/private", http.StatusFound) + })) + defer redirectServer.Close() + + parsed, err := url.Parse(redirectServer.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + + tool := newWebFetchToolForTests(50000, parsed.Host) + ctx := context.Background() + args := map[string]interface{}{ + "url": redirectServer.URL, + } + + result := tool.Execute(ctx, args) + if !result.IsError { + t.Fatalf("Expected blocked destination error for redirect to private target") + } + if !strings.Contains(strings.ToLower(result.ForLLM), "blocked destination") { + t.Fatalf("Expected blocked destination message, got: %s", result.ForLLM) + } +} From a6c051e0d92dc091cf03bb91333955fec74838ba Mon Sep 17 00:00:00 2001 From: Jared Mahotiere Date: Thu, 19 Feb 2026 10:59:51 -0500 Subject: [PATCH 3/3] refactor(telegram): use shared SplitMessage utility --- pkg/channels/telegram.go | 129 +++++++-------------------- pkg/utils/message.go | 179 ++++++++++++++++++++++++++++++++++++++ pkg/utils/message_test.go | 151 ++++++++++++++++++++++++++++++++ 3 files changed, 362 insertions(+), 97 deletions(-) create mode 100644 pkg/utils/message.go create mode 100644 pkg/utils/message_test.go diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 09f69da2a0..b7614a9b68 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -41,7 +41,6 @@ type thinkingCancel struct { const ( telegramMaxMessageLength = 4096 - telegramSplitTarget = 3900 ) func (c *thinkingCancel) Cancel() { @@ -499,115 +498,51 @@ func splitTelegramMessageContent(text string, maxLen int) []string { return []string{text} } - if runeLen(markdownToTelegramHTML(text)) <= maxLen { - return []string{text} - } - - target := telegramSplitTarget - if target >= maxLen { - target = maxLen - 64 - } - if target < 256 { - target = maxLen / 2 - } - if target < 1 { - target = 1 - } - - parts := splitTextByBoundary(text, target) - if len(parts) == 1 { - runes := []rune(text) - if len(runes) <= 1 { - return parts - } - mid := len(runes) / 2 - if mid < 1 { - mid = 1 - } - parts = []string{ - strings.TrimSpace(string(runes[:mid])), - strings.TrimSpace(string(runes[mid:])), - } - } + chunks := utils.SplitMessage(text, maxLen) + return enforceTelegramMessageHTMLLimit(chunks, maxLen) +} - out := make([]string, 0, len(parts)) - for _, p := range parts { - p = strings.TrimSpace(p) - if p == "" { +func enforceTelegramMessageHTMLLimit(chunks []string, maxLen int) []string { + out := make([]string, 0, len(chunks)) + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { continue } - out = append(out, splitTelegramMessageContent(p, maxLen)...) - } - return out -} -func splitTextByBoundary(text string, limit int) []string { - runes := []rune(text) - if len(runes) <= limit { - return []string{text} - } - - result := make([]string, 0) - for len(runes) > 0 { - if len(runes) <= limit { - tail := strings.TrimSpace(string(runes)) - if tail != "" { - result = append(result, tail) - } - break + if runeLen(markdownToTelegramHTML(chunk)) <= maxLen { + out = append(out, chunk) + continue } - splitAt := findSplitPoint(runes, limit) - if splitAt <= 0 { - splitAt = limit - } - if splitAt > len(runes) { - splitAt = len(runes) + runes := []rune(chunk) + if len(runes) <= 1 { + out = append(out, chunk) + continue } - chunk := strings.TrimSpace(string(runes[:splitAt])) - if chunk != "" { - result = append(result, chunk) + splitLimit := len(runes) / 2 + if splitLimit > maxLen { + splitLimit = maxLen } - - runes = runes[splitAt:] - for len(runes) > 0 && (runes[0] == '\n' || runes[0] == '\r' || runes[0] == ' ' || runes[0] == '\t') { - runes = runes[1:] + if splitLimit < 1 { + splitLimit = 1 } - } - - return result -} - -func findSplitPoint(runes []rune, limit int) int { - if len(runes) <= limit { - return len(runes) - } - if limit <= 1 { - return 1 - } - - floor := limit / 2 - if floor < 1 { - floor = 1 - } - for i := limit; i > floor; i-- { - if i > 1 && runes[i-1] == '\n' && runes[i-2] == '\n' { - return i - } - } - for i := limit; i > floor; i-- { - if runes[i-1] == '\n' { - return i - } - } - for i := limit; i > floor; i-- { - if runes[i-1] == ' ' || runes[i-1] == '\t' { - return i + subChunks := utils.SplitMessage(chunk, splitLimit) + if len(subChunks) <= 1 { + mid := len(runes) / 2 + if mid < 1 { + mid = 1 + } + subChunks = []string{ + string(runes[:mid]), + string(runes[mid:]), + } } + out = append(out, enforceTelegramMessageHTMLLimit(subChunks, maxLen)...) } - return limit + return out } func runeLen(text string) int { diff --git a/pkg/utils/message.go b/pkg/utils/message.go new file mode 100644 index 0000000000..1d05950d9f --- /dev/null +++ b/pkg/utils/message.go @@ -0,0 +1,179 @@ +package utils + +import ( + "strings" +) + +// SplitMessage splits long messages into chunks, preserving code block integrity. +// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks, +// but may extend to maxLen when needed. +// Call SplitMessage with the full text content and the maximum allowed length of a single message; +// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks. +func SplitMessage(content string, maxLen int) []string { + var messages []string + + // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible + codeBlockBuffer := maxLen / 10 + if codeBlockBuffer < 50 { + codeBlockBuffer = 50 + } + if codeBlockBuffer > maxLen/2 { + codeBlockBuffer = maxLen / 2 + } + + for len(content) > 0 { + if len(content) <= maxLen { + messages = append(messages, content) + break + } + + // Effective split point: maxLen minus buffer, to leave room for code blocks + effectiveLimit := maxLen - codeBlockBuffer + if effectiveLimit < maxLen/2 { + effectiveLimit = maxLen / 2 + } + + // Find natural split point within the effective limit + msgEnd := findLastNewline(content[:effectiveLimit], 200) + if msgEnd <= 0 { + msgEnd = findLastSpace(content[:effectiveLimit], 100) + } + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + // Check if this would end with an incomplete code block + candidate := content[:msgEnd] + unclosedIdx := findLastUnclosedCodeBlock(candidate) + + if unclosedIdx >= 0 { + // Message would end with incomplete code block + // Try to extend up to maxLen to include the closing ``` + if len(content) > msgEnd { + closingIdx := findNextClosingCodeBlock(content, msgEnd) + if closingIdx > 0 && closingIdx <= maxLen { + // Extend to include the closing ``` + msgEnd = closingIdx + } else { + // Code block is too long to fit in one chunk or missing closing fence. + // Try to split inside by injecting closing and reopening fences. + headerEnd := strings.Index(content[unclosedIdx:], "\n") + if headerEnd == -1 { + headerEnd = unclosedIdx + 3 + } else { + headerEnd += unclosedIdx + } + header := strings.TrimSpace(content[unclosedIdx:headerEnd]) + + // If we have a reasonable amount of content after the header, split inside + if msgEnd > headerEnd+20 { + // Find a better split point closer to maxLen + innerLimit := maxLen - 5 // Leave room for "\n```" + betterEnd := findLastNewline(content[:innerLimit], 200) + if betterEnd > headerEnd { + msgEnd = betterEnd + } else { + msgEnd = innerLimit + } + messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") + content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + continue + } + + // Otherwise, try to split before the code block starts + newEnd := findLastNewline(content[:unclosedIdx], 200) + if newEnd <= 0 { + newEnd = findLastSpace(content[:unclosedIdx], 100) + } + if newEnd > 0 { + msgEnd = newEnd + } else { + // If we can't split before, we MUST split inside (last resort) + if unclosedIdx > 20 { + msgEnd = unclosedIdx + } else { + msgEnd = maxLen - 5 + messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") + content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + continue + } + } + } + } + } + + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + messages = append(messages, content[:msgEnd]) + content = strings.TrimSpace(content[msgEnd:]) + } + + return messages +} + +// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ``` +// Returns the position of the opening ``` or -1 if all code blocks are complete +func findLastUnclosedCodeBlock(text string) int { + inCodeBlock := false + lastOpenIdx := -1 + + for i := 0; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + // Toggle code block state on each fence + if !inCodeBlock { + // Entering a code block: record this opening fence + lastOpenIdx = i + } + inCodeBlock = !inCodeBlock + i += 2 + } + } + + if inCodeBlock { + return lastOpenIdx + } + return -1 +} + +// findNextClosingCodeBlock finds the next closing ``` starting from a position +// Returns the position after the closing ``` or -1 if not found +func findNextClosingCodeBlock(text string, startIdx int) int { + for i := startIdx; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + return i + 3 + } + } + return -1 +} + +// findLastNewline finds the last newline character within the last N characters +// Returns the position of the newline or -1 if not found +func findLastNewline(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == '\n' { + return i + } + } + return -1 +} + +// findLastSpace finds the last space character within the last N characters +// Returns the position of the space or -1 if not found +func findLastSpace(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == ' ' || s[i] == '\t' { + return i + } + } + return -1 +} diff --git a/pkg/utils/message_test.go b/pkg/utils/message_test.go new file mode 100644 index 0000000000..338509437d --- /dev/null +++ b/pkg/utils/message_test.go @@ -0,0 +1,151 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestSplitMessage(t *testing.T) { + longText := strings.Repeat("a", 2500) + longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars + + tests := []struct { + name string + content string + maxLen int + expectChunks int // Check number of chunks + checkContent func(t *testing.T, chunks []string) // Custom validation + }{ + { + name: "Empty message", + content: "", + maxLen: 2000, + expectChunks: 0, + }, + { + name: "Short message fits in one chunk", + content: "Hello world", + maxLen: 2000, + expectChunks: 1, + }, + { + name: "Simple split regular text", + content: longText, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len(chunks[0]) > 2000 { + t.Errorf("Chunk 0 too large: %d", len(chunks[0])) + } + if len(chunks[0])+len(chunks[1]) != len(longText) { + t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText)) + } + }, + }, + { + name: "Split at newline", + // 1750 chars then newline, then more chars. + // Dynamic buffer: 2000 / 10 = 200. + // Effective limit: 2000 - 200 = 1800. + // Split should happen at newline because it's at 1750 (< 1800). + // Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051. + content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300), + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len(chunks[0]) != 1750 { + t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0])) + } + if chunks[1] != strings.Repeat("b", 300) { + t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1])) + } + }, + }, + { + name: "Long code block split", + content: "Prefix\n" + longCode, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Check that first chunk ends with closing fence + if !strings.HasSuffix(chunks[0], "\n```") { + t.Error("First chunk should end with injected closing fence") + } + // Check that second chunk starts with execution header + if !strings.HasPrefix(chunks[1], "```go") { + t.Error("Second chunk should start with injected code block header") + } + }, + }, + { + name: "Preserve Unicode characters", + content: strings.Repeat("\u4e16", 1000), // 3000 bytes + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Just verify we didn't panic and got valid strings. + // Go strings are UTF-8, if we split mid-rune it would be bad, + // but standard slicing might do that. + // Let's assume standard behavior is acceptable or check if it produces invalid rune? + if !strings.Contains(chunks[0], "\u4e16") { + t.Error("Chunk should contain unicode characters") + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SplitMessage(tc.content, tc.maxLen) + + if tc.expectChunks == 0 { + if len(got) != 0 { + t.Errorf("Expected 0 chunks, got %d", len(got)) + } + return + } + + if len(got) != tc.expectChunks { + t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got)) + // Log sizes for debugging + for i, c := range got { + t.Logf("Chunk %d length: %d", i, len(c)) + } + return // Stop further checks if count assumes specific split + } + + if tc.checkContent != nil { + tc.checkContent(t, got) + } + }) + } +} + +func TestSplitMessage_CodeBlockIntegrity(t *testing.T) { + // Focused test for the core requirement: splitting inside a code block preserves syntax highlighting + + // 60 chars total approximately + content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```" + maxLen := 40 + + chunks := SplitMessage(content, maxLen) + + if len(chunks) != 2 { + t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks) + } + + // First chunk must end with "\n```" + if !strings.HasSuffix(chunks[0], "\n```") { + t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0]) + } + + // Second chunk must start with the header "```go" + if !strings.HasPrefix(chunks[1], "```go") { + t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1]) + } + + // First chunk should contain meaningful content + if len(chunks[0]) > 40 { + t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0])) + } +}