55
66import hstspreload
77
8- from .auth import BasicAuth
8+ from .auth import Auth , AuthTypes , BasicAuth , FunctionAuth
99from .concurrency .base import ConcurrencyBackend
1010from .config import (
1111 DEFAULT_MAX_REDIRECTS ,
3131 RedirectLoop ,
3232 TooManyRedirects ,
3333)
34- from .middleware import Middleware
3534from .models import (
3635 URL ,
37- AuthTypes ,
3836 Cookies ,
3937 CookieTypes ,
4038 Headers ,
@@ -397,28 +395,18 @@ async def send(
397395 if request .url .scheme not in ("http" , "https" ):
398396 raise InvalidURL ('URL scheme must be "http" or "https".' )
399397
400- auth = self .auth if auth is None else auth
401- trust_env = self .trust_env if trust_env is None else trust_env
402398 timeout = self .timeout if isinstance (timeout , UnsetType ) else Timeout (timeout )
403399
404- if not isinstance (auth , Middleware ):
405- request = self .authenticate (request , trust_env , auth )
406- response = await self .send_handling_redirects (
407- request ,
408- verify = verify ,
409- cert = cert ,
410- timeout = timeout ,
411- allow_redirects = allow_redirects ,
412- )
413- else :
414- get_response = functools .partial (
415- self .send_handling_redirects ,
416- verify = verify ,
417- cert = cert ,
418- timeout = timeout ,
419- allow_redirects = allow_redirects ,
420- )
421- response = await auth (request , get_response )
400+ auth = self .setup_auth (request , trust_env , auth )
401+
402+ response = await self .send_handling_redirects (
403+ request ,
404+ auth = auth ,
405+ verify = verify ,
406+ cert = cert ,
407+ timeout = timeout ,
408+ allow_redirects = allow_redirects ,
409+ )
422410
423411 if not stream :
424412 try :
@@ -428,30 +416,36 @@ async def send(
428416
429417 return response
430418
431- def authenticate (
432- self , request : Request , trust_env : bool , auth : AuthTypes = None
433- ) -> "Request" :
419+ def setup_auth (
420+ self , request : Request , trust_env : bool = None , auth : AuthTypes = None
421+ ) -> Auth :
422+ auth = self .auth if auth is None else auth
423+ trust_env = self .trust_env if trust_env is None else trust_env
424+
434425 if auth is not None :
435426 if isinstance (auth , tuple ):
436- auth = BasicAuth (username = auth [0 ], password = auth [1 ])
437- return auth (request )
427+ return BasicAuth (username = auth [0 ], password = auth [1 ])
428+ elif isinstance (auth , Auth ):
429+ return auth
430+ elif callable (auth ):
431+ return FunctionAuth (func = auth )
432+ raise TypeError ('Invalid "auth" argument.' )
438433
439434 username , password = request .url .username , request .url .password
440435 if username or password :
441- auth = BasicAuth (username = username , password = password )
442- return auth (request )
436+ return BasicAuth (username = username , password = password )
443437
444438 if trust_env and "Authorization" not in request .headers :
445439 credentials = self .netrc .get_credentials (request .url .authority )
446440 if credentials is not None :
447- auth = BasicAuth (username = credentials [0 ], password = credentials [1 ])
448- return auth (request )
441+ return BasicAuth (username = credentials [0 ], password = credentials [1 ])
449442
450- return request
443+ return Auth ()
451444
452445 async def send_handling_redirects (
453446 self ,
454447 request : Request ,
448+ auth : Auth ,
455449 timeout : Timeout ,
456450 verify : VerifyTypes = None ,
457451 cert : CertTypes = None ,
@@ -467,8 +461,8 @@ async def send_handling_redirects(
467461 if request .url in (response .url for response in history ):
468462 raise RedirectLoop ()
469463
470- response = await self .send_single_request (
471- request , verify = verify , cert = cert , timeout = timeout
464+ response = await self .send_handling_auth (
465+ request , auth = auth , timeout = timeout , verify = verify , cert = cert
472466 )
473467 response .history = list (history )
474468
@@ -483,6 +477,7 @@ async def send_handling_redirects(
483477 response .call_next = functools .partial (
484478 self .send_handling_redirects ,
485479 request = request ,
480+ auth = auth ,
486481 verify = verify ,
487482 cert = cert ,
488483 timeout = timeout ,
@@ -581,6 +576,29 @@ def redirect_content(self, request: Request, method: str) -> RequestContent:
581576 raise RedirectBodyUnavailable ()
582577 return request .content
583578
579+ async def send_handling_auth (
580+ self ,
581+ request : Request ,
582+ auth : Auth ,
583+ timeout : Timeout ,
584+ verify : VerifyTypes = None ,
585+ cert : CertTypes = None ,
586+ ) -> Response :
587+ auth_flow = auth (request )
588+ request = next (auth_flow )
589+ while True :
590+ response = await self .send_single_request (request , timeout , verify , cert )
591+ try :
592+ next_request = auth_flow .send (response )
593+ except StopIteration :
594+ return response
595+ except BaseException as exc :
596+ await response .close ()
597+ raise exc from None
598+ else :
599+ request = next_request
600+ await response .close ()
601+
584602 async def send_single_request (
585603 self ,
586604 request : Request ,
0 commit comments