Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sync-specific or async-specific auth flows #1217

Merged
merged 16 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions httpx/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class Auth:

To implement a custom authentication scheme, subclass `Auth` and override
the `.auth_flow()` method.

If the authentication scheme does I/O such as disk access or network calls, or uses
synchronization primitives such as locks, you should override `.sync_auth_flow()`
and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized
implementations that will be used by `Client` and `AsyncClient` respectively.
"""

requires_request_body = False
Expand Down Expand Up @@ -46,6 +51,56 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non
"""
tomchristie marked this conversation as resolved.
Show resolved Hide resolved
yield request

def sync_auth_flow(
self, request: Request
) -> typing.Generator[Request, Response, None]:
"""
Execute the authentication flow synchronously.

By default, this defers to `.auth_flow()`. You should override this method
when the authentication scheme does I/O and/or uses concurrency primitives.
"""
if self.requires_request_body:
request.read()

flow = self.auth_flow(request)
request = next(flow)

while True:
response = yield request
if self.requires_response_body:
response.read()

try:
request = flow.send(response)
except StopIteration:
break

async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
"""
Execute the authentication flow asynchronously.

By default, this defers to `.auth_flow()`. You should override this method
when the authentication scheme does I/O and/or uses concurrency primitives.
"""
if self.requires_request_body:
await request.aread()

flow = self.auth_flow(request)
request = next(flow)

while True:
response = yield request
if self.requires_response_body:
await response.aread()

try:
request = flow.send(response)
except StopIteration:
break


class FunctionAuth(Auth):
"""
Expand Down
22 changes: 8 additions & 14 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,15 +760,12 @@ def _send_handling_auth(
auth: Auth,
timeout: Timeout,
) -> Response:
if auth.requires_request_body:
request.read()
auth_flow = auth.sync_auth_flow(request)
request = auth_flow.send(None) # type: ignore
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved

auth_flow = auth.auth_flow(request)
request = next(auth_flow)
while True:
response = self._send_single_request(request, timeout)
if auth.requires_response_body:
response.read()

try:
next_request = auth_flow.send(response)
except StopIteration:
Expand Down Expand Up @@ -1369,18 +1366,15 @@ async def _send_handling_auth(
auth: Auth,
timeout: Timeout,
) -> Response:
if auth.requires_request_body:
await request.aread()
auth_flow = auth.async_auth_flow(request)
request = await auth_flow.asend(None) # type: ignore

auth_flow = auth.auth_flow(request)
request = next(auth_flow)
while True:
response = await self._send_single_request(request, timeout)
if auth.requires_response_body:
await response.aread()

try:
next_request = auth_flow.send(response)
except StopIteration:
next_request = await auth_flow.asend(response)
except StopAsyncIteration:
return response
except BaseException as exc:
await response.aclose()
Expand Down
45 changes: 45 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import hashlib
import os
import threading
import typing

import httpcore
Expand Down Expand Up @@ -184,6 +186,29 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non
yield request


class SyncOrAsyncAuth(Auth):
"""
A mock authentication scheme that uses a different implementation for the
sync and async cases.
"""

def __init__(self):
self._lock = threading.Lock()
self._async_lock = asyncio.Lock()

def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
with self._lock:
request.headers["Authorization"] = "sync-auth"
yield request

async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
async with self._async_lock:
request.headers["Authorization"] = "async-auth"
yield request


@pytest.mark.asyncio
async def test_basic_auth() -> None:
url = "https://example.org/"
Expand Down Expand Up @@ -641,3 +666,23 @@ def test_sync_auth_reads_response_body() -> None:
response = client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": '{"auth": "xyz"}'}


@pytest.mark.asyncio
async def test_sync_async_auth() -> None:
"""
Test that we can use a different auth flow implementation in the async case, to
support cases that require performing I/O or using concurrency primitives (such
as checking a disk-based cache or fetching a token from a remote auth server).
"""
url = "https://example.org/"
auth = SyncOrAsyncAuth()

client = AsyncClient(transport=AsyncMockTransport())
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
response = await client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "async-auth"}

response = Client(transport=SyncMockTransport()).get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "sync-auth"}