diff --git a/codex-rs/shell-escalation/Cargo.toml b/codex-rs/shell-escalation/Cargo.toml index 49b8f10bd31..89c903b57f6 100644 --- a/codex-rs/shell-escalation/Cargo.toml +++ b/codex-rs/shell-escalation/Cargo.toml @@ -14,13 +14,15 @@ libc = { workspace = true } serde_json = { workspace = true } path-absolutize = { workspace = true } serde = { workspace = true, features = ["derive"] } -socket2 = { workspace = true } +socket2 = { workspace = true, features = ["all"] } tokio = { workspace = true, features = [ "io-std", + "net", "macros", "process", "rt-multi-thread", "signal", + "time", ] } tokio-util = { workspace = true } tracing = { workspace = true } diff --git a/codex-rs/shell-escalation/src/lib.rs b/codex-rs/shell-escalation/src/lib.rs index 555d0f89e0d..68070898d74 100644 --- a/codex-rs/shell-escalation/src/lib.rs +++ b/codex-rs/shell-escalation/src/lib.rs @@ -1,21 +1,114 @@ #[cfg(unix)] -mod unix { - mod escalate_client; - mod escalate_protocol; - mod escalate_server; - mod escalation_policy; - mod socket; - mod stopwatch; - - pub use self::escalate_client::run; - pub use self::escalate_protocol::EscalateAction; - pub use self::escalate_server::EscalationPolicyFactory; - pub use self::escalate_server::ExecParams; - pub use self::escalate_server::ExecResult; - pub use self::escalate_server::run_escalate_server; - pub use self::escalation_policy::EscalationPolicy; - pub use self::stopwatch::Stopwatch; -} +pub mod unix; #[cfg(unix)] pub use unix::*; + +#[cfg(unix)] +pub use unix::escalate_client::run; +#[cfg(unix)] +pub use unix::escalate_protocol::EscalateAction; +#[cfg(unix)] +pub use unix::escalate_server::EscalationPolicyFactory; +#[cfg(unix)] +pub use unix::escalate_server::ExecParams; +#[cfg(unix)] +pub use unix::escalate_server::ExecResult; +#[cfg(unix)] +pub use unix::escalation_policy::EscalationPolicy; +#[cfg(unix)] +pub use unix::stopwatch::Stopwatch; + +#[cfg(unix)] +mod legacy_api { + use std::collections::HashMap; + use std::path::Path; + use std::path::PathBuf; + use std::sync::Arc; + use std::time::Duration; + + use codex_execpolicy::Policy; + use codex_protocol::config_types::WindowsSandboxLevel; + use codex_protocol::models::SandboxPermissions as ProtocolSandboxPermissions; + use tokio::sync::RwLock; + use tokio_util::sync::CancellationToken; + + use crate::unix::escalate_server::EscalationPolicyFactory; + use crate::unix::escalate_server::ExecParams; + use crate::unix::escalate_server::ExecResult; + use crate::unix::escalate_server::SandboxState; + use crate::unix::escalate_server::ShellCommandExecutor; + + struct CoreShellCommandExecutor; + + #[async_trait::async_trait] + impl ShellCommandExecutor for CoreShellCommandExecutor { + async fn run( + &self, + command: Vec, + cwd: PathBuf, + env: HashMap, + cancel_rx: CancellationToken, + sandbox_state: &SandboxState, + ) -> anyhow::Result { + let result = codex_core::exec::process_exec_tool_call( + codex_core::exec::ExecParams { + command, + cwd, + expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx), + env, + network: None, + sandbox_permissions: ProtocolSandboxPermissions::UseDefault, + windows_sandbox_level: WindowsSandboxLevel::Disabled, + justification: None, + arg0: None, + }, + &sandbox_state.sandbox_policy, + &sandbox_state.sandbox_cwd, + &sandbox_state.codex_linux_sandbox_exe, + sandbox_state.use_linux_sandbox_bwrap, + None, + ) + .await?; + + Ok(ExecResult { + exit_code: result.exit_code, + output: result.aggregated_output.text, + duration: result.duration, + timed_out: result.timed_out, + }) + } + } + + #[allow(clippy::too_many_arguments)] + pub async fn run_escalate_server( + exec_params: ExecParams, + sandbox_state: &codex_core::SandboxState, + shell_program: impl AsRef, + execve_wrapper: impl AsRef, + policy: Arc>, + escalation_policy_factory: impl EscalationPolicyFactory, + effective_timeout: Duration, + ) -> anyhow::Result { + let sandbox_state = SandboxState { + sandbox_policy: sandbox_state.sandbox_policy.clone(), + codex_linux_sandbox_exe: sandbox_state.codex_linux_sandbox_exe.clone(), + sandbox_cwd: sandbox_state.sandbox_cwd.clone(), + use_linux_sandbox_bwrap: sandbox_state.use_linux_sandbox_bwrap, + }; + crate::unix::escalate_server::run_escalate_server( + exec_params, + &sandbox_state, + shell_program, + execve_wrapper, + policy, + escalation_policy_factory, + effective_timeout, + &CoreShellCommandExecutor, + ) + .await + } +} + +#[cfg(unix)] +pub use legacy_api::run_escalate_server; diff --git a/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs b/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs index 0be7af28fa1..ca6bd347daa 100644 --- a/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs +++ b/codex-rs/shell-escalation/src/unix/core_shell_escalation.rs @@ -40,7 +40,7 @@ impl ShellPolicyFactory { } } -struct ShellEscalationPolicy { +pub struct ShellEscalationPolicy { provider: Arc, stopwatch: Stopwatch, } diff --git a/codex-rs/shell-escalation/src/unix/escalate_server.rs b/codex-rs/shell-escalation/src/unix/escalate_server.rs index 0ee5fc27c40..d0f7c94be98 100644 --- a/codex-rs/shell-escalation/src/unix/escalate_server.rs +++ b/codex-rs/shell-escalation/src/unix/escalate_server.rs @@ -7,8 +7,8 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Context as _; -use codex_core::SandboxState; use codex_execpolicy::Policy; +use codex_protocol::protocol::SandboxPolicy; use path_absolutize::Absolutize as _; use tokio::process::Command; use tokio::sync::RwLock; @@ -27,6 +27,33 @@ use crate::unix::socket::AsyncDatagramSocket; use crate::unix::socket::AsyncSocket; use crate::unix::stopwatch::Stopwatch; +#[derive(Debug, Clone)] +/// Sandbox configuration forwarded to the embedding crate's process executor. +pub struct SandboxState { + pub sandbox_policy: SandboxPolicy, + pub codex_linux_sandbox_exe: Option, + pub sandbox_cwd: PathBuf, + pub use_linux_sandbox_bwrap: bool, +} + +#[async_trait::async_trait] +/// Adapter for running the shell command after the escalation server has been set up. +/// +/// This lets `shell-escalation` own the Unix escalation protocol while the caller +/// (for example `codex-core` or `exec-server`) keeps control over process spawning, +/// output capture, and sandbox integration. +pub trait ShellCommandExecutor: Send + Sync { + /// Runs the requested shell command and returns the captured result. + async fn run( + &self, + command: Vec, + cwd: PathBuf, + env: HashMap, + cancel_rx: CancellationToken, + sandbox_state: &SandboxState, + ) -> anyhow::Result; +} + #[derive(Debug, serde::Deserialize, serde::Serialize)] pub struct ExecParams { /// The bash string to execute. @@ -71,11 +98,12 @@ impl EscalateServer { params: ExecParams, cancel_rx: CancellationToken, sandbox_state: &SandboxState, + command_executor: &dyn ShellCommandExecutor, ) -> anyhow::Result { let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?; let client_socket = escalate_client.into_inner(); + // Only the client endpoint should cross exec into the wrapper process. client_socket.set_cloexec(false)?; - let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy.clone())); let mut env = std::env::vars().collect::>(); env.insert( @@ -91,47 +119,27 @@ impl EscalateServer { self.execve_wrapper.to_string_lossy().to_string(), ); - let ExecParams { - command, - workdir, - timeout_ms: _, - login, - } = params; - let result = codex_core::exec::process_exec_tool_call( - codex_core::exec::ExecParams { - command: vec![ - self.bash_path.to_string_lossy().to_string(), - if login == Some(false) { - "-c".to_string() - } else { - "-lc".to_string() - }, - command, - ], - cwd: PathBuf::from(&workdir), - expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx), - env, - network: None, - sandbox_permissions: codex_core::sandboxing::SandboxPermissions::UseDefault, - windows_sandbox_level: codex_protocol::config_types::WindowsSandboxLevel::Disabled, - justification: None, - arg0: None, + let command = vec![ + self.bash_path.to_string_lossy().to_string(), + if params.login == Some(false) { + "-c".to_string() + } else { + "-lc".to_string() }, - &sandbox_state.sandbox_policy, - &sandbox_state.sandbox_cwd, - &sandbox_state.codex_linux_sandbox_exe, - sandbox_state.use_linux_sandbox_bwrap, - None, - ) - .await?; + params.command, + ]; + let result = command_executor + .run( + command, + PathBuf::from(¶ms.workdir), + env, + cancel_rx, + sandbox_state, + ) + .await?; escalate_task.abort(); - Ok(ExecResult { - exit_code: result.exit_code, - output: result.aggregated_output.text, - duration: result.duration, - timed_out: result.timed_out, - }) + Ok(result) } } @@ -142,6 +150,7 @@ pub trait EscalationPolicyFactory { fn create_policy(&self, policy: Arc>, stopwatch: Stopwatch) -> Self::Policy; } +#[allow(clippy::too_many_arguments)] pub async fn run_escalate_server( exec_params: ExecParams, sandbox_state: &SandboxState, @@ -150,6 +159,7 @@ pub async fn run_escalate_server( policy: Arc>, escalation_policy_factory: impl EscalationPolicyFactory, effective_timeout: Duration, + command_executor: &dyn ShellCommandExecutor, ) -> anyhow::Result { let stopwatch = Stopwatch::new(effective_timeout); let cancel_token = stopwatch.cancellation_token(); @@ -160,7 +170,7 @@ pub async fn run_escalate_server( ); escalate_server - .exec(exec_params, cancel_token, sandbox_state) + .exec(exec_params, cancel_token, sandbox_state, command_executor) .await } @@ -272,6 +282,7 @@ async fn handle_escalate_session_with_policy( .await?; } } + Ok(()) } @@ -279,7 +290,6 @@ async fn handle_escalate_session_with_policy( mod tests { use super::*; use pretty_assertions::assert_eq; - use std::collections::HashMap; use std::path::Path; use std::path::PathBuf; diff --git a/codex-rs/shell-escalation/src/unix/mod.rs b/codex-rs/shell-escalation/src/unix/mod.rs index 0ae7941da26..555bd7246eb 100644 --- a/codex-rs/shell-escalation/src/unix/mod.rs +++ b/codex-rs/shell-escalation/src/unix/mod.rs @@ -1,7 +1,7 @@ +pub mod core_shell_escalation; pub mod escalate_client; pub mod escalate_protocol; pub mod escalate_server; pub mod escalation_policy; pub mod socket; -pub mod core_shell_escalation; pub mod stopwatch; diff --git a/codex-rs/shell-escalation/src/unix/socket.rs b/codex-rs/shell-escalation/src/unix/socket.rs index 35292367a6b..a66cc4eab51 100644 --- a/codex-rs/shell-escalation/src/unix/socket.rs +++ b/codex-rs/shell-escalation/src/unix/socket.rs @@ -96,8 +96,8 @@ async fn read_frame_header( while filled < LENGTH_PREFIX_SIZE { let mut guard = async_socket.readable().await?; // The first read should come with a control message containing any FDs. - let result = if !captured_control { - guard.try_io(|inner| { + let read = if !captured_control { + match guard.try_io(|inner| { let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])]; let (read, control_len) = { let mut msg = MsgHdrMut::new() @@ -109,16 +109,18 @@ async fn read_frame_header( control.truncate(control_len); captured_control = true; Ok(read) - }) + }) { + Ok(Ok(read)) => read, + Ok(Err(err)) => return Err(err), + Err(_would_block) => continue, + } } else { - guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) - }; - let Ok(result) = result else { - // Would block, try again. - continue; + match guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..])) { + Ok(Ok(read)) => read, + Ok(Err(err)) => return Err(err), + Err(_would_block) => continue, + } }; - - let read = result?; if read == 0 { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, @@ -150,12 +152,11 @@ async fn read_frame_payload( let mut filled = 0; while filled < message_len { let mut guard = async_socket.readable().await?; - let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])); - let Ok(result) = result else { - // Would block, try again. - continue; + let read = match guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..])) { + Ok(Ok(read)) => read, + Ok(Err(err)) => return Err(err), + Err(_would_block) => continue, }; - let read = result?; if read == 0 { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, @@ -261,7 +262,13 @@ impl AsyncSocket { } pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> { - let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; + // `socket2::Socket::pair()` also applies "common flags" (including + // `SO_NOSIGPIPE` on Apple platforms), which can fail for AF_UNIX sockets. + // Use `pair_raw()` to avoid those side effects, then restore `CLOEXEC` + // explicitly on both endpoints. + let (server, client) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?; + server.set_cloexec(true)?; + client.set_cloexec(true)?; Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?)) } @@ -314,11 +321,11 @@ async fn send_stream_frame( let mut include_fds = !fds.is_empty(); while written < frame.len() { let mut guard = socket.writable().await?; - let result = guard.try_io(|inner| { - send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds) - }); - let bytes_written = match result { - Ok(bytes_written) => bytes_written?, + let bytes_written = match guard + .try_io(|inner| send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds)) + { + Ok(Ok(bytes_written)) => bytes_written, + Ok(Err(err)) => return Err(err), Err(_would_block) => continue, }; if bytes_written == 0 { @@ -370,7 +377,13 @@ impl AsyncDatagramSocket { } pub fn pair() -> std::io::Result<(Self, Self)> { - let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; + // `socket2::Socket::pair()` also applies "common flags" (including + // `SO_NOSIGPIPE` on Apple platforms), which can fail for AF_UNIX sockets. + // Use `pair_raw()` to avoid those side effects, then restore `CLOEXEC` + // explicitly on both endpoints. + let (server, client) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?; + server.set_cloexec(true)?; + client.set_cloexec(true)?; Ok((Self::new(server)?, Self::new(client)?)) } @@ -472,7 +485,7 @@ mod tests { #[test] fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> { - let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?; + let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::DGRAM, None)?; let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err(); assert_eq!(std::io::ErrorKind::InvalidInput, err.kind()); @@ -481,7 +494,7 @@ mod tests { #[test] fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> { - let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?; + let (socket, _peer) = Socket::pair_raw(Domain::UNIX, Type::STREAM, None)?; let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?; let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err(); assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());