diff --git a/Makefile b/Makefile index 77af92d4a8..9581fa6332 100644 --- a/Makefile +++ b/Makefile @@ -47,6 +47,13 @@ define PATCH_MIPS_FLAGS fi endef +# Patch creack/pty for loong64 support (upstream doesn't have ztypes_loong64.go) +PTY_PATCH_LOONG64=pty_dir=$$(go env GOMODCACHE)/github.com/creack/pty@v1.1.9; \ + if [ -d "$$pty_dir" ] && [ ! -f "$$pty_dir/ztypes_loong64.go" ]; then \ + chmod +w "$$pty_dir" 2>/dev/null || true; \ + printf '//go:build linux && loong64\npackage pty\ntype (_C_int int32; _C_uint uint32)\n' > "$$pty_dir/ztypes_loong64.go"; \ + fi + # Golangci-lint GOLANGCI_LINT?=golangci-lint @@ -190,6 +197,7 @@ build-all: generate GOOS=linux GOARCH=amd64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR) GOOS=linux GOARCH=arm64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) + @$(PTY_PATCH_LOONG64) GOOS=linux GOARCH=loong64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) GOOS=linux GOARCH=riscv64 $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build $(GOFLAGS_NO_GOOLM) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) diff --git a/go.mod b/go.mod index e9ef37e984..e8f00d7a52 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2 github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.4.0 + github.com/creack/pty v1.1.9 github.com/ergochat/irc-go v0.6.0 github.com/ergochat/readline v0.1.3 github.com/gdamore/tcell/v2 v2.13.8 diff --git a/go.sum b/go.sum index 87117bc986..7cebd17446 100644 --- a/go.sum +++ b/go.sum @@ -69,6 +69,7 @@ github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9 github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.9 h1:uDmaGzcdjhF4i/plgjmEsriH11Y0o7RKapEf/LDaM3w= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go index e073cb929b..e296a18cbf 100644 --- a/pkg/agent/instance_test.go +++ b/pkg/agent/instance_test.go @@ -236,8 +236,9 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) { t.Fatal("exec tool not registered") } execResult := execTool.Execute(context.Background(), map[string]any{ - "command": "cat " + filepath.Base(mediaPath), - "working_dir": mediaDir, + "action": "run", + "command": "cat " + filepath.Base(mediaPath), + "cwd": mediaDir, }) if execResult.IsError { t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM) diff --git a/pkg/tools/session.go b/pkg/tools/session.go new file mode 100644 index 0000000000..141dd4b5ea --- /dev/null +++ b/pkg/tools/session.go @@ -0,0 +1,252 @@ +package tools + +import ( + "bytes" + "errors" + "io" + "os" + "sync" + "time" + + "github.com/google/uuid" +) + +const maxOutputBufferSize = 1 * 1024 * 1024 // 1MB + +const outputTruncateMarker = "\n... [output truncated, exceeded 1MB]\n" + +// PtyKeyMode represents arrow key encoding mode for PTY sessions. +// Programs send smkx/rmkx sequences to switch between CSI and SS3 modes. +type PtyKeyMode uint8 + +const ( + PtyKeyModeCSI PtyKeyMode = iota // triggered by rmkx (\x1b[?1l) + PtyKeyModeSS3 // triggered by smkx (\x1b[?1h) +) + +const PtyKeyModeNotFound PtyKeyMode = 255 + +var ( + ErrSessionNotFound = errors.New("session not found") + ErrSessionDone = errors.New("session already completed") + ErrPTYNotSupported = errors.New("PTY is not supported on this platform") + ErrNoStdin = errors.New("no stdin available") +) + +type ProcessSession struct { + mu sync.Mutex + ID string + PID int + Command string + PTY bool + Background bool + StartTime int64 + ExitCode int + Status string + stdinWriter io.Writer + stdoutPipe io.Reader + outputBuffer *bytes.Buffer + outputTruncated bool + ptyMaster *os.File + + // ptyKeyMode tracks arrow key encoding mode (CSI vs SS3) + ptyKeyMode PtyKeyMode +} + +func (s *ProcessSession) IsDone() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.Status == "done" || s.Status == "exited" +} + +func (s *ProcessSession) GetPtyKeyMode() PtyKeyMode { + s.mu.Lock() + defer s.mu.Unlock() + return s.ptyKeyMode +} + +func (s *ProcessSession) SetPtyKeyMode(mode PtyKeyMode) { + s.mu.Lock() + defer s.mu.Unlock() + s.ptyKeyMode = mode +} + +func (s *ProcessSession) GetStatus() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.Status +} + +func (s *ProcessSession) SetStatus(status string) { + s.mu.Lock() + defer s.mu.Unlock() + s.Status = status +} + +func (s *ProcessSession) GetExitCode() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.ExitCode +} + +func (s *ProcessSession) SetExitCode(code int) { + s.mu.Lock() + defer s.mu.Unlock() + s.ExitCode = code +} + +func (s *ProcessSession) killProcess() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.Status != "running" { + return ErrSessionDone + } + + pid := s.PID + if pid <= 0 { + return ErrSessionNotFound + } + + if err := killProcessGroup(pid); err != nil { + return err + } + + s.Status = "done" + s.ExitCode = -1 + return nil +} + +func (s *ProcessSession) Kill() error { + return s.killProcess() +} + +func (s *ProcessSession) Write(data string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.Status != "running" { + return ErrSessionDone + } + + var writer io.Writer + if s.PTY && s.ptyMaster != nil { + writer = s.ptyMaster + } else if s.stdinWriter != nil { + writer = s.stdinWriter + } else { + return ErrNoStdin + } + + _, err := writer.Write([]byte(data)) + return err +} + +func (s *ProcessSession) Read() string { + s.mu.Lock() + defer s.mu.Unlock() + + if s.outputBuffer.Len() == 0 { + return "" + } + + data := s.outputBuffer.String() + s.outputBuffer.Reset() + return data +} + +func (s *ProcessSession) ToSessionInfo() SessionInfo { + s.mu.Lock() + defer s.mu.Unlock() + + return SessionInfo{ + ID: s.ID, + Command: s.Command, + Status: s.Status, + PID: s.PID, + StartedAt: s.StartTime, + } +} + +type SessionManager struct { + mu sync.RWMutex + sessions map[string]*ProcessSession +} + +func NewSessionManager() *SessionManager { + sm := &SessionManager{ + sessions: make(map[string]*ProcessSession), + } + + // Start cleaner goroutine - runs every 5 minutes, cleans up sessions done for >30 minutes + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + sm.cleanupOldSessions() + } + }() + + return sm +} + +// cleanupOldSessions removes sessions that are done and older than 30 minutes +func (sm *SessionManager) cleanupOldSessions() { + sm.mu.Lock() + defer sm.mu.Unlock() + + cutoff := time.Now().Add(-30 * time.Minute) + for id, session := range sm.sessions { + if session.IsDone() && session.StartTime < cutoff.Unix() { + delete(sm.sessions, id) + } + } +} + +func (sm *SessionManager) Add(session *ProcessSession) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.sessions[session.ID] = session +} + +func (sm *SessionManager) Get(sessionID string) (*ProcessSession, error) { + sm.mu.RLock() + defer sm.mu.RUnlock() + + session, ok := sm.sessions[sessionID] + if !ok { + return nil, ErrSessionNotFound + } + + return session, nil +} + +func (sm *SessionManager) Remove(sessionID string) { + sm.mu.Lock() + defer sm.mu.Unlock() + delete(sm.sessions, sessionID) +} + +func (sm *SessionManager) List() []SessionInfo { + sm.mu.RLock() + defer sm.mu.RUnlock() + + result := make([]SessionInfo, 0, len(sm.sessions)) + for _, session := range sm.sessions { + result = append(result, session.ToSessionInfo()) + } + + return result +} + +func generateSessionID() string { + return uuid.New().String()[:8] +} + +type SessionInfo struct { + ID string `json:"id"` + Command string `json:"command"` + Status string `json:"status"` + PID int `json:"pid"` + StartedAt int64 `json:"startedAt"` +} diff --git a/pkg/tools/session_process_unix.go b/pkg/tools/session_process_unix.go new file mode 100644 index 0000000000..2fe30166e5 --- /dev/null +++ b/pkg/tools/session_process_unix.go @@ -0,0 +1,14 @@ +//go:build !windows + +package tools + +import ( + "syscall" +) + +func killProcessGroup(pid int) error { + if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil { + _ = syscall.Kill(pid, syscall.SIGKILL) + } + return nil +} diff --git a/pkg/tools/session_process_windows.go b/pkg/tools/session_process_windows.go new file mode 100644 index 0000000000..7cf5589540 --- /dev/null +++ b/pkg/tools/session_process_windows.go @@ -0,0 +1,13 @@ +//go:build windows + +package tools + +import ( + "os/exec" + "strconv" +) + +func killProcessGroup(pid int) error { + _ = exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(pid)).Run() + return nil +} diff --git a/pkg/tools/session_test.go b/pkg/tools/session_test.go new file mode 100644 index 0000000000..6cfe72a107 --- /dev/null +++ b/pkg/tools/session_test.go @@ -0,0 +1,99 @@ +package tools + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSessionManager_AddGet(t *testing.T) { + sm := NewSessionManager() + session := &ProcessSession{ + ID: "test-1", + Command: "echo hello", + Status: "running", + StartTime: 1000, + } + + sm.Add(session) + + got, err := sm.Get("test-1") + require.NoError(t, err) + require.Equal(t, "test-1", got.ID) +} + +func TestSessionManager_Remove(t *testing.T) { + sm := NewSessionManager() + session := &ProcessSession{ + ID: "test-1", + Command: "echo hello", + Status: "running", + StartTime: 1000, + } + sm.Add(session) + sm.Remove("test-1") + + _, err := sm.Get("test-1") + require.ErrorIs(t, err, ErrSessionNotFound) +} + +func TestSessionManager_List(t *testing.T) { + sm := NewSessionManager() + sm.Add(&ProcessSession{ + ID: "test-1", + Command: "echo hello", + Status: "running", + StartTime: 1000, + }) + sm.Add(&ProcessSession{ + ID: "test-2", + Command: "echo world", + Status: "running", + StartTime: 1001, + }) + sm.Add(&ProcessSession{ + ID: "test-3", + Command: "echo done", + Status: "done", + StartTime: 1002, + }) + + sessions := sm.List() + require.Len(t, sessions, 3) + + ids := make(map[string]bool) + for _, s := range sessions { + ids[s.ID] = true + } + require.True(t, ids["test-1"]) + require.True(t, ids["test-2"]) + require.True(t, ids["test-3"]) +} + +func TestProcessSession_IsDone(t *testing.T) { + session := &ProcessSession{Status: "running"} + require.False(t, session.IsDone()) + + session.Status = "done" + require.True(t, session.IsDone()) + + session.Status = "exited" + require.True(t, session.IsDone()) +} + +func TestProcessSession_ToSessionInfo(t *testing.T) { + session := &ProcessSession{ + ID: "test-1", + PID: 12345, + Command: "echo hello", + Status: "running", + StartTime: 1000, + } + + info := session.ToSessionInfo() + require.Equal(t, "test-1", info.ID) + require.Equal(t, "echo hello", info.Command) + require.Equal(t, "running", info.Status) + require.Equal(t, 12345, info.PID) + require.Equal(t, int64(1000), info.StartedAt) +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 78ad2b26d7..6ee1cb9936 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -3,20 +3,36 @@ package tools import ( "bytes" "context" + "encoding/json" "errors" "fmt" + "io" "os" "os/exec" "path/filepath" "regexp" "runtime" "strings" + "sync" "time" + "github.com/creack/pty" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" ) +var ( + globalSessionManager = NewSessionManager() + sessionManagerMu sync.RWMutex +) + +func getSessionManager() *SessionManager { + sessionManagerMu.RLock() + defer sessionManagerMu.RUnlock() + return globalSessionManager +} + type ExecTool struct { workingDir string timeout time.Duration @@ -26,6 +42,7 @@ type ExecTool struct { allowedPathPatterns []*regexp.Regexp restrictToWorkspace bool allowRemote bool + sessionManager *SessionManager } var ( @@ -145,7 +162,7 @@ func NewExecToolWithConfig( denyPatterns = append(denyPatterns, defaultDenyPatterns...) } - timeout := 60 * time.Second + var timeout time.Duration if config != nil && config.Tools.Exec.TimeoutSeconds > 0 { timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second } @@ -159,6 +176,7 @@ func NewExecToolWithConfig( allowedPathPatterns: allowedPathPatterns, restrictToWorkspace: restrict, allowRemote: allowRemote, + sessionManager: getSessionManager(), }, nil } @@ -167,27 +185,82 @@ func (t *ExecTool) Name() string { } func (t *ExecTool) Description() string { - return "Execute a shell command and return its output. Use with caution." + return `Execute shell commands. Use background=true for long-running commands (returns sessionId). Use pty=true for interactive commands (can combine with background=true). Use poll/read/write/send-keys/kill with sessionId to manage background sessions. Sessions auto-cleanup 30 minutes after process exits; use kill to terminate early. Output buffer limit: 1MB.` } func (t *ExecTool) Parameters() map[string]any { return map[string]any{ "type": "object", "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{"run", "list", "poll", "read", "write", "kill", "send-keys"}, + "description": "Action: run (execute command), list (show sessions), poll (check status), read (get output), write (send input), kill (terminate), send-keys (send keys to PTY)", + }, "command": map[string]any{ "type": "string", - "description": "The shell command to execute", + "description": "Shell command to execute (required for run)", + }, + "sessionId": map[string]any{ + "type": "string", + "description": "Session ID (required for poll/read/write/kill/send-keys)", + }, + "keys": map[string]any{ + "type": "string", + "description": "Key names for send-keys: up, down, left, right, enter, tab, escape, backspace, ctrl-c, ctrl-d, home, end, pageup, pagedown, f1-f12", + }, + "data": map[string]any{ + "type": "string", + "description": "Data to write to stdin (required for write)", + }, + "background": map[string]any{ + "type": "string", + "description": "Run in background immediately", }, - "working_dir": map[string]any{ + "pty": map[string]any{ "type": "string", - "description": "Optional working directory for the command", + "description": "Run in a pseudo-terminal (PTY) when available", + }, + "cwd": map[string]any{ + "type": "string", + "description": "Working directory for the command", + }, + "timeout": map[string]any{ + "type": "integer", + "description": "Timeout in seconds (0 = no timeout)", }, }, - "required": []string{"command"}, + "required": []string{"action"}, } } func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + action, _ := args["action"].(string) + if action == "" { + return ErrorResult("action is required") + } + + switch action { + case "run": + return t.executeRun(ctx, args) + case "list": + return t.executeList() + case "poll": + return t.executePoll(args) + case "read": + return t.executeRead(args) + case "write": + return t.executeWrite(args) + case "kill": + return t.executeKill(args) + case "send-keys": + return t.executeSendKeys(args) + default: + return ErrorResult(fmt.Sprintf("unknown action: %s", action)) + } +} + +func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolResult { command, ok := args["command"].(string) if !ok { return ErrorResult("command is required") @@ -206,8 +279,26 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult } } + getBoolArg := func(key string) bool { + switch v := args[key].(type) { + case bool: + return v + case string: + return v == "true" + } + return false + } + isPty := getBoolArg("pty") + isBackground := getBoolArg("background") + + if isPty { + if runtime.GOOS == "windows" { + return ErrorResult("PTY is not supported on Windows. Use background=true without pty.") + } + } + cwd := t.workingDir - if wd, ok := args["working_dir"].(string); ok && wd != "" { + if wd, ok := args["cwd"].(string); ok && wd != "" { if t.restrictToWorkspace && t.workingDir != "" { resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns) if err != nil { @@ -253,6 +344,14 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult } } + if isBackground { + return t.runBackground(ctx, command, cwd, isPty) + } + + return t.runSync(ctx, command, cwd) +} + +func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult { // timeout == 0 means no timeout var cmdCtx context.Context var cancel context.CancelFunc @@ -361,6 +460,560 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult } } +func (t *ExecTool) runBackground(ctx context.Context, command, cwd string, ptyEnabled bool) *ToolResult { + sessionID := generateSessionID() + session := &ProcessSession{ + ID: sessionID, + Command: command, + PTY: ptyEnabled, + Background: true, + StartTime: time.Now().Unix(), + Status: "running", + ptyKeyMode: PtyKeyModeCSI, + } + + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", command) + } else { + cmd = exec.Command("sh", "-c", command) + } + if cwd != "" { + cmd.Dir = cwd + } + + prepareCommandForTermination(cmd) + + var stdoutReader io.ReadCloser + var stderrReader io.ReadCloser + var stdinWriter io.WriteCloser + + if ptyEnabled { + ptmx, tty, err := pty.Open() + if err != nil { + return ErrorResult(fmt.Sprintf("failed to create PTY: %v", err)) + } + + cmd.Stdin = tty + cmd.Stdout = tty + cmd.Stderr = tty + + // For PTY, we need Setsid to create a new session. + // Note: Setsid and Setpgid conflict, so we must replace SysProcAttr entirely. + setSysProcAttrForPty(cmd) + + session.ptyMaster = ptmx + } else { + var err error + stdoutReader, err = cmd.StdoutPipe() + if err != nil { + return ErrorResult(fmt.Sprintf("failed to create stdout pipe: %v", err)) + } + stderrReader, err = cmd.StderrPipe() + if err != nil { + return ErrorResult(fmt.Sprintf("failed to create stderr pipe: %v", err)) + } + stdinWriter, err = cmd.StdinPipe() + if err != nil { + return ErrorResult(fmt.Sprintf("failed to create stdin pipe: %v", err)) + } + session.stdoutPipe = io.MultiReader(stdoutReader, stderrReader) + session.stdinWriter = stdinWriter + } + + if err := cmd.Start(); err != nil { + if session.ptyMaster != nil { + session.ptyMaster.Close() + } + return ErrorResult(fmt.Sprintf("failed to start command: %v", err)) + } + + session.PID = cmd.Process.Pid + t.sessionManager.Add(session) + + session.outputBuffer = &bytes.Buffer{} + + // PTY mode: read from ptyMaster and wait for process + // Note: On Linux, closing ptyMaster doesn't interrupt blocking Read() calls, + // so we need cmd.Wait() in a separate goroutine to detect process exit. + if session.PTY && session.ptyMaster != nil { + go func() { + cmd.Wait() // Wait for process to exit + session.mu.Lock() + if cmd.ProcessState != nil { + session.ExitCode = cmd.ProcessState.ExitCode() + } + session.Status = "done" + session.mu.Unlock() + }() + + go func() { + buf := make([]byte, 4096) + for { + n, err := session.ptyMaster.Read(buf) + if n > 0 { + raw := string(buf[:n]) + if mode := detectPtyKeyMode(raw); mode != PtyKeyModeNotFound && mode != session.GetPtyKeyMode() { + session.SetPtyKeyMode(mode) + } + + session.mu.Lock() + if session.outputBuffer.Len() >= maxOutputBufferSize { + if !session.outputTruncated { + session.outputBuffer.WriteString(outputTruncateMarker) + session.outputTruncated = true + } + } else { + session.outputBuffer.Write(buf[:n]) + } + session.mu.Unlock() + } + if err != nil { + break + } + } + }() + } else { + // Non-PTY mode: single goroutine reads pipes. + // When Read() returns EOF (pipe closed), we break. + // When process exits, OS closes pipe write end → Read() returns EOF → we exit. + go func() { + buf := make([]byte, 4096) + + // Read stdout + for { + n, err := stdoutReader.Read(buf) + if n > 0 { + session.mu.Lock() + if session.outputBuffer.Len() >= maxOutputBufferSize { + if !session.outputTruncated { + session.outputBuffer.WriteString(outputTruncateMarker) + session.outputTruncated = true + } + } else { + session.outputBuffer.Write(buf[:n]) + } + session.mu.Unlock() + } + if err != nil { + break + } + } + + // Read stderr + for { + n, err := stderrReader.Read(buf) + if n > 0 { + session.mu.Lock() + if session.outputBuffer.Len() >= maxOutputBufferSize { + if !session.outputTruncated { + session.outputBuffer.WriteString(outputTruncateMarker) + session.outputTruncated = true + } + } else { + session.outputBuffer.Write(buf[:n]) + } + session.mu.Unlock() + } + if err != nil { + break + } + } + + // All pipes closed, get exit status + if stdinWriter != nil { + stdinWriter.Close() + } + cmd.Wait() + + session.mu.Lock() + if cmd.ProcessState != nil { + session.ExitCode = cmd.ProcessState.ExitCode() + } + session.Status = "done" + session.mu.Unlock() + }() + } + + resp := ExecResponse{ + SessionID: sessionID, + Status: "running", + } + data, _ := json.Marshal(resp) + return &ToolResult{ + ForLLM: string(data), + ForUser: fmt.Sprintf("Session %s started", sessionID), + IsError: false, + } +} + +func (t *ExecTool) executeList() *ToolResult { + sessions := t.sessionManager.List() + resp := ExecResponse{ + Sessions: sessions, + } + data, _ := json.Marshal(resp) + return &ToolResult{ + ForLLM: string(data), + ForUser: fmt.Sprintf("%d active sessions", len(sessions)), + IsError: false, + } +} + +func (t *ExecTool) executePoll(args map[string]any) *ToolResult { + sessionID, ok := args["sessionId"].(string) + if !ok { + return ErrorResult("sessionId is required") + } + + session, err := t.sessionManager.Get(sessionID) + if err != nil { + if errors.Is(err, ErrSessionNotFound) { + return ErrorResult(fmt.Sprintf("session not found: %s", sessionID)) + } + return ErrorResult(err.Error()) + } + + resp := ExecResponse{ + SessionID: sessionID, + Status: session.GetStatus(), + ExitCode: session.GetExitCode(), + } + data, _ := json.Marshal(resp) + return &ToolResult{ + ForLLM: string(data), + IsError: false, + } +} + +func (t *ExecTool) executeRead(args map[string]any) *ToolResult { + sessionID, ok := args["sessionId"].(string) + if !ok { + return ErrorResult("sessionId is required") + } + + session, err := t.sessionManager.Get(sessionID) + if err != nil { + if errors.Is(err, ErrSessionNotFound) { + return ErrorResult(fmt.Sprintf("session not found: %s", sessionID)) + } + return ErrorResult(err.Error()) + } + + output := session.Read() + + resp := ExecResponse{ + SessionID: sessionID, + Output: output, + Status: session.GetStatus(), + } + data, _ := json.Marshal(resp) + return &ToolResult{ + ForLLM: string(data), + IsError: false, + } +} + +func (t *ExecTool) executeWrite(args map[string]any) *ToolResult { + sessionID, ok := args["sessionId"].(string) + if !ok { + return ErrorResult("sessionId is required") + } + + data, ok := args["data"].(string) + if !ok { + return ErrorResult("data is required") + } + + session, err := t.sessionManager.Get(sessionID) + if err != nil { + if errors.Is(err, ErrSessionNotFound) { + return ErrorResult(fmt.Sprintf("session not found: %s", sessionID)) + } + return ErrorResult(err.Error()) + } + + if session.IsDone() { + return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode())) + } + + if err := session.Write(data); err != nil { + if errors.Is(err, ErrSessionDone) { + return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode())) + } + return ErrorResult(fmt.Sprintf("failed to write to session: %v", err)) + } + + resp := ExecResponse{ + SessionID: sessionID, + Status: session.GetStatus(), + } + respData, _ := json.Marshal(resp) + return &ToolResult{ + ForLLM: string(respData), + IsError: false, + } +} + +func (t *ExecTool) executeKill(args map[string]any) *ToolResult { + sessionID, ok := args["sessionId"].(string) + if !ok { + return ErrorResult("sessionId is required") + } + + session, err := t.sessionManager.Get(sessionID) + if err != nil { + if errors.Is(err, ErrSessionNotFound) { + return ErrorResult(fmt.Sprintf("session not found: %s", sessionID)) + } + return ErrorResult(err.Error()) + } + + if session.IsDone() { + return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode())) + } + + if err := session.Kill(); err != nil { + return ErrorResult(fmt.Sprintf("failed to kill session: %v", err)) + } + + t.sessionManager.Remove(sessionID) + + resp := ExecResponse{ + SessionID: sessionID, + Status: "done", + } + data, _ := json.Marshal(resp) + return &ToolResult{ + ForLLM: string(data), + ForUser: fmt.Sprintf("Session %s killed", sessionID), + IsError: false, + } +} + +// keyMap maps key names to their escape sequences. +var keyMap = map[string]string{ + "enter": "\r", + "return": "\r", + "tab": "\t", + "escape": "\x1b", + "esc": "\x1b", + "space": " ", + "backspace": "\x7f", + "bspace": "\x7f", + "up": "\x1b[A", + "down": "\x1b[B", + "right": "\x1b[C", + "left": "\x1b[D", + "home": "\x1b[1~", + "end": "\x1b[4~", + "pageup": "\x1b[5~", + "pagedown": "\x1b[6~", + "pgup": "\x1b[5~", + "pgdn": "\x1b[6~", + "insert": "\x1b[2~", + "ic": "\x1b[2~", + "delete": "\x1b[3~", + "del": "\x1b[3~", + "dc": "\x1b[3~", + "btab": "\x1b[Z", + "f1": "\x1bOP", + "f2": "\x1bOQ", + "f3": "\x1bOR", + "f4": "\x1bOS", + "f5": "\x1b[15~", + "f6": "\x1b[17~", + "f7": "\x1b[18~", + "f8": "\x1b[19~", + "f9": "\x1b[20~", + "f10": "\x1b[21~", + "f11": "\x1b[23~", + "f12": "\x1b[24~", +} + +// ss3KeysMap maps key names to SS3 escape sequences +var ss3KeysMap = map[string]string{ + "up": "\x1bOA", + "down": "\x1bOB", + "right": "\x1bOC", + "left": "\x1bOD", + "home": "\x1bOH", + "end": "\x1bOF", +} + +func detectPtyKeyMode(raw string) PtyKeyMode { + const SMKX = "\x1b[?1h" + const RMKX = "\x1b[?1l" + + lastSmkx := strings.LastIndex(raw, SMKX) + lastRmkx := strings.LastIndex(raw, RMKX) + + if lastSmkx == -1 && lastRmkx == -1 { + return PtyKeyModeNotFound + } + + if lastSmkx > lastRmkx { + return PtyKeyModeSS3 + } + return PtyKeyModeCSI +} + +// encodeKeyToken encodes a single key token into its escape sequence. +// Supports: +// - Named keys: "enter", "tab", "up", "ctrl-c", "alt-x", etc. +// - Ctrl modifier: "ctrl-c" or "c-c" (sends Ctrl+char) +// - Alt modifier: "alt-x" or "m-x" (sends ESC+char) +func encodeKeyToken(token string, ptyKeyMode PtyKeyMode) (string, error) { + token = strings.ToLower(strings.TrimSpace(token)) + if token == "" { + return "", nil + } + + // Handle ctrl-X format (c-x) + if strings.HasPrefix(token, "c-") { + char := token[2] + if char >= 'a' && char <= 'z' { + return string(rune(char) & 0x1f), nil // ctrl-a through ctrl-z + } + return "", fmt.Errorf("invalid ctrl key: %s", token) + } + + // Handle ctrl-X format (ctrl-x) + if strings.HasPrefix(token, "ctrl-") { + char := token[5] + if char >= 'a' && char <= 'z' { + return string(rune(char) & 0x1f), nil + } + return "", fmt.Errorf("invalid ctrl key: %s", token) + } + + // Handle alt-X format (m-x or alt-x) + if strings.HasPrefix(token, "m-") || strings.HasPrefix(token, "alt-") { + var char string + if strings.HasPrefix(token, "m-") { + char = token[2:] + } else { + char = token[4:] + } + if len(char) == 1 { + return "\x1b" + char, nil + } + return "", fmt.Errorf("invalid alt key: %s", token) + } + + // Handle shift modifier for special keys (shift-up, shift-down, etc.) + if strings.HasPrefix(token, "s-") || strings.HasPrefix(token, "shift-") { + var key string + if strings.HasPrefix(token, "s-") { + key = token[2:] + } else { + key = token[6:] + } + // Apply shift modifier: for single-char keys, return uppercase + if seq, ok := keyMap[key]; ok { + // For escape sequences, we can't easily add shift + // For single-char keys (letters), return uppercase + if len(seq) == 1 { + return strings.ToUpper(seq), nil + } + return seq, nil + } + return "", fmt.Errorf("unknown key with shift: %s", key) + } + + if ptyKeyMode == PtyKeyModeSS3 { + if seq, ok := ss3KeysMap[token]; ok { + return seq, nil + } + } + + if seq, ok := keyMap[token]; ok { + return seq, nil + } + + return "", fmt.Errorf("unknown key: %s (use write action for text input)", token) +} + +// encodeKeySequence encodes a slice of key tokens into a single string. +func encodeKeySequence(tokens []string, ptyKeyMode PtyKeyMode) (string, error) { + var result string + for _, token := range tokens { + seq, err := encodeKeyToken(token, ptyKeyMode) + if err != nil { + return "", err + } + result += seq + } + return result, nil +} + +func (t *ExecTool) executeSendKeys(args map[string]any) *ToolResult { + sessionID, ok := args["sessionId"].(string) + if !ok { + return ErrorResult("sessionId is required") + } + + keysStr, ok := args["keys"].(string) + if !ok { + return ErrorResult("keys must be a string") + } + + if keysStr == "" { + return ErrorResult("keys cannot be empty") + } + + // Parse comma-separated key names + keyNames := strings.Split(keysStr, ",") + var keys []string + for _, k := range keyNames { + k = strings.TrimSpace(k) + if k != "" { + keys = append(keys, k) + } + } + + if len(keys) == 0 { + return ErrorResult("keys cannot be empty") + } + + session, err := t.sessionManager.Get(sessionID) + if err != nil { + if errors.Is(err, ErrSessionNotFound) { + return ErrorResult(fmt.Sprintf("session not found: %s", sessionID)) + } + return ErrorResult(err.Error()) + } + + ptyKeyMode := session.GetPtyKeyMode() + + data, err := encodeKeySequence(keys, ptyKeyMode) + if err != nil { + return ErrorResult(fmt.Sprintf("invalid key: %v", err)) + } + + if session.IsDone() { + return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode())) + } + + if err := session.Write(data); err != nil { + if errors.Is(err, ErrSessionDone) { + return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode())) + } + return ErrorResult(fmt.Sprintf("failed to send keys: %v", err)) + } + + resp := ExecResponse{ + SessionID: sessionID, + Status: "running", + Output: fmt.Sprintf("Sent keys: %v", keys), + } + respData, _ := json.Marshal(resp) + return &ToolResult{ + ForLLM: string(respData), + IsError: false, + } +} + func (t *ExecTool) guardCommand(command, cwd string) string { cmd := strings.TrimSpace(command) lower := strings.ToLower(cmd) diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index f8f83ea747..a8de2f4c9c 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -2,12 +2,16 @@ package tools import ( "context" + "encoding/json" "os" "path/filepath" + "runtime" "strings" "testing" "time" + "github.com/stretchr/testify/require" + "github.com/sipeed/picoclaw/pkg/config" ) @@ -20,6 +24,7 @@ func TestShellTool_Success(t *testing.T) { ctx := context.Background() args := map[string]any{ + "action": "run", "command": "echo 'hello world'", } @@ -50,6 +55,7 @@ func TestShellTool_Failure(t *testing.T) { ctx := context.Background() args := map[string]any{ + "action": "run", "command": "ls /nonexistent_directory_12345", } @@ -82,6 +88,7 @@ func TestShellTool_Timeout(t *testing.T) { ctx := context.Background() args := map[string]any{ + "action": "run", "command": "sleep 10", } @@ -112,8 +119,9 @@ func TestShellTool_WorkingDir(t *testing.T) { ctx := context.Background() args := map[string]any{ - "command": "cat test.txt", - "working_dir": tmpDir, + "action": "run", + "command": "cat test.txt", + "cwd": tmpDir, } result := tool.Execute(ctx, args) @@ -136,6 +144,7 @@ func TestShellTool_DangerousCommand(t *testing.T) { ctx := context.Background() args := map[string]any{ + "action": "run", "command": "rm -rf /", } @@ -159,6 +168,7 @@ func TestShellTool_DangerousCommand_KillBlocked(t *testing.T) { ctx := context.Background() args := map[string]any{ + "action": "run", "command": "kill 12345", } @@ -198,6 +208,7 @@ func TestShellTool_StderrCapture(t *testing.T) { ctx := context.Background() args := map[string]any{ + "action": "run", "command": "sh -c 'echo stdout; echo stderr >&2'", } @@ -222,6 +233,7 @@ func TestShellTool_OutputTruncation(t *testing.T) { ctx := context.Background() // Generate long output (>10000 chars) args := map[string]any{ + "action": "run", "command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000), } @@ -251,8 +263,9 @@ func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) { } result := tool.Execute(context.Background(), map[string]any{ - "command": "pwd", - "working_dir": outsideDir, + "action": "run", + "command": "pwd", + "cwd": outsideDir, }) if !result.IsError { @@ -289,8 +302,9 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) { } result := tool.Execute(context.Background(), map[string]any{ - "command": "cat secret.txt", - "working_dir": link, + "action": "run", + "command": "cat secret.txt", + "cwd": link, }) if !result.IsError { @@ -312,7 +326,7 @@ func TestShellTool_RemoteChannelBlockedByDefault(t *testing.T) { t.Fatalf("NewExecToolWithConfig() error: %v", err) } ctx := WithToolContext(context.Background(), "telegram", "chat-1") - result := tool.Execute(ctx, map[string]any{"command": "echo hi"}) + result := tool.Execute(ctx, map[string]any{"action": "run", "command": "echo hi"}) if !result.IsError { t.Fatal("expected remote-channel exec to be blocked") @@ -333,7 +347,7 @@ func TestShellTool_InternalChannelAllowed(t *testing.T) { t.Fatalf("NewExecToolWithConfig() error: %v", err) } ctx := WithToolContext(context.Background(), "cli", "direct") - result := tool.Execute(ctx, map[string]any{"command": "echo hi"}) + result := tool.Execute(ctx, map[string]any{"action": "run", "command": "echo hi"}) if result.IsError { t.Fatalf("expected internal channel exec to succeed, got: %s", result.ForLLM) @@ -373,7 +387,7 @@ func TestShellTool_AllowRemoteBypassesChannelCheck(t *testing.T) { t.Fatalf("NewExecToolWithConfig() error: %v", err) } ctx := WithToolContext(context.Background(), "telegram", "chat-1") - result := tool.Execute(ctx, map[string]any{"command": "echo hi"}) + result := tool.Execute(ctx, map[string]any{"action": "run", "command": "echo hi"}) if result.IsError { t.Fatalf("expected allowRemote=true to permit remote channel, got: %s", result.ForLLM) @@ -392,6 +406,7 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) { ctx := context.Background() args := map[string]any{ + "action": "run", "command": "cat ../../etc/passwd", } @@ -429,7 +444,7 @@ func TestShellTool_DevNullAllowed(t *testing.T) { } for _, cmd := range commands { - result := tool.Execute(context.Background(), map[string]any{"command": cmd}) + result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd}) if result.IsError && strings.Contains(result.ForLLM, "blocked") { t.Errorf("command should not be blocked: %s\n error: %s", cmd, result.ForLLM) } @@ -458,7 +473,7 @@ func TestShellTool_BlockDevices(t *testing.T) { } for _, cmd := range blocked { - result := tool.Execute(context.Background(), map[string]any{"command": cmd}) + result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd}) if !result.IsError { t.Errorf("expected block device write to be blocked: %s", cmd) } @@ -482,7 +497,7 @@ func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) { } for _, cmd := range commands { - result := tool.Execute(context.Background(), map[string]any{"command": cmd}) + result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd}) if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") { t.Errorf("safe path should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM) } @@ -498,6 +513,7 @@ func TestShellTool_ExitCodeDetails(t *testing.T) { ctx := context.Background() args := map[string]any{ + "action": "run", "command": "sh -c 'exit 42'", } @@ -534,6 +550,7 @@ func TestShellTool_TimeoutWithPartialOutput(t *testing.T) { ctx := context.Background() // Use a command that outputs immediately then sleeps args := map[string]any{ + "action": "run", "command": "echo 'partial output before timeout' && sleep 30", } @@ -608,7 +625,9 @@ func TestShellTool_URLsNotBlocked(t *testing.T) { } for _, cmd := range commands { - result := tool.Execute(context.Background(), map[string]any{"command": cmd}) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + result := tool.Execute(ctx, map[string]any{"action": "run", "command": cmd}) + cancel() if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") { t.Errorf("command with URL should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM) } @@ -633,7 +652,7 @@ func TestShellTool_FileURISandboxing(t *testing.T) { } for _, cmd := range blockedCommands { - result := tool.Execute(context.Background(), map[string]any{"command": cmd}) + result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd}) if !result.IsError || !strings.Contains(result.ForLLM, "path outside working dir") { t.Errorf("file:// URI outside workspace should be blocked: %s", cmd) } @@ -651,7 +670,7 @@ func TestShellTool_FileURISandboxing(t *testing.T) { } for _, cmd := range allowedCommands { - result := tool.Execute(context.Background(), map[string]any{"command": cmd}) + result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd}) if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") { t.Errorf("file:// URI inside workspace should be allowed: %s\n error: %s", cmd, result.ForLLM) } @@ -677,9 +696,920 @@ func TestShellTool_URLBypassPrevented(t *testing.T) { } for _, cmd := range blockedCommands { - result := tool.Execute(context.Background(), map[string]any{"command": cmd}) + result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd}) if !result.IsError || !strings.Contains(result.ForLLM, "path outside working dir") { t.Errorf("bypass attempt should be blocked: %q\n got: %s", cmd, result.ForLLM) } } } + +func TestShellTool_Background_ReturnsImmediately(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + ctx := context.Background() + args := map[string]any{ + "action": "run", + "command": "sleep 5", + "background": "true", + } + + start := time.Now() + result := tool.Execute(ctx, args) + elapsed := time.Since(start) + + require.False(t, result.IsError, "background run should not error: %s", result.ForLLM) + require.Less(t, elapsed, time.Second, "background run should return immediately") + require.Contains(t, result.ForLLM, "sessionId") +} + +func TestShellTool_List_Empty(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := context.Background() + args := map[string]any{"action": "list"} + + result := tool.Execute(ctx, args) + require.False(t, result.IsError) + require.Contains(t, result.ForUser, "0 active sessions") +} + +func TestShellTool_RunBackground_List(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + runResult := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "sleep 10", + "background": "true", + }) + require.False(t, runResult.IsError, "run should succeed: %s", runResult.ForLLM) + + var resp ExecResponse + err = json.Unmarshal([]byte(runResult.ForLLM), &resp) + require.NoError(t, err) + require.NotEmpty(t, resp.SessionID) + + time.Sleep(100 * time.Millisecond) + + listResult := tool.Execute(ctx, map[string]any{"action": "list"}) + require.False(t, listResult.IsError) + + var listResp ExecResponse + err = json.Unmarshal([]byte(listResult.ForLLM), &listResp) + require.NoError(t, err) + require.Len(t, listResp.Sessions, 1) + require.Equal(t, resp.SessionID, listResp.Sessions[0].ID) + + killResult := tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": resp.SessionID, + }) + require.False(t, killResult.IsError, "kill should succeed: %s", killResult.ForLLM) +} + +func TestShellTool_Read_Output(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + runResult := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "echo hello", + "background": "true", + }) + require.False(t, runResult.IsError) + + var resp ExecResponse + err = json.Unmarshal([]byte(runResult.ForLLM), &resp) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + readResult := tool.Execute(ctx, map[string]any{ + "action": "read", + "sessionId": resp.SessionID, + }) + + if !readResult.IsError { + var readResp ExecResponse + err = json.Unmarshal([]byte(readResult.ForLLM), &readResp) + require.NoError(t, err) + } +} + +func TestShellTool_Kill(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + runResult := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "sleep 100", + "background": "true", + }) + require.False(t, runResult.IsError) + + var resp ExecResponse + err = json.Unmarshal([]byte(runResult.ForLLM), &resp) + require.NoError(t, err) + + killResult := tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": resp.SessionID, + }) + require.False(t, killResult.IsError, "kill should succeed: %s", killResult.ForLLM) + + time.Sleep(100 * time.Millisecond) + + listResult := tool.Execute(ctx, map[string]any{"action": "list"}) + var listResp ExecResponse + err = json.Unmarshal([]byte(listResult.ForLLM), &listResp) + require.NoError(t, err) + require.Len(t, listResp.Sessions, 0) +} + +func TestShellTool_PTY_AllowedCommands(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY not supported on Windows") + } + + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Test that PTY is allowed for non-interpreter commands + result := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "cat", + "pty": "true", + "background": "true", + }) + require.False(t, result.IsError, "PTY with cat should succeed: %s", result.ForLLM) + require.Contains(t, result.ForLLM, "sessionId") + + var resp ExecResponse + err = json.Unmarshal([]byte(result.ForLLM), &resp) + require.NoError(t, err) + require.NotEmpty(t, resp.SessionID) + + // Clean up + tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": resp.SessionID, + }) +} + +func TestShellTool_PTY_WriteRead(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY not supported on Windows") + } + + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Start a PTY session with a command that waits for input + // Using 'cat' which will wait for stdin + result := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "cat", + "pty": "true", + "background": "true", + }) + require.False(t, result.IsError, "PTY run should succeed: %s", result.ForLLM) + + var resp ExecResponse + err = json.Unmarshal([]byte(result.ForLLM), &resp) + require.NoError(t, err) + + // Write some input to cat + writeResult := tool.Execute(ctx, map[string]any{ + "action": "write", + "sessionId": resp.SessionID, + "data": "hello\n", + }) + require.False(t, writeResult.IsError, "write should succeed: %s", writeResult.ForLLM) + + // Give cat time to process and output + time.Sleep(200 * time.Millisecond) + + // Read the output + readResult := tool.Execute(ctx, map[string]any{ + "action": "read", + "sessionId": resp.SessionID, + }) + + require.False(t, readResult.IsError, "read should succeed: %s", readResult.ForLLM) + + var readResp ExecResponse + err = json.Unmarshal([]byte(readResult.ForLLM), &readResp) + require.NoError(t, err) + // PTY output should contain "hello" + require.Contains(t, readResp.Output, "hello") + + // Clean up + tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": resp.SessionID, + }) +} + +func TestShellTool_PTY_Poll(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY not supported on Windows") + } + + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Start a PTY session with a long-running command + result := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "sleep 2", + "pty": "true", + "background": "true", + }) + require.False(t, result.IsError, "PTY run should succeed: %s", result.ForLLM) + + var resp ExecResponse + err = json.Unmarshal([]byte(result.ForLLM), &resp) + require.NoError(t, err) + + // Poll should show running + pollResult := tool.Execute(ctx, map[string]any{ + "action": "poll", + "sessionId": resp.SessionID, + }) + require.False(t, pollResult.IsError, "poll should succeed: %s", pollResult.ForLLM) + + var pollResp ExecResponse + err = json.Unmarshal([]byte(pollResult.ForLLM), &pollResp) + require.NoError(t, err) + require.Equal(t, "running", pollResp.Status) + + // Wait for sleep to complete + time.Sleep(2500 * time.Millisecond) + + // Poll should show done + pollResult = tool.Execute(ctx, map[string]any{ + "action": "poll", + "sessionId": resp.SessionID, + }) + require.False(t, pollResult.IsError) + + err = json.Unmarshal([]byte(pollResult.ForLLM), &pollResp) + require.NoError(t, err) + require.Equal(t, "done", pollResp.Status) +} + +func TestShellTool_PTY_Kill(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY not supported on Windows") + } + + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Start a PTY session with a long-running command + result := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "sleep 10", + "pty": "true", + "background": "true", + }) + require.False(t, result.IsError, "PTY run should succeed: %s", result.ForLLM) + + var resp ExecResponse + err = json.Unmarshal([]byte(result.ForLLM), &resp) + require.NoError(t, err) + + // Kill the session + killResult := tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": resp.SessionID, + }) + require.False(t, killResult.IsError, "kill should succeed: %s", killResult.ForLLM) + + // Verify kill response shows done status + var killResp ExecResponse + err = json.Unmarshal([]byte(killResult.ForLLM), &killResp) + require.NoError(t, err) + require.Equal(t, "done", killResp.Status) + + // Poll should return error since session is removed after kill + pollResult := tool.Execute(ctx, map[string]any{ + "action": "poll", + "sessionId": resp.SessionID, + }) + // Session is removed after kill, so poll returns error with "session not found" + require.True(t, pollResult.IsError, "poll should error after kill (session removed)") + require.Contains(t, pollResult.ForLLM, "session not found") +} + +func TestShellTool_Write_Read_NonPTY(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Start a background process that reads from stdin and outputs it + // Using 'cat' which echoes stdin to stdout + result := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "cat", + "pty": false, + "background": "true", + }) + require.False(t, result.IsError, "run should succeed: %s", result.ForLLM) + + var resp ExecResponse + err = json.Unmarshal([]byte(result.ForLLM), &resp) + require.NoError(t, err) + + // Write some input to cat + writeResult := tool.Execute(ctx, map[string]any{ + "action": "write", + "sessionId": resp.SessionID, + "data": "hello world\n", + }) + require.False(t, writeResult.IsError, "write should succeed: %s", writeResult.ForLLM) + + // Give cat time to process and output + time.Sleep(200 * time.Millisecond) + + // Read the output + readResult := tool.Execute(ctx, map[string]any{ + "action": "read", + "sessionId": resp.SessionID, + }) + require.False(t, readResult.IsError, "read should succeed: %s", readResult.ForLLM) + + var readResp ExecResponse + err = json.Unmarshal([]byte(readResult.ForLLM), &readResp) + require.NoError(t, err) + require.Contains(t, readResp.Output, "hello world") + + // Clean up + tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": resp.SessionID, + }) +} + +func TestShellTool_Read_NonPTY_Running(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Start a long-running process that produces output over time + // Using sh -c with sleep at the end so process doesn't exit immediately + result := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "sh -c 'echo line1; sleep 0.5; echo line2; sleep 0.5; echo line3; sleep 10'", + "pty": false, + "background": "true", + }) + require.False(t, result.IsError, "run should succeed: %s", result.ForLLM) + + var resp ExecResponse + err = json.Unmarshal([]byte(result.ForLLM), &resp) + require.NoError(t, err) + + // Give time for first outputs to be produced + time.Sleep(300 * time.Millisecond) + + // Read output while process is running + readResult := tool.Execute(ctx, map[string]any{ + "action": "read", + "sessionId": resp.SessionID, + }) + require.False(t, readResult.IsError, "read should succeed: %s", readResult.ForLLM) + + var readResp ExecResponse + err = json.Unmarshal([]byte(readResult.ForLLM), &readResp) + require.NoError(t, err) + // Should have at least line1 + require.Contains(t, readResp.Output, "line1") + + // Wait for line3 to be produced (line1=0s, line2=0.5s, line3=1s, then sleep 10) + time.Sleep(1200 * time.Millisecond) + + // Read again - should have line3 as well + readResult = tool.Execute(ctx, map[string]any{ + "action": "read", + "sessionId": resp.SessionID, + }) + require.False(t, readResult.IsError, "read should succeed: %s", readResult.ForLLM) + + err = json.Unmarshal([]byte(readResult.ForLLM), &readResp) + require.NoError(t, err) + require.Contains(t, readResp.Output, "line3") + + // Clean up + tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": resp.SessionID, + }) +} + +func TestShellTool_ProcessGroupKill(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Process group kill not supported on Windows") + } + + // Note: Testing process group kill with PTY is tricky because the command + // must be run through an interpreter (sh, bash) which is blocked for PTY. + // Instead, we test with non-PTY mode which also uses Setsid for background processes. + + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Start a shell that spawns child processes (non-PTY mode) + // The sh -c command creates child sleep processes + result := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "sh -c 'sleep 30 & sleep 30 & wait'", + "pty": false, + "background": "true", + }) + require.False(t, result.IsError, "run should succeed: %s", result.ForLLM) + + var resp ExecResponse + err = json.Unmarshal([]byte(result.ForLLM), &resp) + require.NoError(t, err) + + // Give time for child processes to spawn + time.Sleep(500 * time.Millisecond) + + // Kill the session - should kill the entire process group + killResult := tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": resp.SessionID, + }) + require.False(t, killResult.IsError, "kill should succeed: %s", killResult.ForLLM) + + // Verify kill response shows done status + var killResp ExecResponse + err = json.Unmarshal([]byte(killResult.ForLLM), &killResp) + require.NoError(t, err) + require.Equal(t, "done", killResp.Status) + + // Poll should return error since session is removed after kill + pollResult := tool.Execute(ctx, map[string]any{ + "action": "poll", + "sessionId": resp.SessionID, + }) + require.True(t, pollResult.IsError, "poll should error after kill (session removed)") + require.Contains(t, pollResult.ForLLM, "session not found") +} + +func TestShellTool_PTY_ProcessGroupKill(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY process group kill not supported on Windows") + } + + // This test binary creates 4 child sleep processes and waits for signals. + // It's not an interpreter, so it's allowed with PTY mode. + // The binary is created in /tmp/test_pgroup.c and compiled as part of test setup. + testBinary := "/tmp/test_pgroup" + if _, err := os.Stat(testBinary); os.IsNotExist(err) { + t.Skip("Test binary /tmp/test_pgroup not found - run: gcc -o /tmp/test_pgroup /tmp/test_pgroup.c") + } + + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Start the test binary with PTY mode + // It forks 4 child sleep processes and waits for signals + result := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": testBinary, + "pty": "true", + "background": "true", + }) + require.False(t, result.IsError, "run should succeed: %s", result.ForLLM) + + var resp ExecResponse + err = json.Unmarshal([]byte(result.ForLLM), &resp) + require.NoError(t, err) + + // Give time for child processes to spawn + time.Sleep(500 * time.Millisecond) + + // Kill the session - should kill the entire process group + killResult := tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": resp.SessionID, + }) + require.False(t, killResult.IsError, "kill should succeed: %s", killResult.ForLLM) + + // Verify kill response shows done status + var killResp ExecResponse + err = json.Unmarshal([]byte(killResult.ForLLM), &killResp) + require.NoError(t, err) + require.Equal(t, "done", killResp.Status) + + // Poll should return error since session is removed after kill + pollResult := tool.Execute(ctx, map[string]any{ + "action": "poll", + "sessionId": resp.SessionID, + }) + require.True(t, pollResult.IsError, "poll should error after kill (session removed)") + require.Contains(t, pollResult.ForLLM, "session not found") +} + +func TestShellTool_PTY_Background_Read(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY not supported on Windows") + } + + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Start a fast command with PTY + background mode + runResult := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "echo hello", + "pty": "true", + "background": "true", + }) + require.False(t, runResult.IsError, "run should succeed: %s", runResult.ForLLM) + + var runResp ExecResponse + err = json.Unmarshal([]byte(runResult.ForLLM), &runResp) + require.NoError(t, err) + require.NotEmpty(t, runResp.SessionID) + require.Equal(t, "running", runResp.Status) + + // Wait for command to complete + time.Sleep(500 * time.Millisecond) + + // Read output - this is the key test: PTY + background mode should preserve output + readResult := tool.Execute(ctx, map[string]any{ + "action": "read", + "sessionId": runResp.SessionID, + }) + require.False(t, readResult.IsError, "read should succeed: %s", readResult.ForLLM) + require.Contains(t, readResult.ForLLM, "hello", "output should contain 'hello'") +} + +func TestShellTool_PTY_Background_ReadNoBlock(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY not supported on Windows") + } + + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + // Start a long-running command with PTY + background mode + // This command produces no output, just sleeps + runResult := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "sleep 10", + "pty": "true", + "background": "true", + }) + require.False(t, runResult.IsError, "run should succeed: %s", runResult.ForLLM) + + var runResp ExecResponse + err = json.Unmarshal([]byte(runResult.ForLLM), &runResp) + require.NoError(t, err) + require.NotEmpty(t, runResp.SessionID) + + // Read immediately - should NOT block even though process is running and has no output + // This tests that Read() returns quickly (within 1 second) instead of blocking for 10 seconds + start := time.Now() + readResult := tool.Execute(ctx, map[string]any{ + "action": "read", + "sessionId": runResp.SessionID, + }) + elapsed := time.Since(start) + + require.False(t, readResult.IsError, "read should succeed: %s", readResult.ForLLM) + require.Less(t, elapsed.Seconds(), 1.0, "read should not block, should return within 1 second") + + // Kill the session to clean up + killResult := tool.Execute(ctx, map[string]any{ + "action": "kill", + "sessionId": runResp.SessionID, + }) + require.False(t, killResult.IsError, "kill should succeed: %s", killResult.ForLLM) +} + +func TestShellTool_Poll_Status(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + sm := NewSessionManager() + tool.sessionManager = sm + + ctx := WithToolContext(context.Background(), "cli", "test") + + runResult := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "sleep 1", + "background": "true", + }) + require.False(t, runResult.IsError) + + var resp ExecResponse + err = json.Unmarshal([]byte(runResult.ForLLM), &resp) + require.NoError(t, err) + + pollResult := tool.Execute(ctx, map[string]any{ + "action": "poll", + "sessionId": resp.SessionID, + }) + require.False(t, pollResult.IsError) + + var pollResp ExecResponse + err = json.Unmarshal([]byte(pollResult.ForLLM), &pollResp) + require.NoError(t, err) + require.Equal(t, "running", pollResp.Status) + + time.Sleep(1200 * time.Millisecond) + + pollResult = tool.Execute(ctx, map[string]any{ + "action": "poll", + "sessionId": resp.SessionID, + }) + require.False(t, pollResult.IsError) + + err = json.Unmarshal([]byte(pollResult.ForLLM), &pollResp) + require.NoError(t, err) + require.Equal(t, "done", pollResp.Status) +} + +func TestShellTool_Action_Run_Sync(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + ctx := context.Background() + + result := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "echo hello", + }) + + require.False(t, result.IsError) + require.Contains(t, result.ForLLM, "hello") +} + +// TestShellTool_Background_ReadAfterExit verifies that we can read +// buffered output even after the background process has exited. +func TestShellTool_Background_ReadAfterExit(t *testing.T) { + tool, err := NewExecTool("", false) + require.NoError(t, err) + + ctx := context.Background() + + // Start a background command that produces output and exits quickly + runResult := tool.Execute(ctx, map[string]any{ + "action": "run", + "command": "echo hello && sleep 1 && echo world", + "background": "true", + }) + require.False(t, runResult.IsError, "run should succeed: %s", runResult.ForUser) + + // Parse session ID from response + var resp ExecResponse + err = json.Unmarshal([]byte(runResult.ForLLM), &resp) + require.NoError(t, err) + require.NotEmpty(t, resp.SessionID) + sessionID := resp.SessionID + + // Wait for process to exit (sleep 1 + some buffer) + time.Sleep(1500 * time.Millisecond) + + // Poll to verify process is done + pollResult := tool.Execute(ctx, map[string]any{ + "action": "poll", + "sessionId": sessionID, + }) + require.False(t, pollResult.IsError, "poll should succeed: %s", pollResult.ForLLM) + var pollResp ExecResponse + err = json.Unmarshal([]byte(pollResult.ForLLM), &pollResp) + require.NoError(t, err) + require.Equal(t, "done", pollResp.Status, "process should be done") + + // Try to read output AFTER process has exited + readResult := tool.Execute(ctx, map[string]any{ + "action": "read", + "sessionId": sessionID, + }) + require.False(t, readResult.IsError, "read should succeed after exit: %s", readResult.ForLLM) + + var readResp ExecResponse + err = json.Unmarshal([]byte(readResult.ForLLM), &readResp) + require.NoError(t, err) + + // Output should contain both "hello" and "world" + require.Contains(t, readResp.Output, "hello", "should contain hello") + require.Contains(t, readResp.Output, "world", "should contain world after sleep") +} + +func TestSendKeys_CtrlC(t *testing.T) { + // Note: Ctrl-C as a signal requires sending SIGINT to the process group, + // which requires elevated privileges. Writing "\x03" to PTY passes the byte + // to the process but doesn't generate SIGINT for processes that don't read stdin. + // For interrupting processes, use the kill action instead. + t.Skip("Ctrl-C as signal not supported - use kill action for interruption") +} + +func TestEncodeKeyToken(t *testing.T) { + tests := []struct { + token string + expected string + hasError bool + }{ + // Named keys + {"enter", "\r", false}, + {"return", "\r", false}, + {"tab", "\t", false}, + {"escape", "\x1b", false}, + {"esc", "\x1b", false}, + {"backspace", "\x7f", false}, + {"up", "\x1b[A", false}, + {"down", "\x1b[B", false}, + {"left", "\x1b[D", false}, + {"right", "\x1b[C", false}, + {"home", "\x1b[1~", false}, + {"end", "\x1b[4~", false}, + {"pageup", "\x1b[5~", false}, + {"pagedown", "\x1b[6~", false}, + {"delete", "\x1b[3~", false}, + {"f1", "\x1bOP", false}, + {"f12", "\x1b[24~", false}, + + // Ctrl keys + {"ctrl-c", "\x03", false}, + {"ctrl-d", "\x04", false}, + {"ctrl-a", "\x01", false}, + {"ctrl-z", "\x1a", false}, + {"c-c", "\x03", false}, + {"c-d", "\x04", false}, + + // Alt keys + {"alt-x", "\x1bx", false}, + {"m-x", "\x1bx", false}, + + // Case insensitive tests + {"ENTER", "\r", false}, + {"TAB", "\t", false}, + {"CTRL-C", "\x03", false}, + {"Ctrl-D", "\x04", false}, + {"ALT-X", "\x1bx", false}, + {"M-X", "\x1bx", false}, + {"UP", "\x1b[A", false}, + {"DOWN", "\x1b[B", false}, + + // Unknown keys should return error (use write action for text input) + {"unknown-key", "", true}, + } + + for _, tt := range tests { + t.Run(tt.token, func(t *testing.T) { + result, err := encodeKeyToken(tt.token, PtyKeyModeCSI) + if tt.hasError { + require.Error(t, err, "expected error for %s", tt.token) + } else { + require.NoError(t, err, "unexpected error for %s", tt.token) + require.Equal(t, tt.expected, result, "wrong encoding for %s", tt.token) + } + }) + } +} + +// TestDetectPtyKeyMode tests smkx/rmkx detection in PTY output +func TestDetectPtyKeyMode(t *testing.T) { + tests := []struct { + name string + raw string + expected PtyKeyMode + }{ + {"no toggle", "hello world", PtyKeyModeNotFound}, + {"smkx only", "\x1b[?1h\x1b=", PtyKeyModeSS3}, + {"rmkx only", "\x1b[?1l\x1b>", PtyKeyModeCSI}, + {"both smkx first", "\x1b[?1h\x1b=...\x1b[?1l\x1b>", PtyKeyModeCSI}, + {"both rmkx first", "\x1b[?1l\x1b>...\x1b[?1h\x1b=", PtyKeyModeSS3}, + {"multiple toggles smkx last", "\x1b[?1h\x1b=...\x1b[?1l\x1b>...\x1b[?1h\x1b=", PtyKeyModeSS3}, + {"multiple toggles rmkx last", "\x1b[?1l\x1b>...\x1b[?1h\x1b=...\x1b[?1l\x1b>", PtyKeyModeCSI}, + {"partial smkx", "\x1b[?1h", PtyKeyModeSS3}, + {"partial rmkx", "\x1b[?1l", PtyKeyModeCSI}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := detectPtyKeyMode(tt.raw) + require.Equal(t, tt.expected, result, "wrong mode for %s", tt.name) + }) + } +} + +func TestEncodeKeyTokenWithPtyKeyMode(t *testing.T) { + tests := []struct { + name string + token string + mode PtyKeyMode + expected string + hasError bool + }{ + // CSI mode + {"up csi", "up", PtyKeyModeCSI, "\x1b[A", false}, + {"down csi", "down", PtyKeyModeCSI, "\x1b[B", false}, + {"left csi", "left", PtyKeyModeCSI, "\x1b[D", false}, + {"right csi", "right", PtyKeyModeCSI, "\x1b[C", false}, + + // SS3 mode + {"up ss3", "up", PtyKeyModeSS3, "\x1bOA", false}, + {"down ss3", "down", PtyKeyModeSS3, "\x1bOB", false}, + {"left ss3", "left", PtyKeyModeSS3, "\x1bOD", false}, + {"right ss3", "right", PtyKeyModeSS3, "\x1bOC", false}, + {"home ss3", "home", PtyKeyModeSS3, "\x1bOH", false}, + {"end ss3", "end", PtyKeyModeSS3, "\x1bOF", false}, + + // Other keys unaffected by mode + {"enter ss3", "enter", PtyKeyModeSS3, "\r", false}, + {"tab ss3", "tab", PtyKeyModeSS3, "\t", false}, + {"ctrl-c ss3", "ctrl-c", PtyKeyModeSS3, "\x03", false}, + + // NotFound behaves like CSI + {"up notfound", "up", PtyKeyModeNotFound, "\x1b[A", false}, + {"down notfound", "down", PtyKeyModeNotFound, "\x1b[B", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := encodeKeyToken(tt.token, tt.mode) + if tt.hasError { + require.Error(t, err, "expected error for %s", tt.name) + } else { + require.NoError(t, err, "unexpected error for %s", tt.name) + require.Equal(t, tt.expected, result, "wrong encoding for %s", tt.name) + } + }) + } +} diff --git a/pkg/tools/shell_timeout_unix_test.go b/pkg/tools/shell_timeout_unix_test.go index 357e1276ef..dfd28454c7 100644 --- a/pkg/tools/shell_timeout_unix_test.go +++ b/pkg/tools/shell_timeout_unix_test.go @@ -30,6 +30,7 @@ func TestShellTool_TimeoutKillsChildProcess(t *testing.T) { tool.SetTimeout(500 * time.Millisecond) args := map[string]any{ + "action": "run", // Spawn a child process that would outlive the shell unless process-group kill is used. "command": "sleep 60 & echo $! > child.pid; wait", } diff --git a/pkg/tools/sysproc_unix.go b/pkg/tools/sysproc_unix.go new file mode 100644 index 0000000000..0fb03d43a8 --- /dev/null +++ b/pkg/tools/sysproc_unix.go @@ -0,0 +1,12 @@ +//go:build !windows + +package tools + +import ( + "os/exec" + "syscall" +) + +func setSysProcAttrForPty(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} +} diff --git a/pkg/tools/sysproc_windows.go b/pkg/tools/sysproc_windows.go new file mode 100644 index 0000000000..150f166fb2 --- /dev/null +++ b/pkg/tools/sysproc_windows.go @@ -0,0 +1,10 @@ +//go:build windows + +package tools + +import "os/exec" + +func setSysProcAttrForPty(cmd *exec.Cmd) { + // Windows doesn't support Setsid, and PTY is not available on Windows anyway. + // This function is a no-op for Windows builds. +} diff --git a/pkg/tools/types.go b/pkg/tools/types.go index a6015cde33..4d1a18d5a1 100644 --- a/pkg/tools/types.go +++ b/pkg/tools/types.go @@ -56,3 +56,24 @@ type ToolFunctionDefinition struct { Description string `json:"description"` Parameters map[string]any `json:"parameters"` } + +type ExecRequest struct { + Action string `json:"action"` + Command string `json:"command,omitempty"` + PTY bool `json:"pty,omitempty"` + Background bool `json:"background,omitempty"` + Timeout int `json:"timeout,omitempty"` + Env map[string]string `json:"env,omitempty"` + Cwd string `json:"cwd,omitempty"` + SessionID string `json:"sessionId,omitempty"` + Data string `json:"data,omitempty"` +} + +type ExecResponse struct { + SessionID string `json:"sessionId,omitempty"` + Status string `json:"status,omitempty"` + ExitCode int `json:"exitCode,omitempty"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` + Sessions []SessionInfo `json:"sessions,omitempty"` +}