Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions pkg/tools/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tools
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"os/exec"
Expand Down Expand Up @@ -109,18 +110,43 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
cmd.Dir = cwd
}

prepareCommandForTermination(cmd)

var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr

err := cmd.Run()
if err := cmd.Start(); err != nil {
return ErrorResult(fmt.Sprintf("failed to start command: %v", err))
}

done := make(chan error, 1)
go func() {
done <- cmd.Wait()
}()

var err error
select {
case err = <-done:
case <-cmdCtx.Done():
_ = terminateProcessTree(cmd)
select {
case err = <-done:
case <-time.After(2 * time.Second):
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
err = <-done
}
}

output := stdout.String()
if stderr.Len() > 0 {
output += "\nSTDERR:\n" + stderr.String()
}

if err != nil {
if cmdCtx.Err() == context.DeadlineExceeded {
if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) {
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
return &ToolResult{
ForLLM: msg,
Expand Down
32 changes: 32 additions & 0 deletions pkg/tools/shell_process_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//go:build !windows

package tools

import (
"os/exec"
"syscall"
)

func prepareCommandForTermination(cmd *exec.Cmd) {
if cmd == nil {
return
}
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}

func terminateProcessTree(cmd *exec.Cmd) error {
if cmd == nil || cmd.Process == nil {
return nil
}

pid := cmd.Process.Pid
if pid <= 0 {
return nil
}

// Kill the entire process group spawned by the shell command.
_ = syscall.Kill(-pid, syscall.SIGKILL)
// Fallback kill on the shell process itself.
_ = cmd.Process.Kill()
return nil
}
27 changes: 27 additions & 0 deletions pkg/tools/shell_process_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//go:build windows

package tools

import (
"os/exec"
"strconv"
)

func prepareCommandForTermination(cmd *exec.Cmd) {
// no-op on Windows
}

func terminateProcessTree(cmd *exec.Cmd) error {
if cmd == nil || cmd.Process == nil {
return nil
}

pid := cmd.Process.Pid
if pid <= 0 {
return nil
}

_ = exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(pid)).Run()
_ = cmd.Process.Kill()
return nil
}
61 changes: 61 additions & 0 deletions pkg/tools/shell_timeout_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//go:build !windows

package tools

import (
"context"
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
"testing"
"time"
)

func processExists(pid int) bool {
if pid <= 0 {
return false
}
err := syscall.Kill(pid, 0)
return err == nil || err == syscall.EPERM
}

func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
tool := NewExecTool(t.TempDir(), false)
tool.SetTimeout(500 * time.Millisecond)

args := map[string]interface{}{
// Spawn a child process that would outlive the shell unless process-group kill is used.
"command": "sleep 60 & echo $! > child.pid; wait",
}

result := tool.Execute(context.Background(), args)
if !result.IsError {
t.Fatalf("expected timeout error, got success: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "timed out") {
t.Fatalf("expected timeout message, got: %s", result.ForLLM)
}

childPIDPath := filepath.Join(tool.workingDir, "child.pid")
data, err := os.ReadFile(childPIDPath)
if err != nil {
t.Fatalf("failed to read child pid file: %v", err)
}

childPID, err := strconv.Atoi(strings.TrimSpace(string(data)))
if err != nil {
t.Fatalf("failed to parse child pid: %v", err)
}

deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if !processExists(childPID) {
return
}
time.Sleep(50 * time.Millisecond)
}

t.Fatalf("child process %d is still running after timeout", childPID)
}