diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 55d6929c025..74511b6c81b 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1859,6 +1859,7 @@ dependencies = [ "shlex", "tempfile", "tokio", + "tokio-util", "tracing", "tracing-subscriber", ] diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml index a6a0721b985..74b252d97b1 100644 --- a/codex-rs/exec-server/Cargo.toml +++ b/codex-rs/exec-server/Cargo.toml @@ -32,6 +32,8 @@ codex-core = { workspace = true } codex-execpolicy = { workspace = true } codex-protocol = { workspace = true } codex-shell-command = { workspace = true } + +[target.'cfg(unix)'.dependencies] codex-shell-escalation = { workspace = true } rmcp = { workspace = true, default-features = false, features = [ "auth", @@ -51,6 +53,7 @@ serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } shlex = { workspace = true } tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] } +tokio-util = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } diff --git a/codex-rs/exec-server/src/unix.rs b/codex-rs/exec-server/src/unix.rs index 0e0355443bc..2bd892fd649 100644 --- a/codex-rs/exec-server/src/unix.rs +++ b/codex-rs/exec-server/src/unix.rs @@ -67,7 +67,7 @@ use codex_execpolicy::Decision; use codex_execpolicy::Policy; use codex_execpolicy::RuleMatch; use codex_shell_command::is_dangerous_command::command_might_be_dangerous; -use codex_shell_escalation as shell_escalation; +use codex_shell_escalation::unix::escalate_client::run; use rmcp::ErrorData as McpError; use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; @@ -160,7 +160,7 @@ pub async fn main_execve_wrapper() -> anyhow::Result<()> { .init(); let ExecveWrapperCli { file, argv } = ExecveWrapperCli::parse(); - let exit_code = shell_escalation::run(file, argv).await?; + let exit_code = run(file, argv).await?; std::process::exit(exit_code); } diff --git a/codex-rs/exec-server/src/unix/mcp.rs b/codex-rs/exec-server/src/unix/mcp.rs index 547d055c144..30dbc4d8194 100644 --- a/codex-rs/exec-server/src/unix/mcp.rs +++ b/codex-rs/exec-server/src/unix/mcp.rs @@ -6,11 +6,19 @@ use anyhow::Context as _; use anyhow::Result; use codex_core::MCP_SANDBOX_STATE_CAPABILITY; use codex_core::MCP_SANDBOX_STATE_METHOD; -use codex_core::SandboxState; +use codex_core::SandboxState as CoreSandboxState; +use codex_core::exec::process_exec_tool_call; use codex_execpolicy::Policy; +use codex_protocol::config_types::WindowsSandboxLevel; +use codex_protocol::models::SandboxPermissions as ProtocolSandboxPermissions; use codex_protocol::protocol::SandboxPolicy; -use codex_shell_escalation::EscalationPolicyFactory; -use codex_shell_escalation::run_escalate_server; +use codex_shell_escalation::unix::escalate_server::EscalationPolicyFactory; +use codex_shell_escalation::unix::escalate_server::ExecParams as ShellExecParams; +use codex_shell_escalation::unix::escalate_server::ExecResult as ShellExecResult; +use codex_shell_escalation::unix::escalate_server::SandboxState as ShellEscalationSandboxState; +use codex_shell_escalation::unix::escalate_server::ShellCommandExecutor; +use codex_shell_escalation::unix::escalate_server::run_escalate_server; +use codex_shell_escalation::unix::stopwatch::Stopwatch; use rmcp::ErrorData as McpError; use rmcp::RoleServer; use rmcp::ServerHandler; @@ -27,7 +35,9 @@ use rmcp::tool_handler; use rmcp::tool_router; use rmcp::transport::stdio; use serde_json::json; +use std::collections::HashMap; use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; use crate::unix::mcp_escalation_policy::McpEscalationPolicy; @@ -50,8 +60,8 @@ pub struct ExecResult { pub timed_out: bool, } -impl From for ExecResult { - fn from(result: codex_shell_escalation::ExecResult) -> Self { +impl From for ExecResult { + fn from(result: ShellExecResult) -> Self { Self { exit_code: result.exit_code, output: result.output, @@ -68,7 +78,7 @@ pub struct ExecTool { execve_wrapper: PathBuf, policy: Arc>, preserve_program_paths: bool, - sandbox_state: Arc>>, + sandbox_state: Arc>>, } #[derive(Debug, serde::Serialize, serde::Deserialize, rmcp::schemars::JsonSchema)] @@ -83,7 +93,7 @@ pub struct ExecParams { pub login: Option, } -impl From for codex_shell_escalation::ExecParams { +impl From for ShellExecParams { fn from(inner: ExecParams) -> Self { Self { command: inner.command, @@ -99,14 +109,51 @@ struct McpEscalationPolicyFactory { preserve_program_paths: bool, } +struct McpShellCommandExecutor; + +#[async_trait::async_trait] +impl ShellCommandExecutor for McpShellCommandExecutor { + async fn run( + &self, + command: Vec, + cwd: PathBuf, + env: HashMap, + cancel_rx: CancellationToken, + sandbox_state: &ShellEscalationSandboxState, + ) -> anyhow::Result { + let result = process_exec_tool_call( + codex_core::exec::ExecParams { + command, + cwd, + expiration: codex_core::exec::ExecExpiration::Cancellation(cancel_rx), + env, + network: None, + sandbox_permissions: ProtocolSandboxPermissions::UseDefault, + windows_sandbox_level: WindowsSandboxLevel::Disabled, + justification: None, + arg0: None, + }, + &sandbox_state.sandbox_policy, + &sandbox_state.sandbox_cwd, + &sandbox_state.codex_linux_sandbox_exe, + sandbox_state.use_linux_sandbox_bwrap, + None, + ) + .await?; + + Ok(ShellExecResult { + exit_code: result.exit_code, + output: result.aggregated_output.text, + duration: result.duration, + timed_out: result.timed_out, + }) + } +} + impl EscalationPolicyFactory for McpEscalationPolicyFactory { type Policy = McpEscalationPolicy; - fn create_policy( - &self, - policy: Arc>, - stopwatch: codex_shell_escalation::Stopwatch, - ) -> Self::Policy { + fn create_policy(&self, policy: Arc>, stopwatch: Stopwatch) -> Self::Policy { McpEscalationPolicy::new( policy, self.context.clone(), @@ -151,15 +198,21 @@ impl ExecTool { .read() .await .clone() - .unwrap_or_else(|| SandboxState { + .unwrap_or_else(|| CoreSandboxState { sandbox_policy: SandboxPolicy::new_read_only_policy(), codex_linux_sandbox_exe: None, sandbox_cwd: PathBuf::from(¶ms.workdir), use_linux_sandbox_bwrap: false, }); + let shell_sandbox_state = ShellEscalationSandboxState { + sandbox_policy: sandbox_state.sandbox_policy.clone(), + codex_linux_sandbox_exe: sandbox_state.codex_linux_sandbox_exe.clone(), + sandbox_cwd: sandbox_state.sandbox_cwd.clone(), + use_linux_sandbox_bwrap: sandbox_state.use_linux_sandbox_bwrap, + }; let result = run_escalate_server( params.into(), - &sandbox_state, + &shell_sandbox_state, &self.bash_path, &self.execve_wrapper, self.policy.clone(), @@ -168,6 +221,7 @@ impl ExecTool { preserve_program_paths: self.preserve_program_paths, }, effective_timeout, + &McpShellCommandExecutor, ) .await .map_err(|e| McpError::internal_error(e.to_string(), None))?; @@ -236,7 +290,7 @@ impl ServerHandler for ExecTool { )); }; - let Ok(sandbox_state) = serde_json::from_value::(params.clone()) else { + let Ok(sandbox_state) = serde_json::from_value::(params.clone()) else { return Err(McpError::invalid_params( "failed to deserialize sandbox state".to_string(), Some(params), diff --git a/codex-rs/exec-server/src/unix/mcp_escalation_policy.rs b/codex-rs/exec-server/src/unix/mcp_escalation_policy.rs index 98638182619..73b2c34e428 100644 --- a/codex-rs/exec-server/src/unix/mcp_escalation_policy.rs +++ b/codex-rs/exec-server/src/unix/mcp_escalation_policy.rs @@ -2,9 +2,9 @@ use std::path::Path; use codex_core::sandboxing::SandboxPermissions; use codex_execpolicy::Policy; -use codex_shell_escalation::EscalateAction; -use codex_shell_escalation::EscalationPolicy; -use codex_shell_escalation::Stopwatch; +use codex_shell_escalation::unix::escalate_protocol::EscalateAction; +use codex_shell_escalation::unix::escalation_policy::EscalationPolicy; +use codex_shell_escalation::unix::stopwatch::Stopwatch; use rmcp::ErrorData as McpError; use rmcp::RoleServer; use rmcp::model::CreateElicitationRequestParams;