Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 245 additions & 1 deletion codex-rs/core/tests/suite/rmcp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use codex_core::CodexAuth;
use codex_core::config::types::McpServerConfig;
use codex_core::config::types::McpServerTransportConfig;
use codex_core::models_manager::manager::RefreshStrategy;

use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::openai_models::ConfigShellToolType;
use codex_protocol::openai_models::InputModality;
Expand All @@ -28,6 +27,10 @@ use codex_protocol::protocol::McpToolCallBeginEvent;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::user_input::UserInput;
use codex_rmcp_client::ElicitationAction;
use codex_rmcp_client::ElicitationResponse;
use codex_rmcp_client::OAuthCredentialsStoreMode;
use codex_rmcp_client::RmcpClient;
use codex_utils_cargo_bin::cargo_bin;
use core_test_support::responses;
use core_test_support::responses::mount_models_once;
Expand All @@ -36,6 +39,13 @@ use core_test_support::skip_if_no_network;
use core_test_support::stdio_server_bin;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use futures::FutureExt;
use rmcp::model::ClientCapabilities;
use rmcp::model::ElicitationCapability;
use rmcp::model::FormElicitationCapability;
use rmcp::model::Implementation;
use rmcp::model::InitializeRequestParams;
use rmcp::model::ProtocolVersion;
use serde_json::Value;
use serde_json::json;
use serial_test::serial;
Expand Down Expand Up @@ -1056,6 +1066,221 @@ async fn streamable_http_with_oauth_round_trip_impl() -> anyhow::Result<()> {
Ok(())
}

/// This test writes to a fallback credentials file in CODEX_HOME.
#[serial(codex_home)]
#[test]
fn streamable_http_with_oauth_refresh_adopts_rotated_credentials() -> anyhow::Result<()> {
const TEST_STACK_SIZE_BYTES: usize = 8 * 1024 * 1024;

let handle = std::thread::Builder::new()
.name("streamable_http_with_oauth_refresh_adopts_rotated_credentials".to_string())
.stack_size(TEST_STACK_SIZE_BYTES)
.spawn(|| -> anyhow::Result<()> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.enable_all()
.build()?;
runtime.block_on(streamable_http_with_oauth_refresh_adopts_rotated_credentials_impl())
})?;

match handle.join() {
Ok(result) => result,
Err(_) => Err(anyhow::anyhow!(
"streamable_http_with_oauth_refresh_adopts_rotated_credentials thread panicked"
)),
}
}

#[allow(clippy::expect_used)]
async fn streamable_http_with_oauth_refresh_adopts_rotated_credentials_impl() -> anyhow::Result<()>
{
skip_if_no_network!(Ok(()));

let server_name = "rmcp_http_oauth_refresh_race";
let initial_access_token = "initial-access-token";
let initial_refresh_token = "initial-refresh-token";
let rotated_access_token = "rotated-access-token";
let rotated_refresh_token = "rotated-refresh-token";
let rmcp_http_server_bin = match cargo_bin("test_streamable_http_server") {
Ok(path) => path,
Err(err) => {
eprintln!("test_streamable_http_server binary not available, skipping test: {err}");
return Ok(());
}
};

let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let bind_addr = format!("127.0.0.1:{port}");
let server_url = format!("http://{bind_addr}/mcp");

let mut http_server_child = Command::new(&rmcp_http_server_bin)
.kill_on_drop(true)
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
.env("MCP_EXPECT_BEARER", initial_access_token)
.env("MCP_EXPECT_REFRESH_TOKEN", initial_refresh_token)
.env("MCP_REFRESH_NEXT_ACCESS_TOKEN", rotated_access_token)
.env("MCP_REFRESH_NEXT_REFRESH_TOKEN", rotated_refresh_token)
.env("MCP_REFRESH_EXPIRES_IN", "3600")
.env("MCP_REFRESH_SINGLE_USE", "1")
.spawn()?;

wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5))
.await?;

let temp_home = tempdir()?;
let _guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str());
let initial_expires_at = SystemTime::now()
.checked_add(Duration::from_secs(1))
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
.duration_since(UNIX_EPOCH)?
.as_millis() as u64;
write_fallback_oauth_tokens_with_expiry(
temp_home.path(),
server_name,
&server_url,
"test-client-id",
initial_access_token,
initial_refresh_token,
initial_expires_at,
)?;

let client_a = RmcpClient::new_streamable_http_client(
server_name,
&server_url,
None,
None,
None,
OAuthCredentialsStoreMode::File,
)
.await?;
let client_b = RmcpClient::new_streamable_http_client(
server_name,
&server_url,
None,
None,
None,
OAuthCredentialsStoreMode::File,
)
.await?;

client_a
.initialize(
rmcp_initialize_params(),
Some(Duration::from_secs(5)),
noop_send_elicitation(),
)
.await?;
client_b
.initialize(
rmcp_initialize_params(),
Some(Duration::from_secs(5)),
noop_send_elicitation(),
)
.await?;

let tools_a = client_a
.list_tools(None, Some(Duration::from_secs(5)))
.await?;
assert_eq!(tools_a.tools.len(), 1);
assert_eq!(tools_a.tools[0].name.as_ref(), "echo");
assert_stored_oauth_tokens(
temp_home.path(),
rotated_access_token,
rotated_refresh_token,
)?;

let tools_b = client_b
.list_tools(None, Some(Duration::from_secs(5)))
.await?;
assert_eq!(tools_b.tools.len(), 1);
assert_eq!(tools_b.tools[0].name.as_ref(), "echo");
assert_stored_oauth_tokens(
temp_home.path(),
rotated_access_token,
rotated_refresh_token,
)?;

match http_server_child.try_wait() {
Ok(Some(_)) => {}
Ok(None) => {
let _ = http_server_child.kill().await;
}
Err(error) => {
eprintln!("failed to check streamable http oauth server status: {error}");
let _ = http_server_child.kill().await;
}
}
if let Err(error) = http_server_child.wait().await {
eprintln!("failed to await streamable http oauth server shutdown: {error}");
}

Ok(())
}

fn rmcp_initialize_params() -> InitializeRequestParams {
InitializeRequestParams {
meta: None,
capabilities: ClientCapabilities {
experimental: None,
extensions: None,
roots: None,
sampling: None,
elicitation: Some(ElicitationCapability {
form: Some(FormElicitationCapability {
schema_validation: None,
}),
url: None,
}),
tasks: None,
},
client_info: Implementation {
name: "codex-test".into(),
version: "0.0.0-test".into(),
title: Some("Codex rmcp oauth refresh test".into()),
description: None,
icons: None,
website_url: None,
},
protocol_version: ProtocolVersion::V_2025_06_18,
}
}

fn noop_send_elicitation() -> codex_rmcp_client::SendElicitation {
Box::new(|_, _| {
async {
Ok(ElicitationResponse {
action: ElicitationAction::Accept,
content: Some(json!({})),
})
}
.boxed()
})
}

fn assert_stored_oauth_tokens(
home: &Path,
expected_access_token: &str,
expected_refresh_token: &str,
) -> anyhow::Result<()> {
let file_path = home.join(".credentials.json");
let stored: Value = serde_json::from_slice(&fs::read(&file_path)?)?;
let entry = stored
.get("stub")
.and_then(Value::as_object)
.ok_or_else(|| anyhow::anyhow!("expected fallback OAuth credentials entry"))?;
assert_eq!(
entry.get("access_token").and_then(Value::as_str),
Some(expected_access_token)
);
assert_eq!(
entry.get("refresh_token").and_then(Value::as_str),
Some(expected_refresh_token)
);
Ok(())
}

async fn wait_for_streamable_http_server(
server_child: &mut Child,
address: &str,
Expand Down Expand Up @@ -1111,7 +1336,26 @@ fn write_fallback_oauth_tokens(
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
.duration_since(UNIX_EPOCH)?
.as_millis() as u64;
write_fallback_oauth_tokens_with_expiry(
home,
server_name,
server_url,
client_id,
access_token,
refresh_token,
expires_at,
)
}

fn write_fallback_oauth_tokens_with_expiry(
home: &Path,
server_name: &str,
server_url: &str,
client_id: &str,
access_token: &str,
refresh_token: &str,
expires_at: u64,
) -> anyhow::Result<()> {
let store = serde_json::json!({
"stub": {
"server_name": server_name,
Expand Down
Loading
Loading