Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sync-specific or async-specific auth flows #1217

Merged
merged 16 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,55 @@ class MyCustomAuth(httpx.Auth):
...
```

If you _do_ need to perform I/O other than HTTP requests, such as accessing a disk-based cache, or you need to use concurrency primitives, such as locks, then you should override `.sync_auth_flow()` and `.async_auth_flow()` (instead of `.auth_flow()`). The former will be used by `httpx.Client`, while the latter will be used by `httpx.AsyncClient`.

```python
import asyncio
import threading
import httpx


class MyCustomAuth(httpx.Auth):
def __init__(self):
self._sync_lock = threading.RLock()
self._async_lock = asyncio.Lock()

def sync_get_token(self):
with self._sync_lock:
...

def sync_auth_flow(self, request):
token = self.sync_get_token()
request.headers["Authorization"] = f"Token {token}"
yield request

async def async_get_token(self):
async with self._async_lock:
...

async def async_auth_flow(self, request):
token = await self.async_get_token()
request.headers["Authorization"] = f"Token {token}"
yield request
```

If you only want to support one of the two methods, then you should still override it, but raise an explicit `RuntimeError`.

```python
import httpx
import sync_only_library


class MyCustomAuth(httpx.Auth):
def sync_auth_flow(self, request):
token = sync_only_library.get_token(...)
request.headers["Authorization"] = f"Token {token}"
yield request

async def async_auth_flow(self, request):
raise RuntimeError("Cannot use a sync authentication class with httpx.AsyncClient")
```

## SSL certificates

When making a request over HTTPS, HTTPX needs to verify the identity of the requested host. To do this, it uses a bundle of SSL certificates (a.k.a. CA bundle) delivered by a trusted certificate authority (CA).
Expand Down
55 changes: 55 additions & 0 deletions httpx/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class Auth:

To implement a custom authentication scheme, subclass `Auth` and override
the `.auth_flow()` method.

If the authentication scheme does I/O such as disk access or network calls, or uses
synchronization primitives such as locks, you should override `.sync_auth_flow()`
and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized
implementations that will be used by `Client` and `AsyncClient` respectively.
"""

requires_request_body = False
Expand Down Expand Up @@ -46,6 +51,56 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non
"""
tomchristie marked this conversation as resolved.
Show resolved Hide resolved
yield request

def sync_auth_flow(
self, request: Request
) -> typing.Generator[Request, Response, None]:
"""
Execute the authentication flow synchronously.

By default, this defers to `.auth_flow()`. You should override this method
when the authentication scheme does I/O and/or uses concurrency primitives.
"""
if self.requires_request_body:
request.read()

flow = self.auth_flow(request)
request = next(flow)

while True:
response = yield request
if self.requires_response_body:
response.read()

try:
request = flow.send(response)
except StopIteration:
break

async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
"""
Execute the authentication flow asynchronously.

By default, this defers to `.auth_flow()`. You should override this method
when the authentication scheme does I/O and/or uses concurrency primitives.
"""
if self.requires_request_body:
await request.aread()

flow = self.auth_flow(request)
request = next(flow)

while True:
response = yield request
if self.requires_response_body:
await response.aread()

try:
request = flow.send(response)
except StopIteration:
break


class FunctionAuth(Auth):
"""
Expand Down
22 changes: 8 additions & 14 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,15 +785,12 @@ def _send_handling_auth(
auth: Auth,
timeout: Timeout,
) -> Response:
if auth.requires_request_body:
request.read()

auth_flow = auth.auth_flow(request)
auth_flow = auth.sync_auth_flow(request)
request = next(auth_flow)

while True:
response = self._send_single_request(request, timeout)
if auth.requires_response_body:
response.read()

try:
next_request = auth_flow.send(response)
except StopIteration:
Expand Down Expand Up @@ -1409,18 +1406,15 @@ async def _send_handling_auth(
auth: Auth,
timeout: Timeout,
) -> Response:
if auth.requires_request_body:
await request.aread()
auth_flow = auth.async_auth_flow(request)
request = await auth_flow.__anext__()

auth_flow = auth.auth_flow(request)
request = next(auth_flow)
while True:
response = await self._send_single_request(request, timeout)
if auth.requires_response_body:
await response.aread()

try:
next_request = auth_flow.send(response)
except StopIteration:
next_request = await auth_flow.asend(response)
except StopAsyncIteration:
return response
except BaseException as exc:
await response.aclose()
Expand Down
63 changes: 63 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
"""
Integration tests for authentication.

Unit tests for auth classes also exist in tests/test_auth.py
"""
import asyncio
import hashlib
import os
import threading
import typing

import httpcore
Expand Down Expand Up @@ -183,6 +190,31 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non
yield request


class SyncOrAsyncAuth(Auth):
"""
A mock authentication scheme that uses a different implementation for the
sync and async cases.
"""

def __init__(self) -> None:
self._lock = threading.Lock()
self._async_lock = asyncio.Lock()

def sync_auth_flow(
self, request: Request
) -> typing.Generator[Request, Response, None]:
with self._lock:
request.headers["Authorization"] = "sync-auth"
yield request

async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
async with self._async_lock:
request.headers["Authorization"] = "async-auth"
yield request


@pytest.mark.asyncio
async def test_basic_auth() -> None:
url = "https://example.org/"
Expand Down Expand Up @@ -664,3 +696,34 @@ def test_sync_auth_reads_response_body() -> None:

assert response.status_code == 200
assert response.json() == {"auth": '{"auth": "xyz"}'}


@pytest.mark.asyncio
async def test_async_auth() -> None:
"""
Test that we can use an auth implementation specific to the async case, to
support cases that require performing I/O or using concurrency primitives (such
as checking a disk-based cache or fetching a token from a remote auth server).
"""
url = "https://example.org/"
auth = SyncOrAsyncAuth()

async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 200
assert response.json() == {"auth": "async-auth"}


def test_sync_auth() -> None:
"""
Test that we can use an auth implementation specific to the sync case.
"""
url = "https://example.org/"
auth = SyncOrAsyncAuth()

with httpx.Client(transport=SyncMockTransport()) as client:
response = client.get(url, auth=auth)

assert response.status_code == 200
assert response.json() == {"auth": "sync-auth"}
63 changes: 63 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Unit tests for auth classes.

Integration tests also exist in tests/client/test_auth.py
"""
import pytest

import httpx


def test_basic_auth():
auth = httpx.BasicAuth(username="user", password="pass")
request = httpx.Request("GET", "https://www.example.com")

# The initial request should include a basic auth header.
flow = auth.sync_auth_flow(request)
request = next(flow)
assert request.headers["Authorization"].startswith("Basic")

# No other requests are made.
response = httpx.Response(content=b"Hello, world!", status_code=200)
with pytest.raises(StopIteration):
flow.send(response)


def test_digest_auth_with_200():
auth = httpx.DigestAuth(username="user", password="pass")
request = httpx.Request("GET", "https://www.example.com")

# The initial request should not include an auth header.
flow = auth.sync_auth_flow(request)
request = next(flow)
assert "Authorization" not in request.headers

# If a 200 response is returned, then no other requests are made.
response = httpx.Response(content=b"Hello, world!", status_code=200)
with pytest.raises(StopIteration):
flow.send(response)


def test_digest_auth_with_401():
auth = httpx.DigestAuth(username="user", password="pass")
request = httpx.Request("GET", "https://www.example.com")

# The initial request should not include an auth header.
flow = auth.sync_auth_flow(request)
request = next(flow)
assert "Authorization" not in request.headers

# If a 401 response is returned, then a digest auth request is made.
headers = {
"WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."'
}
response = httpx.Response(
content=b"Auth required", status_code=401, headers=headers
)
request = flow.send(response)
assert request.headers["Authorization"].startswith("Digest")

# No other requests are made.
response = httpx.Response(content=b"Hello, world!", status_code=200)
with pytest.raises(StopIteration):
flow.send(response)