diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 713850f977..11a1d59da7 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -3,6 +3,7 @@ package tools import ( "bytes" "context" + "errors" "fmt" "os" "os/exec" @@ -109,18 +110,43 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To cmd.Dir = cwd } + prepareCommandForTermination(cmd) + var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr - err := cmd.Run() + if err := cmd.Start(); err != nil { + return ErrorResult(fmt.Sprintf("failed to start command: %v", err)) + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + var err error + select { + case err = <-done: + case <-cmdCtx.Done(): + _ = terminateProcessTree(cmd) + select { + case err = <-done: + case <-time.After(2 * time.Second): + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + err = <-done + } + } + output := stdout.String() if stderr.Len() > 0 { output += "\nSTDERR:\n" + stderr.String() } if err != nil { - if cmdCtx.Err() == context.DeadlineExceeded { + if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) { msg := fmt.Sprintf("Command timed out after %v", t.timeout) return &ToolResult{ ForLLM: msg, diff --git a/pkg/tools/shell_process_unix.go b/pkg/tools/shell_process_unix.go new file mode 100644 index 0000000000..7b29a81bf7 --- /dev/null +++ b/pkg/tools/shell_process_unix.go @@ -0,0 +1,32 @@ +//go:build !windows + +package tools + +import ( + "os/exec" + "syscall" +) + +func prepareCommandForTermination(cmd *exec.Cmd) { + if cmd == nil { + return + } + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} +} + +func terminateProcessTree(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + pid := cmd.Process.Pid + if pid <= 0 { + return nil + } + + // Kill the entire process group spawned by the shell command. + _ = syscall.Kill(-pid, syscall.SIGKILL) + // Fallback kill on the shell process itself. + _ = cmd.Process.Kill() + return nil +} diff --git a/pkg/tools/shell_process_windows.go b/pkg/tools/shell_process_windows.go new file mode 100644 index 0000000000..fe23b5c96f --- /dev/null +++ b/pkg/tools/shell_process_windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package tools + +import ( + "os/exec" + "strconv" +) + +func prepareCommandForTermination(cmd *exec.Cmd) { + // no-op on Windows +} + +func terminateProcessTree(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + pid := cmd.Process.Pid + if pid <= 0 { + return nil + } + + _ = exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(pid)).Run() + _ = cmd.Process.Kill() + return nil +} diff --git a/pkg/tools/shell_timeout_unix_test.go b/pkg/tools/shell_timeout_unix_test.go new file mode 100644 index 0000000000..4c6388b9b0 --- /dev/null +++ b/pkg/tools/shell_timeout_unix_test.go @@ -0,0 +1,61 @@ +//go:build !windows + +package tools + +import ( + "context" + "os" + "path/filepath" + "strconv" + "strings" + "syscall" + "testing" + "time" +) + +func processExists(pid int) bool { + if pid <= 0 { + return false + } + err := syscall.Kill(pid, 0) + return err == nil || err == syscall.EPERM +} + +func TestShellTool_TimeoutKillsChildProcess(t *testing.T) { + tool := NewExecTool(t.TempDir(), false) + tool.SetTimeout(500 * time.Millisecond) + + args := map[string]interface{}{ + // Spawn a child process that would outlive the shell unless process-group kill is used. + "command": "sleep 60 & echo $! > child.pid; wait", + } + + result := tool.Execute(context.Background(), args) + if !result.IsError { + t.Fatalf("expected timeout error, got success: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "timed out") { + t.Fatalf("expected timeout message, got: %s", result.ForLLM) + } + + childPIDPath := filepath.Join(tool.workingDir, "child.pid") + data, err := os.ReadFile(childPIDPath) + if err != nil { + t.Fatalf("failed to read child pid file: %v", err) + } + + childPID, err := strconv.Atoi(strings.TrimSpace(string(data))) + if err != nil { + t.Fatalf("failed to parse child pid: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if !processExists(childPID) { + return + } + time.Sleep(50 * time.Millisecond) + } + + t.Fatalf("child process %d is still running after timeout", childPID) +}