@@ -944,8 +944,20 @@ impl AuthorizationManager {
944944 Ok ( false )
945945 }
946946
947+ /// Use a caller-configured `reqwest::Client` for every OAuth HTTP operation,
948+ /// preserving all of its settings (proxy, TLS, timeout, default headers).
949+ ///
950+ /// The same client is reused for all requests, so its own redirect policy applies
951+ /// and [`OAuthHttpRedirectPolicy::Stop`] is not enforced for token operations.
952+ /// Callers needing strict no-redirect handling should pass a custom
953+ /// [`OAuthHttpClient`] to [`AuthorizationManager::new_with_oauth_http_client`].
947954 pub fn with_client ( & mut self , http_client : ReqwestClient ) -> Result < ( ) , AuthError > {
948- self . http_client = Arc :: new ( ReqwestOAuthHttpClient :: new ( http_client) ?) ;
955+ // One client for both modes: a built reqwest::Client can't be rebuilt as a
956+ // no-redirect variant without dropping the caller's configuration.
957+ self . http_client = Arc :: new ( ReqwestOAuthHttpClient {
958+ follow_redirects : http_client. clone ( ) ,
959+ stop_redirects : http_client,
960+ } ) ;
949961 self . refresh_redirect_policy = OAuthHttpRedirectPolicy :: Follow ;
950962 Ok ( ( ) )
951963 }
@@ -4871,6 +4883,149 @@ mod tests {
48714883 ) ;
48724884 }
48734885
4886+ #[ tokio:: test]
4887+ async fn exchange_code_uses_client_configured_by_with_client ( ) {
4888+ use axum:: { Router , body:: Body , http:: Response , routing:: post} ;
4889+
4890+ let received_header = Arc :: new ( std:: sync:: Mutex :: new ( None ) ) ;
4891+ let received_header_clone = Arc :: clone ( & received_header) ;
4892+ let app = Router :: new ( ) . route (
4893+ "/token" ,
4894+ post ( move |headers : axum:: http:: HeaderMap | {
4895+ let received_header = Arc :: clone ( & received_header_clone) ;
4896+ async move {
4897+ * received_header. lock ( ) . unwrap ( ) = headers
4898+ . get ( "x-custom-client" )
4899+ . and_then ( |value| value. to_str ( ) . ok ( ) )
4900+ . map ( str:: to_string) ;
4901+ Response :: builder ( )
4902+ . status ( 200 )
4903+ . header ( "content-type" , "application/json" )
4904+ . body ( Body :: from (
4905+ r#"{"access_token":"new-token","token_type":"Bearer","expires_in":3600}"# ,
4906+ ) )
4907+ . unwrap ( )
4908+ }
4909+ } ) ,
4910+ ) ;
4911+ let listener = tokio:: net:: TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
4912+ let addr = listener. local_addr ( ) . unwrap ( ) ;
4913+ tokio:: spawn ( async move { axum:: serve ( listener, app) . await . unwrap ( ) } ) ;
4914+
4915+ let mut manager = manager_with_metadata ( Some ( AuthorizationMetadata {
4916+ authorization_endpoint : format ! ( "http://{addr}/authorize" ) ,
4917+ token_endpoint : format ! ( "http://{addr}/token" ) ,
4918+ ..Default :: default ( )
4919+ } ) )
4920+ . await ;
4921+ let mut default_headers = reqwest:: header:: HeaderMap :: new ( ) ;
4922+ default_headers. insert ( "x-custom-client" , "configured" . parse ( ) . unwrap ( ) ) ;
4923+ manager
4924+ . with_client (
4925+ reqwest:: Client :: builder ( )
4926+ . default_headers ( default_headers)
4927+ . build ( )
4928+ . unwrap ( ) ,
4929+ )
4930+ . unwrap ( ) ;
4931+ manager. configure_client ( test_client_config ( ) ) . unwrap ( ) ;
4932+ let authorization_url = manager. get_authorization_url ( & [ ] ) . await . unwrap ( ) ;
4933+ let state = Url :: parse ( & authorization_url)
4934+ . unwrap ( )
4935+ . query_pairs ( )
4936+ . find ( |( name, _) | name == "state" )
4937+ . unwrap ( )
4938+ . 1
4939+ . into_owned ( ) ;
4940+
4941+ manager
4942+ . exchange_code_for_token ( "authorization-code" , & state)
4943+ . await
4944+ . unwrap ( ) ;
4945+
4946+ assert_eq ! (
4947+ received_header. lock( ) . unwrap( ) . as_deref( ) ,
4948+ Some ( "configured" )
4949+ ) ;
4950+ }
4951+
4952+ #[ tokio:: test]
4953+ async fn exchange_code_follows_redirects_with_with_client ( ) {
4954+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
4955+
4956+ use axum:: {
4957+ Router ,
4958+ body:: Body ,
4959+ http:: { Response , StatusCode } ,
4960+ routing:: post,
4961+ } ;
4962+
4963+ // The token endpoint replies with a 307 redirect; the with_client path reuses
4964+ // the caller's redirect-following client, so the request is expected to follow
4965+ // it to the final endpoint that returns the token.
4966+ let final_endpoint_hit = Arc :: new ( AtomicBool :: new ( false ) ) ;
4967+ let final_endpoint_hit_clone = Arc :: clone ( & final_endpoint_hit) ;
4968+ let app = Router :: new ( )
4969+ . route (
4970+ "/token" ,
4971+ post ( || async {
4972+ Response :: builder ( )
4973+ . status ( StatusCode :: TEMPORARY_REDIRECT )
4974+ . header ( "location" , "/token-final" )
4975+ . body ( Body :: empty ( ) )
4976+ . unwrap ( )
4977+ } ) ,
4978+ )
4979+ . route (
4980+ "/token-final" ,
4981+ post ( move || {
4982+ let final_endpoint_hit = Arc :: clone ( & final_endpoint_hit_clone) ;
4983+ async move {
4984+ final_endpoint_hit. store ( true , Ordering :: SeqCst ) ;
4985+ Response :: builder ( )
4986+ . status ( 200 )
4987+ . header ( "content-type" , "application/json" )
4988+ . body ( Body :: from (
4989+ r#"{"access_token":"redirected-token","token_type":"Bearer","expires_in":3600}"# ,
4990+ ) )
4991+ . unwrap ( )
4992+ }
4993+ } ) ,
4994+ ) ;
4995+ let listener = tokio:: net:: TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
4996+ let addr = listener. local_addr ( ) . unwrap ( ) ;
4997+ tokio:: spawn ( async move { axum:: serve ( listener, app) . await . unwrap ( ) } ) ;
4998+
4999+ let mut manager = manager_with_metadata ( Some ( AuthorizationMetadata {
5000+ authorization_endpoint : format ! ( "http://{addr}/authorize" ) ,
5001+ token_endpoint : format ! ( "http://{addr}/token" ) ,
5002+ ..Default :: default ( )
5003+ } ) )
5004+ . await ;
5005+ manager
5006+ . with_client ( reqwest:: Client :: builder ( ) . build ( ) . unwrap ( ) )
5007+ . unwrap ( ) ;
5008+ manager. configure_client ( test_client_config ( ) ) . unwrap ( ) ;
5009+ let authorization_url = manager. get_authorization_url ( & [ ] ) . await . unwrap ( ) ;
5010+ let state = Url :: parse ( & authorization_url)
5011+ . unwrap ( )
5012+ . query_pairs ( )
5013+ . find ( |( name, _) | name == "state" )
5014+ . unwrap ( )
5015+ . 1
5016+ . into_owned ( ) ;
5017+
5018+ manager
5019+ . exchange_code_for_token ( "authorization-code" , & state)
5020+ . await
5021+ . unwrap ( ) ;
5022+
5023+ assert ! (
5024+ final_endpoint_hit. load( Ordering :: SeqCst ) ,
5025+ "with_client path should follow redirects on token exchange"
5026+ ) ;
5027+ }
5028+
48745029 async fn start_token_server ( ) -> ( String , Arc < std:: sync:: Mutex < Option < String > > > ) {
48755030 use axum:: { Router , body:: Body , http:: Response , routing:: post} ;
48765031 let captured: Arc < std:: sync:: Mutex < Option < String > > > = Arc :: new ( std:: sync:: Mutex :: new ( None ) ) ;
0 commit comments