diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 2d22351d1f8..e9f98da5dc7 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -489,6 +489,168 @@ impl AsyncManagedClient { } } +#[derive(Clone)] +struct RestartableManagedClient { + server_name: String, + config: McpServerConfig, + store_mode: OAuthCredentialsStoreMode, + cancel_token: CancellationToken, + tx_event: Sender, + elicitation_requests: ElicitationRequestManager, + codex_apps_tools_cache_context: Option, + sandbox_state: Arc>, + state: Arc>, +} + +struct RestartableManagedClientState { + generation: u64, + client: AsyncManagedClient, +} + +impl RestartableManagedClient { + #[allow(clippy::too_many_arguments)] + fn new( + server_name: String, + config: McpServerConfig, + store_mode: OAuthCredentialsStoreMode, + cancel_token: CancellationToken, + tx_event: Sender, + elicitation_requests: ElicitationRequestManager, + codex_apps_tools_cache_context: Option, + sandbox_state: Arc>, + ) -> Self { + let client = Self::build_async_managed_client( + &server_name, + &config, + store_mode, + &cancel_token, + tx_event.clone(), + elicitation_requests.clone(), + codex_apps_tools_cache_context.clone(), + ); + + Self { + server_name, + config, + store_mode, + cancel_token, + tx_event, + elicitation_requests, + codex_apps_tools_cache_context, + sandbox_state, + state: Arc::new(Mutex::new(RestartableManagedClientState { + generation: 0, + client, + })), + } + } + + fn build_async_managed_client( + server_name: &str, + config: &McpServerConfig, + store_mode: OAuthCredentialsStoreMode, + cancel_token: &CancellationToken, + tx_event: Sender, + elicitation_requests: ElicitationRequestManager, + codex_apps_tools_cache_context: Option, + ) -> AsyncManagedClient { + AsyncManagedClient::new( + server_name.to_string(), + config.clone(), + store_mode, + cancel_token.child_token(), + tx_event, + elicitation_requests, + codex_apps_tools_cache_context, + ) + } + + fn allows_tool(&self, tool: &str) -> bool { + ToolFilter::from_config(&self.config).allows(tool) + } + + async fn client(&self) -> Result { + let client = { self.state.lock().await.client.clone() }; + client.client().await + } + + async fn listed_tools(&self) -> Option> { + let client = { self.state.lock().await.client.clone() }; + client.listed_tools().await + } + + async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { + let client = { self.state.lock().await.client.clone() }; + client.notify_sandbox_state_change(sandbox_state).await + } + + async fn call_tool( + &self, + tool: &str, + arguments: Option, + ) -> Result { + let (managed, generation) = self.client_with_generation().await?; + let result = managed + .client + .call_tool(tool.to_string(), arguments.clone(), managed.tool_timeout) + .await; + + match result { + Ok(result) => Ok(result), + Err(err) if is_transport_closed_error(&err) => { + self.restart_if_stale(generation).await; + + let managed = self.client().await.map_err(|err| anyhow!(err))?; + let sandbox_state = { self.sandbox_state.lock().await.clone() }; + if let Err(error) = managed.notify_sandbox_state_change(&sandbox_state).await { + warn!( + "Failed to notify sandbox state to MCP server {} after restart: {error:#}", + self.server_name + ); + } + + managed + .client + .call_tool(tool.to_string(), arguments, managed.tool_timeout) + .await + } + Err(err) => Err(err), + } + } + + async fn client_with_generation(&self) -> Result<(ManagedClient, u64)> { + let (client, generation) = { + let guard = self.state.lock().await; + (guard.client.clone(), guard.generation) + }; + let managed = client.client().await.map_err(|err| anyhow!(err))?; + Ok((managed, generation)) + } + + async fn restart_if_stale(&self, generation: u64) { + let mut guard = self.state.lock().await; + if guard.generation != generation { + return; + } + + warn!( + "MCP transport closed for server {}; restarting", + self.server_name + ); + + guard.generation = guard.generation.saturating_add(1); + guard.client = Self::build_async_managed_client( + &self.server_name, + &self.config, + self.store_mode, + &self.cancel_token, + self.tx_event.clone(), + self.elicitation_requests.clone(), + self.codex_apps_tools_cache_context.clone(), + ); + } +} + pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state"; /// Custom MCP request to push sandbox state updates. @@ -507,9 +669,10 @@ pub struct SandboxState { /// A thin wrapper around a set of running [`RmcpClient`] instances. pub(crate) struct McpConnectionManager { - clients: HashMap, + clients: HashMap, server_origins: HashMap, elicitation_requests: ElicitationRequestManager, + sandbox_state: Option>>, } impl McpConnectionManager { @@ -518,6 +681,7 @@ impl McpConnectionManager { clients: HashMap::new(), server_origins: HashMap::new(), elicitation_requests: ElicitationRequestManager::new(approval_policy.value()), + sandbox_state: None, } } @@ -559,6 +723,7 @@ impl McpConnectionManager { let mut join_set = JoinSet::new(); let elicitation_requests = ElicitationRequestManager::new(approval_policy.value()); let mcp_servers = mcp_servers.clone(); + let sandbox_state = Arc::new(Mutex::new(initial_sandbox_state.clone())); for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) { if let Some(origin) = transport_origin(&cfg.transport) { server_origins.insert(server_name.clone(), origin); @@ -580,7 +745,7 @@ impl McpConnectionManager { } else { None }; - let async_managed_client = AsyncManagedClient::new( + let async_managed_client = RestartableManagedClient::new( server_name.clone(), cfg, store_mode, @@ -588,11 +753,12 @@ impl McpConnectionManager { tx_event.clone(), elicitation_requests.clone(), codex_apps_tools_cache_context, + sandbox_state.clone(), ); clients.insert(server_name.clone(), async_managed_client.clone()); let tx_event = tx_event.clone(); let auth_entry = auth_entries.get(&server_name).cloned(); - let sandbox_state = initial_sandbox_state.clone(); + let sandbox_state = sandbox_state.clone(); join_set.spawn(async move { let outcome = async_managed_client.client().await; if cancel_token.is_cancelled() { @@ -601,8 +767,9 @@ impl McpConnectionManager { let status = match &outcome { Ok(_) => { // Send sandbox state notification immediately after Ready + let sandbox_state_snapshot = { sandbox_state.lock().await.clone() }; if let Err(e) = async_managed_client - .notify_sandbox_state_change(&sandbox_state) + .notify_sandbox_state_change(&sandbox_state_snapshot) .await { warn!( @@ -637,6 +804,7 @@ impl McpConnectionManager { clients, server_origins, elicitation_requests: elicitation_requests.clone(), + sandbox_state: Some(sandbox_state), }; tokio::spawn(async move { let outcomes = join_set.join_all().await; @@ -919,16 +1087,18 @@ impl McpConnectionManager { tool: &str, arguments: Option, ) -> Result { - let client = self.client_by_name(server).await?; - if !client.tool_filter.allows(tool) { + let client = self + .clients + .get(server) + .ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?; + if !client.allows_tool(tool) { return Err(anyhow!( "tool '{tool}' is disabled for MCP server '{server}'" )); } let result: rmcp::model::CallToolResult = client - .client - .call_tool(tool.to_string(), arguments, client.tool_timeout) + .call_tool(tool, arguments) .await .with_context(|| format!("tool call failed for `{server}/{tool}`"))?; @@ -1006,6 +1176,10 @@ impl McpConnectionManager { } pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { + if let Some(state) = &self.sandbox_state { + *state.lock().await = sandbox_state.clone(); + } + let mut join_set = JoinSet::new(); for async_managed_client in self.clients.values() { @@ -1034,6 +1208,14 @@ impl McpConnectionManager { } } +fn is_transport_closed_error(error: &anyhow::Error) -> bool { + error.chain().any(|err| { + err.to_string() + .to_ascii_lowercase() + .contains("transport closed") + }) +} + async fn emit_update( tx_event: &Sender, update: McpStartupUpdateEvent, @@ -1638,6 +1820,57 @@ mod tests { } } + fn create_test_mcp_server_config() -> McpServerConfig { + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "test-server".to_string(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + } + } + + fn create_test_sandbox_state() -> SandboxState { + SandboxState { + sandbox_policy: SandboxPolicy::new_read_only_policy(), + codex_linux_sandbox_exe: None, + sandbox_cwd: PathBuf::new(), + use_linux_sandbox_bwrap: false, + } + } + + fn create_restartable_client_for_tests( + server_name: &str, + async_client: AsyncManagedClient, + ) -> RestartableManagedClient { + let (tx_event, _rx_event) = async_channel::unbounded::(); + RestartableManagedClient { + server_name: server_name.to_string(), + config: create_test_mcp_server_config(), + store_mode: OAuthCredentialsStoreMode::Auto, + cancel_token: CancellationToken::new(), + tx_event, + elicitation_requests: ElicitationRequestManager::new(AskForApproval::OnFailure), + codex_apps_tools_cache_context: None, + sandbox_state: Arc::new(Mutex::new(create_test_sandbox_state())), + state: Arc::new(Mutex::new(RestartableManagedClientState { + generation: 0, + client: async_client, + })), + } + } + #[test] fn elicitation_reject_policy_defaults_to_prompting() { assert!(!elicitation_is_rejected_by_policy( @@ -1987,11 +2220,14 @@ mod tests { let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); manager.clients.insert( CODEX_APPS_MCP_SERVER_NAME.to_string(), - AsyncManagedClient { - client: pending_client, - startup_snapshot: Some(startup_tools), - startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - }, + create_restartable_client_for_tests( + CODEX_APPS_MCP_SERVER_NAME, + AsyncManagedClient { + client: pending_client, + startup_snapshot: Some(startup_tools), + startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + }, + ), ); let tools = manager.list_all_tools().await; @@ -2012,11 +2248,14 @@ mod tests { let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); manager.clients.insert( CODEX_APPS_MCP_SERVER_NAME.to_string(), - AsyncManagedClient { - client: pending_client, - startup_snapshot: None, - startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - }, + create_restartable_client_for_tests( + CODEX_APPS_MCP_SERVER_NAME, + AsyncManagedClient { + client: pending_client, + startup_snapshot: None, + startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + }, + ), ); let timeout_result = @@ -2034,11 +2273,14 @@ mod tests { let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); manager.clients.insert( CODEX_APPS_MCP_SERVER_NAME.to_string(), - AsyncManagedClient { - client: pending_client, - startup_snapshot: Some(Vec::new()), - startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - }, + create_restartable_client_for_tests( + CODEX_APPS_MCP_SERVER_NAME, + AsyncManagedClient { + client: pending_client, + startup_snapshot: Some(Vec::new()), + startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + }, + ), ); let timeout_result = @@ -2065,11 +2307,14 @@ mod tests { let startup_complete = Arc::new(std::sync::atomic::AtomicBool::new(true)); manager.clients.insert( CODEX_APPS_MCP_SERVER_NAME.to_string(), - AsyncManagedClient { - client: failed_client, - startup_snapshot: Some(startup_tools), - startup_complete, - }, + create_restartable_client_for_tests( + CODEX_APPS_MCP_SERVER_NAME, + AsyncManagedClient { + client: failed_client, + startup_snapshot: Some(startup_tools), + startup_complete, + }, + ), ); let tools = manager.list_all_tools().await; diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index dfbac85a096..e8816d5e1ae 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -531,6 +531,165 @@ async fn stdio_image_responses_are_sanitized_for_text_only_model() -> anyhow::Re Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[serial(mcp_transport_restart)] +async fn mcp_tool_call_recovers_from_transport_closed() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + + let server_name = "rmcp_flaky"; + let tool_name = format!("mcp__{server_name}__echo"); + + mount_sse_once( + &server, + responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_function_call("call-1", &tool_name, "{\"message\":\"first\"}"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + mount_sse_once( + &server, + responses::sse(vec![ + responses::ev_assistant_message("msg-1", "first call done"), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + mount_sse_once( + &server, + responses::sse(vec![ + responses::ev_response_created("resp-3"), + responses::ev_function_call("call-2", &tool_name, "{\"message\":\"second\"}"), + responses::ev_completed("resp-3"), + ]), + ) + .await; + mount_sse_once( + &server, + responses::sse(vec![ + responses::ev_assistant_message("msg-2", "second call done"), + responses::ev_completed("resp-4"), + ]), + ) + .await; + + let rmcp_test_server_bin = stdio_server_bin()?; + let fixture = test_codex() + .with_config(move |config| { + let mut servers = config.mcp_servers.get().clone(); + servers.insert( + server_name.to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: rmcp_test_server_bin, + args: Vec::new(), + env: Some(HashMap::from([( + "MCP_TEST_EXIT_AFTER_CALL".to_string(), + "1".to_string(), + )])), + env_vars: Vec::new(), + cwd: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: Some(Duration::from_secs(10)), + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + ); + config + .mcp_servers + .set(servers) + .expect("test mcp servers should accept any configuration"); + }) + .build(&server) + .await?; + + let session_model = fixture.session_configured.model.clone(); + + fixture + .codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "call the rmcp echo tool".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::new_read_only_policy(), + model: session_model.clone(), + effort: None, + summary: None, + collaboration_mode: None, + personality: None, + }) + .await?; + + let end_event = wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallEnd(_)) + }) + .await; + let EventMsg::McpToolCallEnd(end) = end_event else { + unreachable!("event guard guarantees McpToolCallEnd"); + }; + assert!( + end.result.is_ok(), + "first tool call should succeed: {:?}", + end.result + ); + + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + // Allow the flaky server to exit after the first call so the second call hits a closed + // transport and has to recover. + sleep(Duration::from_millis(300)).await; + + fixture + .codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "call the rmcp echo tool again".into(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::new_read_only_policy(), + model: session_model, + effort: None, + summary: None, + collaboration_mode: None, + personality: None, + }) + .await?; + + let end_event = wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallEnd(_)) + }) + .await; + let EventMsg::McpToolCallEnd(end) = end_event else { + unreachable!("event guard guarantees McpToolCallEnd"); + }; + assert!( + end.result.is_ok(), + "second tool call should succeed after transport restart: {:?}", + end.result + ); + + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await; + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] #[serial(mcp_test_value)] async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> { diff --git a/codex-rs/rmcp-client/src/bin/test_stdio_server.rs b/codex-rs/rmcp-client/src/bin/test_stdio_server.rs index d7708bf5ed3..06cfff96fd8 100644 --- a/codex-rs/rmcp-client/src/bin/test_stdio_server.rs +++ b/codex-rs/rmcp-client/src/bin/test_stdio_server.rs @@ -1,6 +1,9 @@ use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::time::Duration; use rmcp::ErrorData as McpError; use rmcp::ServiceExt; @@ -25,6 +28,7 @@ use rmcp::model::Tool; use serde::Deserialize; use serde_json::json; use tokio::task; +use tokio::time; #[derive(Clone)] struct TestToolServer { @@ -35,6 +39,8 @@ struct TestToolServer { const MEMO_URI: &str = "memo://codex/example-note"; const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server."; +const EXIT_AFTER_CALL_ENV_VAR: &str = "MCP_TEST_EXIT_AFTER_CALL"; +static SHOULD_EXIT_AFTER_CALL: AtomicBool = AtomicBool::new(true); const SMALL_PNG_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg=="; pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) { @@ -295,6 +301,13 @@ impl ServerHandler for TestToolServer { request: CallToolRequestParams, _context: rmcp::service::RequestContext, ) -> Result { + let exit_after_call = std::env::var(EXIT_AFTER_CALL_ENV_VAR) + .map(|value| { + let value = value.trim(); + value == "1" || value.eq_ignore_ascii_case("true") + }) + .unwrap_or(false); + match request.name.as_ref() { "echo" => { let args: EchoArgs = match request.arguments { @@ -316,6 +329,13 @@ impl ServerHandler for TestToolServer { "env": env_snapshot.get("MCP_TEST_VALUE"), }); + if exit_after_call && SHOULD_EXIT_AFTER_CALL.swap(false, Ordering::SeqCst) { + task::spawn(async { + time::sleep(Duration::from_millis(200)).await; + std::process::exit(0); + }); + } + Ok(CallToolResult { content: Vec::new(), structured_content: Some(structured_content), @@ -340,6 +360,13 @@ impl ServerHandler for TestToolServer { ) })?; + if exit_after_call && SHOULD_EXIT_AFTER_CALL.swap(false, Ordering::SeqCst) { + task::spawn(async { + time::sleep(Duration::from_millis(200)).await; + std::process::exit(0); + }); + } + Ok(CallToolResult::success(vec![rmcp::model::Content::image( data_b64, mime_type, )]))