Skip to content

Commit 2c5e93a

Browse files
committed
fix(auth): preserve configured reqwest client
1 parent de898dd commit 2c5e93a

1 file changed

Lines changed: 72 additions & 1 deletion

File tree

crates/rmcp/src/transport/auth.rs

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,12 @@ impl AuthorizationManager {
945945
}
946946

947947
pub fn with_client(&mut self, http_client: ReqwestClient) -> Result<(), AuthError> {
948-
self.http_client = Arc::new(ReqwestOAuthHttpClient::new(http_client)?);
948+
// Preserve every caller-provided reqwest setting on the legacy with_client path.
949+
// Callers that need strict no-redirect handling can use OAuthHttpClient directly.
950+
self.http_client = Arc::new(ReqwestOAuthHttpClient {
951+
follow_redirects: http_client.clone(),
952+
stop_redirects: http_client,
953+
});
949954
self.refresh_redirect_policy = OAuthHttpRedirectPolicy::Follow;
950955
Ok(())
951956
}
@@ -4871,6 +4876,72 @@ mod tests {
48714876
);
48724877
}
48734878

4879+
#[tokio::test]
4880+
async fn exchange_code_uses_client_configured_by_with_client() {
4881+
use axum::{Router, body::Body, http::Response, routing::post};
4882+
4883+
let received_header = Arc::new(std::sync::Mutex::new(None));
4884+
let received_header_clone = Arc::clone(&received_header);
4885+
let app = Router::new().route(
4886+
"/token",
4887+
post(move |headers: axum::http::HeaderMap| {
4888+
let received_header = Arc::clone(&received_header_clone);
4889+
async move {
4890+
*received_header.lock().unwrap() = headers
4891+
.get("x-custom-client")
4892+
.and_then(|value| value.to_str().ok())
4893+
.map(str::to_string);
4894+
Response::builder()
4895+
.status(200)
4896+
.header("content-type", "application/json")
4897+
.body(Body::from(
4898+
r#"{"access_token":"new-token","token_type":"Bearer","expires_in":3600}"#,
4899+
))
4900+
.unwrap()
4901+
}
4902+
}),
4903+
);
4904+
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
4905+
let addr = listener.local_addr().unwrap();
4906+
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
4907+
4908+
let mut manager = manager_with_metadata(Some(AuthorizationMetadata {
4909+
authorization_endpoint: format!("http://{addr}/authorize"),
4910+
token_endpoint: format!("http://{addr}/token"),
4911+
..Default::default()
4912+
}))
4913+
.await;
4914+
let mut default_headers = reqwest::header::HeaderMap::new();
4915+
default_headers.insert("x-custom-client", "configured".parse().unwrap());
4916+
manager
4917+
.with_client(
4918+
reqwest::Client::builder()
4919+
.default_headers(default_headers)
4920+
.build()
4921+
.unwrap(),
4922+
)
4923+
.unwrap();
4924+
manager.configure_client(test_client_config()).unwrap();
4925+
let authorization_url = manager.get_authorization_url(&[]).await.unwrap();
4926+
let state = Url::parse(&authorization_url)
4927+
.unwrap()
4928+
.query_pairs()
4929+
.find(|(name, _)| name == "state")
4930+
.unwrap()
4931+
.1
4932+
.into_owned();
4933+
4934+
manager
4935+
.exchange_code_for_token("authorization-code", &state)
4936+
.await
4937+
.unwrap();
4938+
4939+
assert_eq!(
4940+
received_header.lock().unwrap().as_deref(),
4941+
Some("configured")
4942+
);
4943+
}
4944+
48744945
async fn start_token_server() -> (String, Arc<std::sync::Mutex<Option<String>>>) {
48754946
use axum::{Router, body::Body, http::Response, routing::post};
48764947
let captured: Arc<std::sync::Mutex<Option<String>>> = Arc::new(std::sync::Mutex::new(None));

0 commit comments

Comments
 (0)