Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Response(content=<bytes iterator>) #1265

Merged
merged 4 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions httpx/_content_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import httpcore

from ._exceptions import StreamConsumed
from ._types import FileContent, FileTypes, RequestData, RequestFiles
from ._types import FileContent, FileTypes, RequestData, RequestFiles, ResponseContent
from ._utils import (
format_form_param,
guess_content_type,
Expand Down Expand Up @@ -72,11 +72,8 @@ class IteratorStream(ContentStream):
Request content encoded as plain bytes, using an byte iterator.
"""

def __init__(
self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None
) -> None:
def __init__(self, iterator: typing.Iterator[bytes]) -> None:
self.iterator = iterator
self.close_func = close_func
self.is_stream_consumed = False

def can_replay(self) -> bool:
Expand All @@ -95,21 +92,14 @@ def __iter__(self) -> typing.Iterator[bytes]:
def __aiter__(self) -> typing.AsyncIterator[bytes]:
raise RuntimeError("Attempted to call a async iterator on an sync stream.")

def close(self) -> None:
if self.close_func is not None:
self.close_func()


class AsyncIteratorStream(ContentStream):
"""
Request content encoded as plain bytes, using an async byte iterator.
"""

def __init__(
self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None
) -> None:
def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
self.aiterator = aiterator
self.close_func = close_func
self.is_stream_consumed = False

def can_replay(self) -> bool:
Expand All @@ -128,10 +118,6 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async for part in self.aiterator:
yield part

async def aclose(self) -> None:
if self.close_func is not None:
await self.close_func()


class JSONStream(ContentStream):
"""
Expand Down Expand Up @@ -402,3 +388,18 @@ def encode(
return IteratorStream(iterator=data)

raise TypeError(f"Unexpected type for 'data', {type(data)!r}")


def encode_response(content: ResponseContent = None) -> ContentStream:
if content is None:
return ByteStream(b"")
elif isinstance(content, bytes):
return ByteStream(body=content)
elif hasattr(content, "__aiter__"):
content = typing.cast(typing.AsyncIterator[bytes], content)
return AsyncIteratorStream(aiterator=content)
Comment on lines +398 to +400
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you were already aware, but a thing I just learned from reviewing encode/starlette#1041: collections.abc.AsyncIterator/collections.abc.Iterator (which the typing equivalents derive from) have a subclasshook which I think means you could just do:

    elif isinstance(content, typing.AsyncIterator):
          return AsyncIteratorStream(aiterator=content)

Also technically you may be casting Iterables to Iterators here by just checking for the presence of __aiter__, though it seems unlikely to cause problems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif hasattr(content, "__iter__"):
content = typing.cast(typing.Iterator[bytes], content)
return IteratorStream(iterator=content)

raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
11 changes: 7 additions & 4 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import rfc3986
import rfc3986.exceptions

from ._content_streams import ByteStream, ContentStream, encode
from ._content_streams import ByteStream, ContentStream, encode, encode_response
from ._decoders import (
SUPPORTED_DECODERS,
ContentDecoder,
Expand Down Expand Up @@ -44,6 +44,7 @@
QueryParamTypes,
RequestData,
RequestFiles,
ResponseContent,
URLTypes,
)
from ._utils import (
Expand Down Expand Up @@ -674,7 +675,7 @@ def __init__(
http_version: str = None,
headers: HeaderTypes = None,
stream: ContentStream = None,
content: bytes = None,
content: ResponseContent = None,
history: typing.List["Response"] = None,
elapsed_func: typing.Callable = None,
):
Expand All @@ -694,8 +695,10 @@ def __init__(
if stream is not None:
self._raw_stream = stream
else:
self._raw_stream = ByteStream(body=content or b"")
self.read()
self._raw_stream = encode_response(content)
if content is None or isinstance(content, bytes):
# Load the response body, except for streaming content.
self.read()

self._num_bytes_downloaded = 0

Expand Down
2 changes: 2 additions & 0 deletions httpx/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
None,
]

ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]]

RequestData = Union[dict, str, bytes, Iterator[bytes], AsyncIterator[bytes]]

FileContent = Union[IO[str], IO[bytes], str, bytes]
Expand Down
63 changes: 11 additions & 52 deletions tests/models/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest

import httpx
from httpx._content_streams import AsyncIteratorStream, IteratorStream


def streaming_body():
Expand Down Expand Up @@ -215,10 +214,9 @@ async def test_aread():


def test_iter_raw():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)

raw = b""
Expand All @@ -228,12 +226,7 @@ def test_iter_raw():


def test_iter_raw_increments_updates_counter():
stream = IteratorStream(iterator=streaming_body())

response = httpx.Response(
200,
stream=stream,
)
response = httpx.Response(200, content=streaming_body())

num_downloaded = response.num_bytes_downloaded
for part in response.iter_raw():
Expand All @@ -243,11 +236,7 @@ def test_iter_raw_increments_updates_counter():

@pytest.mark.asyncio
async def test_aiter_raw():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
)
response = httpx.Response(200, content=async_streaming_body())

raw = b""
async for part in response.aiter_raw():
Expand All @@ -257,12 +246,7 @@ async def test_aiter_raw():

@pytest.mark.asyncio
async def test_aiter_raw_increments_updates_counter():
stream = AsyncIteratorStream(aiterator=async_streaming_body())

response = httpx.Response(
200,
stream=stream,
)
response = httpx.Response(200, content=async_streaming_body())

num_downloaded = response.num_bytes_downloaded
async for part in response.aiter_raw():
Expand Down Expand Up @@ -346,10 +330,9 @@ async def test_aiter_lines():


def test_sync_streaming_response():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)

assert response.status_code == 200
Expand All @@ -364,10 +347,9 @@ def test_sync_streaming_response():

@pytest.mark.asyncio
async def test_async_streaming_response():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)

assert response.status_code == 200
Expand All @@ -381,10 +363,9 @@ async def test_async_streaming_response():


def test_cannot_read_after_stream_consumed():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)

content = b""
Expand All @@ -397,10 +378,9 @@ def test_cannot_read_after_stream_consumed():

@pytest.mark.asyncio
async def test_cannot_aread_after_stream_consumed():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)

content = b""
Expand All @@ -412,54 +392,33 @@ async def test_cannot_aread_after_stream_consumed():


def test_cannot_read_after_response_closed():
is_closed = False

def close_func():
nonlocal is_closed
is_closed = True

stream = IteratorStream(iterator=streaming_body(), close_func=close_func)
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)

response.close()
assert is_closed

with pytest.raises(httpx.ResponseClosed):
response.read()


@pytest.mark.asyncio
async def test_cannot_aread_after_response_closed():
is_closed = False

async def close_func():
nonlocal is_closed
is_closed = True

stream = AsyncIteratorStream(
aiterator=async_streaming_body(), close_func=close_func
)
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)

await response.aclose()
assert is_closed

with pytest.raises(httpx.ResponseClosed):
await response.aread()


@pytest.mark.asyncio
async def test_elapsed_not_available_until_closed():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)

with pytest.raises(RuntimeError):
Expand Down
71 changes: 70 additions & 1 deletion tests/test_content_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from httpx import StreamConsumed
from httpx._content_streams import ContentStream, encode
from httpx._content_streams import ContentStream, encode, encode_response


@pytest.mark.asyncio
Expand Down Expand Up @@ -251,3 +251,72 @@ async def test_multipart_multiple_files_single_input_content():
b"--+++--\r\n",
]
)


@pytest.mark.asyncio
async def test_response_empty_content():
stream = encode_response()
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])

assert stream.can_replay()
assert stream.get_headers() == {}
assert sync_content == b""
assert async_content == b""


@pytest.mark.asyncio
async def test_response_bytes_content():
stream = encode_response(content=b"Hello, world!")
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])

assert stream.can_replay()
assert stream.get_headers() == {"Content-Length": "13"}
assert sync_content == b"Hello, world!"
assert async_content == b"Hello, world!"


@pytest.mark.asyncio
async def test_response_iterator_content():
def hello_world():
yield b"Hello, "
yield b"world!"

stream = encode_response(content=hello_world())
content = b"".join([part for part in stream])

assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"

with pytest.raises(RuntimeError):
[part async for part in stream]

with pytest.raises(StreamConsumed):
[part for part in stream]


@pytest.mark.asyncio
async def test_response_aiterator_content():
async def hello_world():
yield b"Hello, "
yield b"world!"

stream = encode_response(content=hello_world())
content = b"".join([part async for part in stream])

assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"

with pytest.raises(RuntimeError):
[part for part in stream]

with pytest.raises(StreamConsumed):
[part async for part in stream]


def test_response_invalid_argument():
with pytest.raises(TypeError):
encode_response(123) # type: ignore
Loading