Skip to content

Commit 6d020c9

Browse files
authored
fix(auth): preserve configured reqwest client (#917)
1 parent de898dd commit 6d020c9

1 file changed

Lines changed: 156 additions & 1 deletion

File tree

crates/rmcp/src/transport/auth.rs

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,8 +944,20 @@ impl AuthorizationManager {
944944
Ok(false)
945945
}
946946

947+
/// Use a caller-configured `reqwest::Client` for every OAuth HTTP operation,
948+
/// preserving all of its settings (proxy, TLS, timeout, default headers).
949+
///
950+
/// The same client is reused for all requests, so its own redirect policy applies
951+
/// and [`OAuthHttpRedirectPolicy::Stop`] is not enforced for token operations.
952+
/// Callers needing strict no-redirect handling should pass a custom
953+
/// [`OAuthHttpClient`] to [`AuthorizationManager::new_with_oauth_http_client`].
947954
pub fn with_client(&mut self, http_client: ReqwestClient) -> Result<(), AuthError> {
948-
self.http_client = Arc::new(ReqwestOAuthHttpClient::new(http_client)?);
955+
// One client for both modes: a built reqwest::Client can't be rebuilt as a
956+
// no-redirect variant without dropping the caller's configuration.
957+
self.http_client = Arc::new(ReqwestOAuthHttpClient {
958+
follow_redirects: http_client.clone(),
959+
stop_redirects: http_client,
960+
});
949961
self.refresh_redirect_policy = OAuthHttpRedirectPolicy::Follow;
950962
Ok(())
951963
}
@@ -4871,6 +4883,149 @@ mod tests {
48714883
);
48724884
}
48734885

4886+
#[tokio::test]
4887+
async fn exchange_code_uses_client_configured_by_with_client() {
4888+
use axum::{Router, body::Body, http::Response, routing::post};
4889+
4890+
let received_header = Arc::new(std::sync::Mutex::new(None));
4891+
let received_header_clone = Arc::clone(&received_header);
4892+
let app = Router::new().route(
4893+
"/token",
4894+
post(move |headers: axum::http::HeaderMap| {
4895+
let received_header = Arc::clone(&received_header_clone);
4896+
async move {
4897+
*received_header.lock().unwrap() = headers
4898+
.get("x-custom-client")
4899+
.and_then(|value| value.to_str().ok())
4900+
.map(str::to_string);
4901+
Response::builder()
4902+
.status(200)
4903+
.header("content-type", "application/json")
4904+
.body(Body::from(
4905+
r#"{"access_token":"new-token","token_type":"Bearer","expires_in":3600}"#,
4906+
))
4907+
.unwrap()
4908+
}
4909+
}),
4910+
);
4911+
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
4912+
let addr = listener.local_addr().unwrap();
4913+
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
4914+
4915+
let mut manager = manager_with_metadata(Some(AuthorizationMetadata {
4916+
authorization_endpoint: format!("http://{addr}/authorize"),
4917+
token_endpoint: format!("http://{addr}/token"),
4918+
..Default::default()
4919+
}))
4920+
.await;
4921+
let mut default_headers = reqwest::header::HeaderMap::new();
4922+
default_headers.insert("x-custom-client", "configured".parse().unwrap());
4923+
manager
4924+
.with_client(
4925+
reqwest::Client::builder()
4926+
.default_headers(default_headers)
4927+
.build()
4928+
.unwrap(),
4929+
)
4930+
.unwrap();
4931+
manager.configure_client(test_client_config()).unwrap();
4932+
let authorization_url = manager.get_authorization_url(&[]).await.unwrap();
4933+
let state = Url::parse(&authorization_url)
4934+
.unwrap()
4935+
.query_pairs()
4936+
.find(|(name, _)| name == "state")
4937+
.unwrap()
4938+
.1
4939+
.into_owned();
4940+
4941+
manager
4942+
.exchange_code_for_token("authorization-code", &state)
4943+
.await
4944+
.unwrap();
4945+
4946+
assert_eq!(
4947+
received_header.lock().unwrap().as_deref(),
4948+
Some("configured")
4949+
);
4950+
}
4951+
4952+
#[tokio::test]
4953+
async fn exchange_code_follows_redirects_with_with_client() {
4954+
use std::sync::atomic::{AtomicBool, Ordering};
4955+
4956+
use axum::{
4957+
Router,
4958+
body::Body,
4959+
http::{Response, StatusCode},
4960+
routing::post,
4961+
};
4962+
4963+
// The token endpoint replies with a 307 redirect; the with_client path reuses
4964+
// the caller's redirect-following client, so the request is expected to follow
4965+
// it to the final endpoint that returns the token.
4966+
let final_endpoint_hit = Arc::new(AtomicBool::new(false));
4967+
let final_endpoint_hit_clone = Arc::clone(&final_endpoint_hit);
4968+
let app = Router::new()
4969+
.route(
4970+
"/token",
4971+
post(|| async {
4972+
Response::builder()
4973+
.status(StatusCode::TEMPORARY_REDIRECT)
4974+
.header("location", "/token-final")
4975+
.body(Body::empty())
4976+
.unwrap()
4977+
}),
4978+
)
4979+
.route(
4980+
"/token-final",
4981+
post(move || {
4982+
let final_endpoint_hit = Arc::clone(&final_endpoint_hit_clone);
4983+
async move {
4984+
final_endpoint_hit.store(true, Ordering::SeqCst);
4985+
Response::builder()
4986+
.status(200)
4987+
.header("content-type", "application/json")
4988+
.body(Body::from(
4989+
r#"{"access_token":"redirected-token","token_type":"Bearer","expires_in":3600}"#,
4990+
))
4991+
.unwrap()
4992+
}
4993+
}),
4994+
);
4995+
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
4996+
let addr = listener.local_addr().unwrap();
4997+
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
4998+
4999+
let mut manager = manager_with_metadata(Some(AuthorizationMetadata {
5000+
authorization_endpoint: format!("http://{addr}/authorize"),
5001+
token_endpoint: format!("http://{addr}/token"),
5002+
..Default::default()
5003+
}))
5004+
.await;
5005+
manager
5006+
.with_client(reqwest::Client::builder().build().unwrap())
5007+
.unwrap();
5008+
manager.configure_client(test_client_config()).unwrap();
5009+
let authorization_url = manager.get_authorization_url(&[]).await.unwrap();
5010+
let state = Url::parse(&authorization_url)
5011+
.unwrap()
5012+
.query_pairs()
5013+
.find(|(name, _)| name == "state")
5014+
.unwrap()
5015+
.1
5016+
.into_owned();
5017+
5018+
manager
5019+
.exchange_code_for_token("authorization-code", &state)
5020+
.await
5021+
.unwrap();
5022+
5023+
assert!(
5024+
final_endpoint_hit.load(Ordering::SeqCst),
5025+
"with_client path should follow redirects on token exchange"
5026+
);
5027+
}
5028+
48745029
async fn start_token_server() -> (String, Arc<std::sync::Mutex<Option<String>>>) {
48755030
use axum::{Router, body::Body, http::Response, routing::post};
48765031
let captured: Arc<std::sync::Mutex<Option<String>>> = Arc::new(std::sync::Mutex::new(None));

0 commit comments

Comments
 (0)