diff --git a/httpx/_auth.py b/httpx/_auth.py index 571584593b..ab2082a79a 100644 --- a/httpx/_auth.py +++ b/httpx/_auth.py @@ -17,6 +17,12 @@ class Auth: To implement a custom authentication scheme, subclass `Auth` and override the `.auth_flow()` method. + + If the authentication scheme does I/O, such as disk access or network calls, or uses + synchronization primitives such as locks, you should override `.async_auth_flow()` + to provide an async-friendly implementation that will be used by the `AsyncClient`. + Usage of sync I/O within an async codebase would block the event loop, and could + cause performance issues. """ requires_request_body = False @@ -46,6 +52,26 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non """ yield request + async def async_auth_flow( + self, request: Request + ) -> typing.AsyncGenerator[Request, Response]: + """ + Execute the authentication flow asynchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O, such as disk access or network calls, + or uses concurrency primitives such as locks. + """ + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + try: + request = flow.send(response) + except StopIteration: + break + class FunctionAuth(Auth): """ diff --git a/httpx/_client.py b/httpx/_client.py index 2d2ca9ac16..939d50ae39 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -1365,15 +1365,15 @@ async def _send_handling_auth( if auth.requires_request_body: await request.aread() - auth_flow = auth.auth_flow(request) - request = next(auth_flow) + auth_flow = auth.async_auth_flow(request) + request = await auth_flow.__anext__() while True: response = await self._send_single_request(request, timeout) if auth.requires_response_body: await response.aread() try: - next_request = auth_flow.send(response) - except StopIteration: + next_request = await auth_flow.asend(response) + except StopAsyncIteration: return response except BaseException as exc: await response.aclose() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index edfccf0a70..1140337ae7 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1,6 +1,8 @@ +import asyncio import hashlib import os import typing +import threading import httpcore import pytest @@ -184,6 +186,29 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non yield request +class SyncOrAsyncAuth(Auth): + """ + A mock authentication scheme that uses a different implementation for the + sync and async cases. + """ + + def __init__(self): + self._lock = threading.Lock() + self._async_lock = asyncio.Lock() + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + with self._lock: + request.headers["Authorization"] = "sync-auth" + yield request + + async def async_auth_flow( + self, request: Request + ) -> typing.AsyncGenerator[Request, Response]: + async with self._async_lock: + request.headers["Authorization"] = "async-auth" + yield request + + @pytest.mark.asyncio async def test_basic_auth() -> None: url = "https://example.org/" @@ -641,3 +666,23 @@ def test_sync_auth_reads_response_body() -> None: response = client.get(url, auth=auth) assert response.status_code == 200 assert response.json() == {"auth": '{"auth": "xyz"}'} + + +@pytest.mark.asyncio +async def test_sync_async_auth() -> None: + """ + Test that we can use a different auth flow implementation in the async case, to + support cases that require performing I/O or using concurrency primitives (such + as checking a disk-based cache or fetching a token from a remote auth server). + """ + url = "https://example.org/" + auth = SyncOrAsyncAuth() + + client = AsyncClient(transport=AsyncMockTransport()) + response = await client.get(url, auth=auth) + assert response.status_code == 200 + assert response.json() == {"auth": "async-auth"} + + response = Client(transport=SyncMockTransport()).get(url, auth=auth) + assert response.status_code == 200 + assert response.json() == {"auth": "sync-auth"}