@@ -8,7 +8,7 @@ use crate::{
88 realm:: Realm ,
99 CredentialsCache , KeyringProvider , CREDENTIALS_CACHE ,
1010} ;
11- use anyhow:: anyhow;
11+ use anyhow:: { anyhow, format_err } ;
1212use netrc:: Netrc ;
1313use reqwest:: { Request , Response } ;
1414use reqwest_middleware:: { Error , Middleware , Next } ;
@@ -22,6 +22,11 @@ pub struct AuthMiddleware {
2222 netrc : Option < Netrc > ,
2323 keyring : Option < KeyringProvider > ,
2424 cache : Option < CredentialsCache > ,
25+ /// We know that the endpoint needs authentication, so we don't try to send an unauthenticated
26+ /// request.
27+ ///
28+ /// This is also useful since it avoids cloning an unclonable request.
29+ only_authenticated : bool ,
2530}
2631
2732impl AuthMiddleware {
@@ -30,6 +35,7 @@ impl AuthMiddleware {
3035 netrc : Netrc :: new ( ) . ok ( ) ,
3136 keyring : None ,
3237 cache : None ,
38+ only_authenticated : false ,
3339 }
3440 }
3541
@@ -56,6 +62,16 @@ impl AuthMiddleware {
5662 self
5763 }
5864
65+ /// We know that the endpoint needs authentication, so we don't try to send an unauthenticated
66+ /// request.
67+ ///
68+ /// This is also useful since it avoids cloning an unclonable request.
69+ #[ must_use]
70+ pub fn with_only_authenticated ( mut self , only_authenticated : bool ) -> Self {
71+ self . only_authenticated = only_authenticated;
72+ self
73+ }
74+
5975 /// Get the configured authentication store.
6076 ///
6177 /// If not set, the global store is used.
@@ -198,32 +214,42 @@ impl Middleware for AuthMiddleware {
198214 . as_ref ( )
199215 . is_some_and ( |credentials| credentials. username ( ) . is_some ( ) ) ;
200216
201- // Otherwise, attempt an anonymous request
202- trace ! ( "Attempting unauthenticated request for {url}" ) ;
203-
204- // <https://github.com/TrueLayer/reqwest-middleware/blob/abdf1844c37092d323683c2396b7eefda1418d3c/reqwest-retry/src/middleware.rs#L141-L149>
205- // Clone the request so we can retry it on authentication failure
206- let mut retry_request = request. try_clone ( ) . ok_or_else ( || {
207- Error :: Middleware ( anyhow ! (
208- "Request object is not cloneable. Are you passing a streaming body?" . to_string( )
209- ) )
210- } ) ?;
211-
212- let response = next. clone ( ) . run ( request, extensions) . await ?;
213-
214- // If we don't fail with authorization related codes, return the response
215- if !matches ! (
216- response. status( ) ,
217- StatusCode :: FORBIDDEN | StatusCode :: NOT_FOUND | StatusCode :: UNAUTHORIZED
218- ) {
219- return Ok ( response) ;
220- }
217+ let ( mut retry_request, response) = if self . only_authenticated {
218+ // For endpoints where we require the user to provide credentials, we don't try the
219+ // unauthenticated request first.
220+ trace ! ( "Checking for credentials for {url}" ) ;
221+ ( request, None )
222+ } else {
223+ // Otherwise, attempt an anonymous request
224+ trace ! ( "Attempting unauthenticated request for {url}" ) ;
225+
226+ // <https://github.com/TrueLayer/reqwest-middleware/blob/abdf1844c37092d323683c2396b7eefda1418d3c/reqwest-retry/src/middleware.rs#L141-L149>
227+ // Clone the request so we can retry it on authentication failure
228+ let retry_request = request. try_clone ( ) . ok_or_else ( || {
229+ Error :: Middleware ( anyhow ! (
230+ "Request object is not cloneable. Are you passing a streaming body?"
231+ . to_string( )
232+ ) )
233+ } ) ?;
234+
235+ let response = next. clone ( ) . run ( request, extensions) . await ?;
236+
237+ // If we don't fail with authorization related codes, return the response
238+ if !matches ! (
239+ response. status( ) ,
240+ StatusCode :: FORBIDDEN | StatusCode :: NOT_FOUND | StatusCode :: UNAUTHORIZED
241+ ) {
242+ return Ok ( response) ;
243+ }
221244
222- // Otherwise, search for credentials
223- trace ! (
224- "Request for {url} failed with {}, checking for credentials" ,
225- response. status( )
226- ) ;
245+ // Otherwise, search for credentials
246+ trace ! (
247+ "Request for {url} failed with {}, checking for credentials" ,
248+ response. status( )
249+ ) ;
250+
251+ ( retry_request, Some ( response) )
252+ } ;
227253
228254 // Check in the cache first
229255 let credentials = self . cache ( ) . get_realm (
@@ -265,7 +291,13 @@ impl Middleware for AuthMiddleware {
265291 }
266292 }
267293
268- Ok ( response)
294+ if let Some ( response) = response {
295+ Ok ( response)
296+ } else {
297+ Err ( Error :: Middleware ( format_err ! (
298+ "Missing credentials for {url}"
299+ ) ) )
300+ }
269301 }
270302}
271303
0 commit comments