Skip to content

Commit 6c69e09

Browse files
No I/O auth (#644)
* No I/O on Auth
1 parent 8d4f182 commit 6c69e09

6 files changed

Lines changed: 125 additions & 68 deletions

File tree

httpx/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import typing
22

3+
from .auth import AuthTypes
34
from .client import Client, StreamContextManager
45
from .config import DEFAULT_TIMEOUT_CONFIG, CertTypes, TimeoutTypes, VerifyTypes
56
from .models import (
6-
AuthTypes,
77
CookieTypes,
88
HeaderTypes,
99
QueryParamTypes,

httpx/auth.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,75 @@
77
from urllib.request import parse_http_list
88

99
from .exceptions import ProtocolError
10-
from .middleware import Middleware
1110
from .models import Request, Response
1211
from .utils import to_bytes, to_str, unquote
1312

13+
AuthFlow = typing.Generator[Request, Response, None]
14+
15+
AuthTypes = typing.Union[
16+
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
17+
typing.Callable[["Request"], "Request"],
18+
"Auth",
19+
]
20+
21+
22+
class Auth:
23+
"""
24+
Base class for all authentication schemes.
25+
"""
26+
27+
def __call__(self, request: Request) -> AuthFlow:
28+
"""
29+
Execute the authentication flow.
30+
31+
To dispatch a request, `yield` it:
32+
33+
```
34+
yield request
35+
```
36+
37+
The client will `.send()` the response back into the flow generator. You can
38+
access it like so:
39+
40+
```
41+
response = yield request
42+
```
43+
44+
A `return` (or reaching the end of the generator) will result in the
45+
client returning the last response obtained from the server.
46+
47+
You can dispatch as many requests as is necessary.
48+
"""
49+
yield request
50+
51+
52+
class FunctionAuth(Auth):
53+
"""
54+
Allows the 'auth' argument to be passed as a simple callable function,
55+
that takes the request, and returns a new, modified request.
56+
"""
57+
58+
def __init__(self, func: typing.Callable[[Request], Request]) -> None:
59+
self.func = func
60+
61+
def __call__(self, request: Request) -> AuthFlow:
62+
yield self.func(request)
63+
64+
65+
class BasicAuth(Auth):
66+
"""
67+
Allows the 'auth' argument to be passed as a (username, password) pair,
68+
and uses HTTP Basic authentication.
69+
"""
1470

15-
class BasicAuth:
1671
def __init__(
1772
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
1873
):
1974
self.auth_header = self.build_auth_header(username, password)
2075

21-
def __call__(self, request: Request) -> Request:
76+
def __call__(self, request: Request) -> AuthFlow:
2277
request.headers["Authorization"] = self.auth_header
23-
return request
78+
yield request
2479

2580
def build_auth_header(
2681
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
@@ -30,7 +85,7 @@ def build_auth_header(
3085
return f"Basic {token}"
3186

3287

33-
class DigestAuth(Middleware):
88+
class DigestAuth(Auth):
3489
ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
3590
"MD5": hashlib.md5,
3691
"MD5-SESS": hashlib.md5,
@@ -48,22 +103,22 @@ def __init__(
48103
self.username = to_bytes(username)
49104
self.password = to_bytes(password)
50105

51-
async def __call__(
52-
self, request: Request, get_response: typing.Callable
53-
) -> Response:
54-
response = await get_response(request)
106+
def __call__(self, request: Request) -> AuthFlow:
107+
response = yield request
108+
55109
if response.status_code != 401 or "www-authenticate" not in response.headers:
56-
return response
110+
# If the response is not a 401 WWW-Authenticate, then we don't
111+
# need to build an authenticated request.
112+
return
57113

58-
await response.close()
59114
header = response.headers["www-authenticate"]
60115
try:
61116
challenge = DigestAuthChallenge.from_header(header)
62117
except ValueError:
63118
raise ProtocolError("Malformed Digest authentication header")
64119

65120
request.headers["Authorization"] = self._build_auth_header(request, challenge)
66-
return await get_response(request)
121+
yield request
67122

68123
def _build_auth_header(
69124
self, request: Request, challenge: "DigestAuthChallenge"

httpx/client.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import hstspreload
77

8-
from .auth import BasicAuth
8+
from .auth import Auth, AuthTypes, BasicAuth, FunctionAuth
99
from .concurrency.base import ConcurrencyBackend
1010
from .config import (
1111
DEFAULT_MAX_REDIRECTS,
@@ -31,10 +31,8 @@
3131
RedirectLoop,
3232
TooManyRedirects,
3333
)
34-
from .middleware import Middleware
3534
from .models import (
3635
URL,
37-
AuthTypes,
3836
Cookies,
3937
CookieTypes,
4038
Headers,
@@ -397,28 +395,18 @@ async def send(
397395
if request.url.scheme not in ("http", "https"):
398396
raise InvalidURL('URL scheme must be "http" or "https".')
399397

400-
auth = self.auth if auth is None else auth
401-
trust_env = self.trust_env if trust_env is None else trust_env
402398
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
403399

404-
if not isinstance(auth, Middleware):
405-
request = self.authenticate(request, trust_env, auth)
406-
response = await self.send_handling_redirects(
407-
request,
408-
verify=verify,
409-
cert=cert,
410-
timeout=timeout,
411-
allow_redirects=allow_redirects,
412-
)
413-
else:
414-
get_response = functools.partial(
415-
self.send_handling_redirects,
416-
verify=verify,
417-
cert=cert,
418-
timeout=timeout,
419-
allow_redirects=allow_redirects,
420-
)
421-
response = await auth(request, get_response)
400+
auth = self.setup_auth(request, trust_env, auth)
401+
402+
response = await self.send_handling_redirects(
403+
request,
404+
auth=auth,
405+
verify=verify,
406+
cert=cert,
407+
timeout=timeout,
408+
allow_redirects=allow_redirects,
409+
)
422410

423411
if not stream:
424412
try:
@@ -428,30 +416,36 @@ async def send(
428416

429417
return response
430418

431-
def authenticate(
432-
self, request: Request, trust_env: bool, auth: AuthTypes = None
433-
) -> "Request":
419+
def setup_auth(
420+
self, request: Request, trust_env: bool = None, auth: AuthTypes = None
421+
) -> Auth:
422+
auth = self.auth if auth is None else auth
423+
trust_env = self.trust_env if trust_env is None else trust_env
424+
434425
if auth is not None:
435426
if isinstance(auth, tuple):
436-
auth = BasicAuth(username=auth[0], password=auth[1])
437-
return auth(request)
427+
return BasicAuth(username=auth[0], password=auth[1])
428+
elif isinstance(auth, Auth):
429+
return auth
430+
elif callable(auth):
431+
return FunctionAuth(func=auth)
432+
raise TypeError('Invalid "auth" argument.')
438433

439434
username, password = request.url.username, request.url.password
440435
if username or password:
441-
auth = BasicAuth(username=username, password=password)
442-
return auth(request)
436+
return BasicAuth(username=username, password=password)
443437

444438
if trust_env and "Authorization" not in request.headers:
445439
credentials = self.netrc.get_credentials(request.url.authority)
446440
if credentials is not None:
447-
auth = BasicAuth(username=credentials[0], password=credentials[1])
448-
return auth(request)
441+
return BasicAuth(username=credentials[0], password=credentials[1])
449442

450-
return request
443+
return Auth()
451444

452445
async def send_handling_redirects(
453446
self,
454447
request: Request,
448+
auth: Auth,
455449
timeout: Timeout,
456450
verify: VerifyTypes = None,
457451
cert: CertTypes = None,
@@ -467,8 +461,8 @@ async def send_handling_redirects(
467461
if request.url in (response.url for response in history):
468462
raise RedirectLoop()
469463

470-
response = await self.send_single_request(
471-
request, verify=verify, cert=cert, timeout=timeout
464+
response = await self.send_handling_auth(
465+
request, auth=auth, timeout=timeout, verify=verify, cert=cert
472466
)
473467
response.history = list(history)
474468

@@ -483,6 +477,7 @@ async def send_handling_redirects(
483477
response.call_next = functools.partial(
484478
self.send_handling_redirects,
485479
request=request,
480+
auth=auth,
486481
verify=verify,
487482
cert=cert,
488483
timeout=timeout,
@@ -581,6 +576,29 @@ def redirect_content(self, request: Request, method: str) -> RequestContent:
581576
raise RedirectBodyUnavailable()
582577
return request.content
583578

579+
async def send_handling_auth(
580+
self,
581+
request: Request,
582+
auth: Auth,
583+
timeout: Timeout,
584+
verify: VerifyTypes = None,
585+
cert: CertTypes = None,
586+
) -> Response:
587+
auth_flow = auth(request)
588+
request = next(auth_flow)
589+
while True:
590+
response = await self.send_single_request(request, timeout, verify, cert)
591+
try:
592+
next_request = auth_flow.send(response)
593+
except StopIteration:
594+
return response
595+
except BaseException as exc:
596+
await response.close()
597+
raise exc from None
598+
else:
599+
request = next_request
600+
await response.close()
601+
584602
async def send_single_request(
585603
self,
586604
request: Request,

httpx/middleware.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

httpx/models.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,6 @@
6767

6868
CookieTypes = typing.Union["Cookies", CookieJar, typing.Dict[str, str]]
6969

70-
AuthTypes = typing.Union[
71-
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
72-
typing.Callable[["Request"], "Request"],
73-
"BaseMiddleware",
74-
]
75-
7670
ProxiesTypes = typing.Union[
7771
URLTypes, "Dispatcher", typing.Dict[URLTypes, typing.Union[URLTypes, "Dispatcher"]]
7872
]

tests/client/test_auth.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,14 @@ async def test_trust_env_auth():
164164
os.environ["NETRC"] = "tests/.netrc"
165165
url = "http://netrcexample.org"
166166

167-
client = Client(dispatch=MockDispatch())
168-
response = await client.get(url, trust_env=False)
167+
client = Client(dispatch=MockDispatch(), trust_env=False)
168+
response = await client.get(url)
169169

170170
assert response.status_code == 200
171171
assert response.json() == {"auth": None}
172172

173-
client = Client(dispatch=MockDispatch(), trust_env=False)
174-
response = await client.get(url, trust_env=True)
173+
client = Client(dispatch=MockDispatch(), trust_env=True)
174+
response = await client.get(url)
175175

176176
assert response.status_code == 200
177177
assert response.json() == {

0 commit comments

Comments
 (0)