diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index df389c5ef74..d25516706ad 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -12,7 +12,6 @@ use codex_core::CodexAuth; use codex_core::config::types::McpServerConfig; use codex_core::config::types::McpServerTransportConfig; use codex_core::models_manager::manager::RefreshStrategy; - use codex_protocol::config_types::ReasoningSummary; use codex_protocol::openai_models::ConfigShellToolType; use codex_protocol::openai_models::InputModality; @@ -28,6 +27,10 @@ use codex_protocol::protocol::McpToolCallBeginEvent; use codex_protocol::protocol::Op; use codex_protocol::protocol::SandboxPolicy; use codex_protocol::user_input::UserInput; +use codex_rmcp_client::ElicitationAction; +use codex_rmcp_client::ElicitationResponse; +use codex_rmcp_client::OAuthCredentialsStoreMode; +use codex_rmcp_client::RmcpClient; use codex_utils_cargo_bin::cargo_bin; use core_test_support::responses; use core_test_support::responses::mount_models_once; @@ -36,6 +39,13 @@ use core_test_support::skip_if_no_network; use core_test_support::stdio_server_bin; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; +use futures::FutureExt; +use rmcp::model::ClientCapabilities; +use rmcp::model::ElicitationCapability; +use rmcp::model::FormElicitationCapability; +use rmcp::model::Implementation; +use rmcp::model::InitializeRequestParams; +use rmcp::model::ProtocolVersion; use serde_json::Value; use serde_json::json; use serial_test::serial; @@ -1056,6 +1066,231 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> { Ok(()) } +/// This test writes to a fallback credentials file in CODEX_HOME. +#[serial(codex_home)] +#[test] +fn streamable_http_with_oauth_refresh_adopts_rotated_credentials() -> anyhow::Result<()> { + const TEST_STACK_SIZE_BYTES: usize = 8 * 1024 * 1024; + + let handle = std::thread::Builder::new() + .name("streamable_http_with_oauth_refresh_adopts_rotated_credentials".to_string()) + .stack_size(TEST_STACK_SIZE_BYTES) + .spawn(|| -> anyhow::Result<()> { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build()?; + runtime.block_on(streamable_http_with_oauth_refresh_adopts_rotated_credentials_impl()) + })?; + + match handle.join() { + Ok(result) => result, + Err(_) => Err(anyhow::anyhow!( + "streamable_http_with_oauth_refresh_adopts_rotated_credentials thread panicked" + )), + } +} + +#[allow(clippy::expect_used)] +async fn streamable_http_with_oauth_refresh_adopts_rotated_credentials_impl() -> anyhow::Result<()> +{ + skip_if_no_network!(Ok(())); + + let server_name = "rmcp_http_oauth_refresh_race"; + let initial_access_token = "initial-access-token"; + let initial_refresh_token = "initial-refresh-token"; + let rotated_access_token = "rotated-access-token"; + let rotated_refresh_token = "rotated-refresh-token"; + let rmcp_http_server_bin = match cargo_bin("test_streamable_http_server") { + Ok(path) => path, + Err(err) => { + eprintln!("test_streamable_http_server binary not available, skipping test: {err}"); + return Ok(()); + } + }; + + let listener = TcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + drop(listener); + let bind_addr = format!("127.0.0.1:{port}"); + let server_url = format!("http://{bind_addr}/mcp"); + + let mut http_server_child = Command::new(&rmcp_http_server_bin) + .kill_on_drop(true) + .env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr) + .env("MCP_EXPECT_BEARER", initial_access_token) + .env("MCP_EXPECT_REFRESH_TOKEN", initial_refresh_token) + .env("MCP_REFRESH_NEXT_ACCESS_TOKEN", rotated_access_token) + .env("MCP_REFRESH_NEXT_REFRESH_TOKEN", rotated_refresh_token) + .env("MCP_REFRESH_EXPIRES_IN", "3600") + .env("MCP_REFRESH_SINGLE_USE", "1") + .spawn()?; + + wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5)) + .await?; + + let temp_home = tempdir()?; + let _guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str()); + let initial_expires_at = SystemTime::now() + .checked_add(Duration::from_secs(1)) + .ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))? + .duration_since(UNIX_EPOCH)? + .as_millis() as u64; + write_fallback_oauth_tokens_with_expiry( + temp_home.path(), + server_name, + &server_url, + "test-client-id", + initial_access_token, + initial_refresh_token, + initial_expires_at, + )?; + + let client_a = RmcpClient::new_streamable_http_client( + server_name, + &server_url, + None, + None, + None, + OAuthCredentialsStoreMode::File, + ) + .await?; + let client_b = RmcpClient::new_streamable_http_client( + server_name, + &server_url, + None, + None, + None, + OAuthCredentialsStoreMode::File, + ) + .await?; + + client_a + .initialize( + rmcp_initialize_params(), + Some(Duration::from_secs(5)), + noop_send_elicitation(), + ) + .await?; + client_b + .initialize( + rmcp_initialize_params(), + Some(Duration::from_secs(5)), + noop_send_elicitation(), + ) + .await?; + + let tools_a = client_a + .list_tools(None, Some(Duration::from_secs(5))) + .await?; + assert_eq!(tools_a.tools.len(), 1); + assert_eq!(tools_a.tools[0].name.as_ref(), "echo"); + assert_stored_oauth_tokens( + temp_home.path(), + server_name, + &server_url, + rotated_access_token, + rotated_refresh_token, + )?; + + let tools_b = client_b + .list_tools(None, Some(Duration::from_secs(5))) + .await?; + assert_eq!(tools_b.tools.len(), 1); + assert_eq!(tools_b.tools[0].name.as_ref(), "echo"); + assert_stored_oauth_tokens( + temp_home.path(), + server_name, + &server_url, + rotated_access_token, + rotated_refresh_token, + )?; + + match http_server_child.try_wait() { + Ok(Some(_)) => {} + Ok(None) => { + let _ = http_server_child.kill().await; + } + Err(error) => { + eprintln!("failed to check streamable http oauth server status: {error}"); + let _ = http_server_child.kill().await; + } + } + if let Err(error) = http_server_child.wait().await { + eprintln!("failed to await streamable http oauth server shutdown: {error}"); + } + + Ok(()) +} + +fn rmcp_initialize_params() -> InitializeRequestParams { + InitializeRequestParams { + meta: None, + capabilities: ClientCapabilities { + experimental: None, + extensions: None, + roots: None, + sampling: None, + elicitation: Some(ElicitationCapability { + form: Some(FormElicitationCapability { + schema_validation: None, + }), + url: None, + }), + tasks: None, + }, + client_info: Implementation { + name: "codex-test".into(), + version: "0.0.0-test".into(), + title: Some("Codex rmcp oauth refresh test".into()), + description: None, + icons: None, + website_url: None, + }, + protocol_version: ProtocolVersion::V_2025_06_18, + } +} + +fn noop_send_elicitation() -> codex_rmcp_client::SendElicitation { + Box::new(|_, _| { + async { + Ok(ElicitationResponse { + action: ElicitationAction::Accept, + content: Some(json!({})), + }) + } + .boxed() + }) +} + +fn assert_stored_oauth_tokens( + home: &Path, + server_name: &str, + server_url: &str, + expected_access_token: &str, + expected_refresh_token: &str, +) -> anyhow::Result<()> { + let file_path = home.join(".credentials.json"); + let stored: Value = serde_json::from_slice(&fs::read(&file_path)?)?; + let entries = stored + .as_object() + .ok_or_else(|| anyhow::anyhow!("expected fallback OAuth credential map"))?; + let has_expected_tokens = entries.values().any(|entry| { + entry.as_object().is_some_and(|entry| { + entry.get("server_name").and_then(Value::as_str) == Some(server_name) + && entry.get("server_url").and_then(Value::as_str) == Some(server_url) + && entry.get("access_token").and_then(Value::as_str) == Some(expected_access_token) + && entry.get("refresh_token").and_then(Value::as_str) + == Some(expected_refresh_token) + }) + }); + assert!( + has_expected_tokens, + "expected stored OAuth credentials for {server_name} at {server_url} to include access_token={expected_access_token} refresh_token={expected_refresh_token}, got {stored}", + ); + Ok(()) +} + async fn wait_for_streamable_http_server( server_child: &mut Child, address: &str, @@ -1111,7 +1346,26 @@ fn write_fallback_oauth_tokens( .ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))? .duration_since(UNIX_EPOCH)? .as_millis() as u64; + write_fallback_oauth_tokens_with_expiry( + home, + server_name, + server_url, + client_id, + access_token, + refresh_token, + expires_at, + ) +} +fn write_fallback_oauth_tokens_with_expiry( + home: &Path, + server_name: &str, + server_url: &str, + client_id: &str, + access_token: &str, + refresh_token: &str, + expires_at: u64, +) -> anyhow::Result<()> { let store = serde_json::json!({ "stub": { "server_name": server_name, diff --git a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs index 821850d2a8e..be1adf478e1 100644 --- a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs +++ b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use axum::Router; use axum::body::Body; +use axum::extract::Form; use axum::extract::State; use axum::http::Request; use axum::http::StatusCode; @@ -15,6 +16,7 @@ use axum::middleware; use axum::middleware::Next; use axum::response::Response; use axum::routing::get; +use axum::routing::post; use rmcp::ErrorData as McpError; use rmcp::handler::server::ServerHandler; use rmcp::model::CallToolRequestParams; @@ -39,6 +41,8 @@ use rmcp::transport::StreamableHttpService; use rmcp::transport::streamable_http_server::session::local::LocalSessionManager; use serde::Deserialize; use serde_json::json; +use tokio::sync::Mutex; +use tokio::sync::RwLock; use tokio::task; #[derive(Clone)] @@ -48,6 +52,22 @@ struct TestToolServer { resource_templates: Arc>, } +#[derive(Clone)] +struct AuthState { + current_bearer: Arc>>, + refresh_state: Option>>, +} + +#[derive(Debug)] +struct RefreshTokenState { + current_refresh_token: String, + next_access_token: String, + next_refresh_token: String, + expires_in: u64, + single_use: bool, + used_once: bool, +} + const MEMO_URI: &str = "memo://codex/example-note"; const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server."; @@ -263,6 +283,15 @@ async fn main() -> Result<(), Box> { }; eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp"); + let auth_state = AuthState { + current_bearer: Arc::new(RwLock::new( + std::env::var("MCP_EXPECT_BEARER") + .ok() + .map(|token| format!("Bearer {token}")), + )), + refresh_state: refresh_state_from_env(), + }; + let router = Router::new() .route( "/.well-known/oauth-authorization-server/mcp", @@ -284,6 +313,7 @@ async fn main() -> Result<(), Box> { } }), ) + .route("/oauth/token", post(oauth_refresh_token)) .nest_service( "/mcp", StreamableHttpService::new( @@ -291,28 +321,108 @@ async fn main() -> Result<(), Box> { Arc::new(LocalSessionManager::default()), StreamableHttpServerConfig::default(), ), - ); - - let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") { - let expected = Arc::new(format!("Bearer {token}")); - router.layer(middleware::from_fn_with_state(expected, require_bearer)) - } else { - router - }; + ) + .with_state(auth_state.clone()) + .layer(middleware::from_fn_with_state(auth_state, require_bearer)); axum::serve(listener, router).await?; task::yield_now().await; Ok(()) } +fn refresh_state_from_env() -> Option>> { + let current_refresh_token = std::env::var("MCP_EXPECT_REFRESH_TOKEN").ok()?; + let next_access_token = std::env::var("MCP_REFRESH_NEXT_ACCESS_TOKEN").ok()?; + let next_refresh_token = std::env::var("MCP_REFRESH_NEXT_REFRESH_TOKEN").ok()?; + let expires_in = std::env::var("MCP_REFRESH_EXPIRES_IN") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(3600); + let single_use = std::env::var("MCP_REFRESH_SINGLE_USE") + .ok() + .is_some_and(|value| value == "1"); + + Some(Arc::new(Mutex::new(RefreshTokenState { + current_refresh_token, + next_access_token, + next_refresh_token, + expires_in, + single_use, + used_once: false, + }))) +} + +async fn oauth_refresh_token( + State(state): State, + Form(form): Form>, +) -> Response { + let Some(refresh_state) = state.refresh_state.clone() else { + return json_response(StatusCode::NOT_FOUND, json!({ "error": "not_found" })); + }; + + if form.get("grant_type").map(String::as_str) != Some("refresh_token") { + return json_response( + StatusCode::BAD_REQUEST, + json!({ "error": "unsupported_grant_type" }), + ); + } + + let provided_refresh_token = form.get("refresh_token").map(String::as_str); + let mut refresh_state = refresh_state.lock().await; + if refresh_state.single_use && refresh_state.used_once { + return json_response( + StatusCode::UNAUTHORIZED, + json!({ + "error": "invalid_grant", + "error_description": "refresh token was already used", + "code": "refresh_token_reused", + }), + ); + } + if provided_refresh_token != Some(refresh_state.current_refresh_token.as_str()) { + return json_response( + StatusCode::UNAUTHORIZED, + json!({ + "error": "invalid_grant", + "error_description": "refresh token was already used", + "code": "refresh_token_reused", + }), + ); + } + + let access_token = refresh_state.next_access_token.clone(); + let refresh_token = refresh_state.next_refresh_token.clone(); + let expires_in = refresh_state.expires_in; + refresh_state.current_refresh_token = refresh_token.clone(); + refresh_state.used_once = true; + *state.current_bearer.write().await = Some(format!("Bearer {access_token}")); + + json_response( + StatusCode::OK, + json!({ + "access_token": access_token, + "token_type": "Bearer", + "refresh_token": refresh_token, + "expires_in": expires_in, + }), + ) +} + async fn require_bearer( - State(expected): State>, + State(state): State, request: Request, next: Next, ) -> Result { - if request.uri().path().contains("/.well-known/") { + let request_path = request.uri().path(); + if request_path.contains("/.well-known/") || request_path.contains("/oauth/token") { return Ok(next.run(request).await); } + + let expected = state.current_bearer.read().await.clone(); + let Some(expected) = expected else { + return Ok(next.run(request).await); + }; + if request .headers() .get(AUTHORIZATION) @@ -323,3 +433,14 @@ async fn require_bearer( Err(StatusCode::UNAUTHORIZED) } } + +fn json_response(status: StatusCode, body: serde_json::Value) -> Response { + #[expect(clippy::expect_used)] + Response::builder() + .status(status) + .header(CONTENT_TYPE, "application/json") + .body(Body::from( + serde_json::to_vec(&body).expect("failed to serialize JSON response"), + )) + .expect("valid JSON response") +} diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index cdb64ff1517..75599b73250 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -25,7 +25,10 @@ use oauth2::RefreshToken; use oauth2::Scope; use oauth2::TokenResponse; use oauth2::basic::BasicTokenType; +use rmcp::transport::auth::CredentialStore; +use rmcp::transport::auth::InMemoryCredentialStore; use rmcp::transport::auth::OAuthTokenResponse; +use rmcp::transport::auth::StoredCredentials; use schemars::JsonSchema; use serde::Deserialize; use serde::Serialize; @@ -273,15 +276,32 @@ struct OAuthPersistorInner { server_name: String, url: String, authorization_manager: Arc>, + runtime_credentials: InMemoryCredentialStore, store_mode: OAuthCredentialsStoreMode, last_credentials: Mutex>, } +#[derive(Debug, Clone, PartialEq)] +enum GuardedRefreshOutcome { + NoAction, + ReloadedChanged(StoredOAuthTokens), + ReloadedNoChange, + MissingOrInvalid, + ReloadFailed, +} + +#[derive(Debug, PartialEq)] +enum GuardedRefreshPersistedCredentials { + Loaded(Option), + ReloadFailed, +} + impl OAuthPersistor { pub(crate) fn new( server_name: String, url: String, authorization_manager: Arc>, + runtime_credentials: InMemoryCredentialStore, store_mode: OAuthCredentialsStoreMode, initial_credentials: Option, ) -> Self { @@ -290,6 +310,7 @@ impl OAuthPersistor { server_name, url, authorization_manager, + runtime_credentials, store_mode, last_credentials: Mutex::new(initial_credentials), }), @@ -350,28 +371,220 @@ impl OAuthPersistor { Ok(()) } + /// Guard refreshes against multi-process refresh-token reuse. + /// + /// MCP OAuth credentials live in shared storage, but each Codex process also keeps an + /// in-memory snapshot. Before refreshing, reload the shared credentials and compare them to + /// the cached copy: + /// - if the local cache was cleared, reload shared storage first so this process can recover + /// when another process logs in and persists fresh credentials; + /// - if shared storage changed, another process already refreshed, so adopt those credentials + /// in the live runtime and skip the local refresh; + /// - if shared storage is unchanged, this process still owns the refresh and can rotate the + /// tokens with the authority; + /// - if shared storage no longer has credentials, treat that as logged out and clear the live + /// runtime instead of sending a stale refresh token. pub(crate) async fn refresh_if_needed(&self) -> Result<()> { - let expires_at = { + let mut cached_credentials = { let guard = self.inner.last_credentials.lock().await; - guard.as_ref().and_then(|tokens| tokens.expires_at) + guard.clone() }; - if !token_needs_refresh(expires_at) { - return Ok(()); + if cached_credentials.is_none() + && let Some(credentials) = load_oauth_tokens_when_cache_missing( + &self.inner.server_name, + &self.inner.url, + self.inner.store_mode, + ) + { + self.apply_runtime_credentials(Some(credentials.clone())) + .await?; + cached_credentials = Some(credentials); } + match self.guarded_refresh_outcome(cached_credentials.as_ref()) { + GuardedRefreshOutcome::NoAction => Ok(()), + GuardedRefreshOutcome::ReloadedChanged(credentials) => { + self.apply_runtime_credentials(Some(credentials)).await + } + GuardedRefreshOutcome::ReloadedNoChange => { + { + let manager = self.inner.authorization_manager.clone(); + let guard = manager.lock().await; + guard.refresh_token().await.with_context(|| { + format!( + "failed to refresh OAuth tokens for server {}", + self.inner.server_name + ) + })?; + } + + self.persist_if_needed().await + } + GuardedRefreshOutcome::MissingOrInvalid => self.apply_runtime_credentials(None).await, + GuardedRefreshOutcome::ReloadFailed => Ok(()), + } + } + + fn guarded_refresh_outcome( + &self, + cached_credentials: Option<&StoredOAuthTokens>, + ) -> GuardedRefreshOutcome { + let Some(cached_credentials) = cached_credentials else { + return GuardedRefreshOutcome::NoAction; + }; + + if !token_needs_refresh(cached_credentials.expires_at) { + return GuardedRefreshOutcome::NoAction; + } + + match load_oauth_tokens_for_guarded_refresh( + &self.inner.server_name, + &self.inner.url, + self.inner.store_mode, + ) { + GuardedRefreshPersistedCredentials::Loaded(persisted_credentials) => { + determine_guarded_refresh_outcome(cached_credentials, persisted_credentials) + } + GuardedRefreshPersistedCredentials::ReloadFailed => GuardedRefreshOutcome::ReloadFailed, + } + } + + async fn apply_runtime_credentials( + &self, + credentials: Option, + ) -> Result<()> { { let manager = self.inner.authorization_manager.clone(); - let guard = manager.lock().await; - guard.refresh_token().await.with_context(|| { - format!( - "failed to refresh OAuth tokens for server {}", - self.inner.server_name - ) - })?; + let mut guard = manager.lock().await; + + match credentials.as_ref() { + Some(credentials) => { + self.inner + .runtime_credentials + .save(StoredCredentials { + client_id: credentials.client_id.clone(), + token_response: Some(credentials.token_response.0.clone()), + }) + .await?; + guard + .configure_client_id(&credentials.client_id) + .with_context(|| { + format!( + "failed to reconfigure OAuth client for server {}", + self.inner.server_name + ) + })?; + } + None => { + self.inner.runtime_credentials.clear().await?; + } + } + } + + let mut last_credentials = self.inner.last_credentials.lock().await; + *last_credentials = credentials; + Ok(()) + } +} + +fn load_oauth_tokens_for_guarded_refresh( + server_name: &str, + url: &str, + store_mode: OAuthCredentialsStoreMode, +) -> GuardedRefreshPersistedCredentials { + let keyring_store = DefaultKeyringStore; + match store_mode { + OAuthCredentialsStoreMode::Auto => { + load_oauth_tokens_for_guarded_refresh_with_keyring_fallback( + &keyring_store, + server_name, + url, + ) + } + OAuthCredentialsStoreMode::File => guarded_refresh_persisted_credentials_from_load_result( + load_oauth_tokens_from_file(server_name, url), + server_name, + ), + OAuthCredentialsStoreMode::Keyring => { + guarded_refresh_persisted_credentials_from_load_result( + load_oauth_tokens_from_keyring(&keyring_store, server_name, url) + .with_context(|| "failed to read OAuth tokens from keyring".to_string()), + server_name, + ) + } + } +} + +fn load_oauth_tokens_when_cache_missing( + server_name: &str, + url: &str, + store_mode: OAuthCredentialsStoreMode, +) -> Option { + match load_oauth_tokens_for_guarded_refresh(server_name, url, store_mode) { + GuardedRefreshPersistedCredentials::Loaded(Some(credentials)) => Some(credentials), + GuardedRefreshPersistedCredentials::Loaded(None) + | GuardedRefreshPersistedCredentials::ReloadFailed => None, + } +} + +fn load_oauth_tokens_for_guarded_refresh_with_keyring_fallback( + keyring_store: &K, + server_name: &str, + url: &str, +) -> GuardedRefreshPersistedCredentials { + match load_oauth_tokens_from_keyring(keyring_store, server_name, url) { + Ok(Some(tokens)) => GuardedRefreshPersistedCredentials::Loaded(Some(tokens)), + Ok(None) => guarded_refresh_persisted_credentials_from_load_result( + load_oauth_tokens_from_file(server_name, url), + server_name, + ), + Err(error) => { + warn!("failed to read OAuth tokens from keyring: {error}"); + match load_oauth_tokens_from_file(server_name, url) { + Ok(Some(tokens)) => GuardedRefreshPersistedCredentials::Loaded(Some(tokens)), + Ok(None) => { + warn!( + "failed to reload OAuth tokens for server {server_name}: keyring read failed and no fallback file credentials were available" + ); + GuardedRefreshPersistedCredentials::ReloadFailed + } + Err(file_error) => { + warn!( + "failed to reload OAuth tokens for server {server_name}: keyring read failed ({error}) and fallback file reload failed: {file_error}" + ); + GuardedRefreshPersistedCredentials::ReloadFailed + } + } } + } +} - self.persist_if_needed().await +#[cfg(test)] +fn guarded_refresh_outcome_from_load_result( + cached_credentials: &StoredOAuthTokens, + persisted_credentials: Result>, + server_name: &str, +) -> GuardedRefreshOutcome { + match guarded_refresh_persisted_credentials_from_load_result(persisted_credentials, server_name) + { + GuardedRefreshPersistedCredentials::Loaded(persisted_credentials) => { + determine_guarded_refresh_outcome(cached_credentials, persisted_credentials) + } + GuardedRefreshPersistedCredentials::ReloadFailed => GuardedRefreshOutcome::ReloadFailed, + } +} + +fn guarded_refresh_persisted_credentials_from_load_result( + persisted_credentials: Result>, + server_name: &str, +) -> GuardedRefreshPersistedCredentials { + match persisted_credentials { + Ok(credentials) => GuardedRefreshPersistedCredentials::Loaded(credentials), + Err(error) => { + warn!("failed to reload OAuth tokens for server {server_name}: {error}"); + GuardedRefreshPersistedCredentials::ReloadFailed + } } } @@ -521,6 +734,61 @@ fn token_needs_refresh(expires_at: Option) -> bool { now.saturating_add(REFRESH_SKEW_MILLIS) >= expires_at } +fn determine_guarded_refresh_outcome( + cached_credentials: &StoredOAuthTokens, + persisted_credentials: Option, +) -> GuardedRefreshOutcome { + match persisted_credentials { + Some(persisted_credentials) + if oauth_tokens_equal_for_refresh( + Some(cached_credentials), + Some(&persisted_credentials), + ) => + { + GuardedRefreshOutcome::ReloadedNoChange + } + Some(persisted_credentials) => { + GuardedRefreshOutcome::ReloadedChanged(persisted_credentials) + } + None => GuardedRefreshOutcome::MissingOrInvalid, + } +} + +fn oauth_tokens_equal_for_refresh( + left: Option<&StoredOAuthTokens>, + right: Option<&StoredOAuthTokens>, +) -> bool { + match (left, right) { + (None, None) => true, + (Some(left), Some(right)) => { + left.server_name == right.server_name + && left.url == right.url + && left.client_id == right.client_id + && left.expires_at == right.expires_at + && oauth_token_responses_equal_for_refresh( + &left.token_response, + &right.token_response, + ) + } + _ => false, + } +} + +fn oauth_token_responses_equal_for_refresh( + left: &WrappedOAuthTokenResponse, + right: &WrappedOAuthTokenResponse, +) -> bool { + let left = &left.0; + let right = &right.0; + + left.access_token().secret() == right.access_token().secret() + && left.token_type() == right.token_type() + && left.refresh_token().map(RefreshToken::secret) + == right.refresh_token().map(RefreshToken::secret) + && left.scopes() == right.scopes() + && left.extra_fields() == right.extra_fields() +} + fn compute_store_key(server_name: &str, server_url: &str) -> Result { let mut payload = JsonMap::new(); payload.insert( @@ -855,6 +1123,158 @@ mod tests { assert!(tokens.token_response.0.expires_in().is_none()); } + #[test] + fn guarded_refresh_outcome_reloads_when_persisted_credentials_changed() { + let cached = sample_tokens(); + let mut persisted = sample_tokens(); + persisted + .token_response + .0 + .set_refresh_token(Some(RefreshToken::new("rotated-refresh-token".to_string()))); + persisted + .token_response + .0 + .set_expires_in(Some(&Duration::from_secs(7200))); + persisted.expires_at = super::compute_expires_at_millis(&persisted.token_response.0); + + assert_eq!( + super::determine_guarded_refresh_outcome(&cached, Some(persisted.clone())), + super::GuardedRefreshOutcome::ReloadedChanged(persisted), + ); + } + + #[test] + fn guarded_refresh_outcome_refreshes_when_persisted_credentials_match() { + let cached = sample_tokens(); + let mut persisted = cached.clone(); + persisted + .token_response + .0 + .set_expires_in(Some(&Duration::from_secs(5))); + + assert_eq!( + super::determine_guarded_refresh_outcome(&cached, Some(persisted)), + super::GuardedRefreshOutcome::ReloadedNoChange, + ); + } + + #[test] + fn guarded_refresh_outcome_clears_when_persisted_credentials_are_missing() { + assert_eq!( + super::determine_guarded_refresh_outcome(&sample_tokens(), None), + super::GuardedRefreshOutcome::MissingOrInvalid, + ); + } + + #[test] + fn guarded_refresh_outcome_keeps_state_recoverable_when_reload_fails() { + let error = anyhow::anyhow!("transient read failure"); + + assert_eq!( + super::guarded_refresh_outcome_from_load_result( + &sample_tokens(), + Err(error), + "test-server", + ), + super::GuardedRefreshOutcome::ReloadFailed, + ); + } + + #[test] + fn guarded_refresh_auto_load_keeps_state_recoverable_when_keyring_fails_without_file() { + let _env = TempCodexHome::new(); + let store = MockKeyringStore::default(); + let tokens = sample_tokens(); + let key = super::compute_store_key(&tokens.server_name, &tokens.url) + .expect("store key should compute"); + store.set_error(&key, KeyringError::Invalid("error".into(), "load".into())); + + assert_eq!( + super::load_oauth_tokens_for_guarded_refresh_with_keyring_fallback( + &store, + &tokens.server_name, + &tokens.url, + ), + super::GuardedRefreshPersistedCredentials::ReloadFailed, + ); + } + + #[test] + fn missing_cached_credentials_reload_shared_store_from_file() -> Result<()> { + let _env = TempCodexHome::new(); + let tokens = sample_tokens(); + let expected = tokens.clone(); + super::save_oauth_tokens_to_file(&tokens)?; + + let loaded = super::load_oauth_tokens_when_cache_missing( + &tokens.server_name, + &tokens.url, + OAuthCredentialsStoreMode::File, + ) + .expect("tokens should reload from shared file store"); + assert_tokens_match_without_expiry(&loaded, &expected); + Ok(()) + } + + #[test] + fn missing_cached_credentials_ignore_reload_failures() { + let _env = TempCodexHome::new(); + let store = MockKeyringStore::default(); + let tokens = sample_tokens(); + let key = super::compute_store_key(&tokens.server_name, &tokens.url) + .expect("store key should compute"); + store.set_error(&key, KeyringError::Invalid("error".into(), "load".into())); + + assert_eq!( + super::load_oauth_tokens_for_guarded_refresh_with_keyring_fallback( + &store, + &tokens.server_name, + &tokens.url, + ), + super::GuardedRefreshPersistedCredentials::ReloadFailed, + ); + assert_eq!( + super::load_oauth_tokens_when_cache_missing( + &tokens.server_name, + &tokens.url, + OAuthCredentialsStoreMode::Auto, + ), + None, + ); + } + + #[test] + fn oauth_tokens_equal_for_refresh_ignores_only_expires_in() { + let left = sample_tokens(); + let mut right = left.clone(); + right + .token_response + .0 + .set_expires_in(Some(&Duration::from_secs(5))); + + assert!(super::oauth_tokens_equal_for_refresh( + Some(&left), + Some(&right), + )); + + let mut different_refresh_token = right.clone(); + different_refresh_token + .token_response + .0 + .set_refresh_token(Some(RefreshToken::new("different-refresh".to_string()))); + assert!(!super::oauth_tokens_equal_for_refresh( + Some(&left), + Some(&different_refresh_token), + )); + + let mut different_expiry = right; + different_expiry.expires_at = different_expiry.expires_at.map(|value| value + 1000); + assert!(!super::oauth_tokens_equal_for_refresh( + Some(&left), + Some(&different_expiry), + )); + } + fn assert_tokens_match_without_expiry( actual: &StoredOAuthTokens, expected: &StoredOAuthTokens, diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index e1d7045968e..eb065cfbb63 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -39,7 +39,10 @@ use rmcp::service::{self}; use rmcp::transport::StreamableHttpClientTransport; use rmcp::transport::auth::AuthClient; use rmcp::transport::auth::AuthError; -use rmcp::transport::auth::OAuthState; +use rmcp::transport::auth::AuthorizationManager; +use rmcp::transport::auth::CredentialStore; +use rmcp::transport::auth::InMemoryCredentialStore; +use rmcp::transport::auth::StoredCredentials; use rmcp::transport::child_process::TokioChildProcess; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; use serde_json::Value; @@ -358,6 +361,12 @@ impl RmcpClient { } }; + if let Some(runtime) = &oauth_persistor + && let Err(error) = runtime.refresh_if_needed().await + { + warn!("failed to refresh OAuth tokens before initialize: {error}"); + } + let service = match timeout { Some(duration) => time::timeout(duration, transport) .await @@ -595,22 +604,20 @@ async fn create_oauth_transport_and_runtime( )> { let http_client = apply_default_headers(reqwest::Client::builder(), &default_headers).build()?; - let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?; - - oauth_state - .set_credentials( - &initial_tokens.client_id, - initial_tokens.token_response.0.clone(), - ) + let runtime_credentials = InMemoryCredentialStore::new(); + runtime_credentials + .save(StoredCredentials { + client_id: initial_tokens.client_id.clone(), + token_response: Some(initial_tokens.token_response.0.clone()), + }) .await?; - let manager = match oauth_state { - OAuthState::Authorized(manager) => manager, - OAuthState::Unauthorized(manager) => manager, - OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => { - return Err(anyhow!("unexpected OAuth state during client setup")); - } - }; + let mut manager = AuthorizationManager::new(url.to_string()).await?; + manager.set_credential_store(runtime_credentials.clone()); + manager.with_client(http_client.clone())?; + let metadata = manager.discover_metadata().await?; + manager.set_metadata(metadata); + manager.configure_client_id(&initial_tokens.client_id)?; let auth_client = AuthClient::new(http_client, manager); let auth_manager = auth_client.auth_manager.clone(); @@ -624,6 +631,7 @@ async fn create_oauth_transport_and_runtime( server_name.to_string(), url.to_string(), auth_manager, + runtime_credentials, credentials_store, Some(initial_tokens), );