Skip to content

Commit

Permalink
Add support for async auth flows
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Aug 25, 2020
1 parent 4161d7a commit eaf43b3
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
26 changes: 26 additions & 0 deletions httpx/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ class Auth:
To implement a custom authentication scheme, subclass `Auth` and override
the `.auth_flow()` method.
If the authentication scheme does I/O, such as disk access or network calls, or uses
synchronization primitives such as locks, you should override `.async_auth_flow()`
to provide an async-friendly implementation that will be used by the `AsyncClient`.
Usage of sync I/O within an async codebase would block the event loop, and could
cause performance issues.
"""

requires_request_body = False
Expand Down Expand Up @@ -46,6 +52,26 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non
"""
yield request

async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
"""
Execute the authentication flow asynchronously.
By default, this defers to `.auth_flow()`. You should override this method
when the authentication scheme does I/O, such as disk access or network calls,
or uses concurrency primitives such as locks.
"""
flow = self.auth_flow(request)
request = next(flow)

while True:
response = yield request
try:
request = flow.send(response)
except StopIteration:
break


class FunctionAuth(Auth):
"""
Expand Down
8 changes: 4 additions & 4 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,15 +1365,15 @@ async def _send_handling_auth(
if auth.requires_request_body:
await request.aread()

auth_flow = auth.auth_flow(request)
request = next(auth_flow)
auth_flow = auth.async_auth_flow(request)
request = await auth_flow.__anext__()
while True:
response = await self._send_single_request(request, timeout)
if auth.requires_response_body:
await response.aread()
try:
next_request = auth_flow.send(response)
except StopIteration:
next_request = await auth_flow.asend(response)
except StopAsyncIteration:
return response
except BaseException as exc:
await response.aclose()
Expand Down
45 changes: 45 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import hashlib
import os
import typing
import threading

import httpcore
import pytest
Expand Down Expand Up @@ -184,6 +186,29 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non
yield request


class SyncOrAsyncAuth(Auth):
"""
A mock authentication scheme that uses a different implementation for the
sync and async cases.
"""

def __init__(self):
self._lock = threading.Lock()
self._async_lock = asyncio.Lock()

def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
with self._lock:
request.headers["Authorization"] = "sync-auth"
yield request

async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
async with self._async_lock:
request.headers["Authorization"] = "async-auth"
yield request


@pytest.mark.asyncio
async def test_basic_auth() -> None:
url = "https://example.org/"
Expand Down Expand Up @@ -641,3 +666,23 @@ def test_sync_auth_reads_response_body() -> None:
response = client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": '{"auth": "xyz"}'}


@pytest.mark.asyncio
async def test_sync_async_auth() -> None:
"""
Test that we can use a different auth flow implementation in the async case, to
support cases that require performing I/O or using concurrency primitives (such
as checking a disk-based cache or fetching a token from a remote auth server).
"""
url = "https://example.org/"
auth = SyncOrAsyncAuth()

client = AsyncClient(transport=AsyncMockTransport())
response = await client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "async-auth"}

response = Client(transport=SyncMockTransport()).get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "sync-auth"}

0 comments on commit eaf43b3

Please sign in to comment.