From e90b2ccf2a5a9eec63a26ba6784ac9a32b9bc13b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 7 Sep 2020 09:36:37 +0100 Subject: [PATCH 1/2] Support Response(content=) --- httpx/_content_streams.py | 35 +++++++++-------- httpx/_models.py | 11 ++++-- httpx/_types.py | 2 + tests/models/test_responses.py | 46 +++++----------------- tests/test_content_streams.py | 71 +++++++++++++++++++++++++++++++++- tests/test_decoders.py | 13 ++----- 6 files changed, 110 insertions(+), 68 deletions(-) diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py index 402fa959c8..3cd2196ab4 100644 --- a/httpx/_content_streams.py +++ b/httpx/_content_streams.py @@ -8,7 +8,7 @@ import httpcore from ._exceptions import StreamConsumed -from ._types import FileContent, FileTypes, RequestData, RequestFiles +from ._types import FileContent, FileTypes, RequestData, RequestFiles, ResponseContent from ._utils import ( format_form_param, guess_content_type, @@ -72,11 +72,8 @@ class IteratorStream(ContentStream): Request content encoded as plain bytes, using an byte iterator. """ - def __init__( - self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None - ) -> None: + def __init__(self, iterator: typing.Iterator[bytes]) -> None: self.iterator = iterator - self.close_func = close_func self.is_stream_consumed = False def can_replay(self) -> bool: @@ -95,21 +92,14 @@ def __iter__(self) -> typing.Iterator[bytes]: def __aiter__(self) -> typing.AsyncIterator[bytes]: raise RuntimeError("Attempted to call a async iterator on an sync stream.") - def close(self) -> None: - if self.close_func is not None: - self.close_func() - class AsyncIteratorStream(ContentStream): """ Request content encoded as plain bytes, using an async byte iterator. """ - def __init__( - self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None - ) -> None: + def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None: self.aiterator = aiterator - self.close_func = close_func self.is_stream_consumed = False def can_replay(self) -> bool: @@ -128,10 +118,6 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: async for part in self.aiterator: yield part - async def aclose(self) -> None: - if self.close_func is not None: - await self.close_func() - class JSONStream(ContentStream): """ @@ -402,3 +388,18 @@ def encode( return IteratorStream(iterator=data) raise TypeError(f"Unexpected type for 'data', {type(data)!r}") + + +def encode_response(content: ResponseContent = None) -> ContentStream: + if content is None: + return ByteStream(b"") + elif isinstance(content, bytes): + return ByteStream(body=content) + elif hasattr(content, "__aiter__"): + content = typing.cast(typing.AsyncIterator[bytes], content) + return AsyncIteratorStream(aiterator=content) + elif hasattr(content, "__iter__"): + content = typing.cast(typing.Iterator[bytes], content) + return IteratorStream(iterator=content) + + raise TypeError(f"Unexpected type for 'content', {type(content)!r}") diff --git a/httpx/_models.py b/httpx/_models.py index 713281e662..694e520c2c 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -14,7 +14,7 @@ import rfc3986 import rfc3986.exceptions -from ._content_streams import ByteStream, ContentStream, encode +from ._content_streams import ByteStream, ContentStream, encode, encode_response from ._decoders import ( SUPPORTED_DECODERS, Decoder, @@ -44,6 +44,7 @@ QueryParamTypes, RequestData, RequestFiles, + ResponseContent, URLTypes, ) from ._utils import ( @@ -674,7 +675,7 @@ def __init__( http_version: str = None, headers: HeaderTypes = None, stream: ContentStream = None, - content: bytes = None, + content: ResponseContent = None, history: typing.List["Response"] = None, elapsed_func: typing.Callable = None, ): @@ -694,8 +695,10 @@ def __init__( if stream is not None: self._raw_stream = stream else: - self._raw_stream = ByteStream(body=content or b"") - self.read() + self._raw_stream = encode_response(content) + if content is None or isinstance(content, bytes): + # Load the response body, except for streaming content. + self.read() @property def elapsed(self) -> datetime.timedelta: diff --git a/httpx/_types.py b/httpx/_types.py index 3a90ee42e7..8989b2826c 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -63,6 +63,8 @@ None, ] +ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]] + RequestData = Union[dict, str, bytes, Iterator[bytes], AsyncIterator[bytes]] FileContent = Union[IO[str], IO[bytes], str, bytes] diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 2b07a27040..9c4d285091 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -5,7 +5,6 @@ import pytest import httpx -from httpx._content_streams import AsyncIteratorStream, IteratorStream def streaming_body(): @@ -215,10 +214,9 @@ async def test_aread(): def test_iter_raw(): - stream = IteratorStream(iterator=streaming_body()) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) raw = b"" @@ -229,10 +227,9 @@ def test_iter_raw(): @pytest.mark.asyncio async def test_aiter_raw(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) raw = b"" @@ -317,10 +314,9 @@ async def test_aiter_lines(): def test_sync_streaming_response(): - stream = IteratorStream(iterator=streaming_body()) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) assert response.status_code == 200 @@ -335,10 +331,9 @@ def test_sync_streaming_response(): @pytest.mark.asyncio async def test_async_streaming_response(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) assert response.status_code == 200 @@ -352,10 +347,9 @@ async def test_async_streaming_response(): def test_cannot_read_after_stream_consumed(): - stream = IteratorStream(iterator=streaming_body()) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) content = b"" @@ -368,10 +362,9 @@ def test_cannot_read_after_stream_consumed(): @pytest.mark.asyncio async def test_cannot_aread_after_stream_consumed(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) content = b"" @@ -383,54 +376,33 @@ async def test_cannot_aread_after_stream_consumed(): def test_cannot_read_after_response_closed(): - is_closed = False - - def close_func(): - nonlocal is_closed - is_closed = True - - stream = IteratorStream(iterator=streaming_body(), close_func=close_func) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) response.close() - assert is_closed - with pytest.raises(httpx.ResponseClosed): response.read() @pytest.mark.asyncio async def test_cannot_aread_after_response_closed(): - is_closed = False - - async def close_func(): - nonlocal is_closed - is_closed = True - - stream = AsyncIteratorStream( - aiterator=async_streaming_body(), close_func=close_func - ) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) await response.aclose() - assert is_closed - with pytest.raises(httpx.ResponseClosed): await response.aread() @pytest.mark.asyncio async def test_elapsed_not_available_until_closed(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) with pytest.raises(RuntimeError): diff --git a/tests/test_content_streams.py b/tests/test_content_streams.py index 140aa8d2af..2d1de1f1c0 100644 --- a/tests/test_content_streams.py +++ b/tests/test_content_streams.py @@ -3,7 +3,7 @@ import pytest from httpx import StreamConsumed -from httpx._content_streams import ContentStream, encode +from httpx._content_streams import ContentStream, encode, encode_response @pytest.mark.asyncio @@ -251,3 +251,72 @@ async def test_multipart_multiple_files_single_input_content(): b"--+++--\r\n", ] ) + + +@pytest.mark.asyncio +async def test_response_empty_content(): + stream = encode_response() + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) + + assert stream.can_replay() + assert stream.get_headers() == {} + assert sync_content == b"" + assert async_content == b"" + + +@pytest.mark.asyncio +async def test_response_bytes_content(): + stream = encode_response(content=b"Hello, world!") + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) + + assert stream.can_replay() + assert stream.get_headers() == {"Content-Length": "13"} + assert sync_content == b"Hello, world!" + assert async_content == b"Hello, world!" + + +@pytest.mark.asyncio +async def test_response_iterator_content(): + def hello_world(): + yield b"Hello, " + yield b"world!" + + stream = encode_response(content=hello_world()) + content = b"".join([part for part in stream]) + + assert not stream.can_replay() + assert stream.get_headers() == {"Transfer-Encoding": "chunked"} + assert content == b"Hello, world!" + + with pytest.raises(RuntimeError): + [part async for part in stream] + + with pytest.raises(StreamConsumed): + [part for part in stream] + + +@pytest.mark.asyncio +async def test_response_aiterator_content(): + async def hello_world(): + yield b"Hello, " + yield b"world!" + + stream = encode_response(content=hello_world()) + content = b"".join([part async for part in stream]) + + assert not stream.can_replay() + assert stream.get_headers() == {"Transfer-Encoding": "chunked"} + assert content == b"Hello, world!" + + with pytest.raises(RuntimeError): + [part for part in stream] + + with pytest.raises(StreamConsumed): + [part async for part in stream] + + +def test_response_invalid_argument(): + with pytest.raises(TypeError): + encode_response(123) # type: ignore diff --git a/tests/test_decoders.py b/tests/test_decoders.py index dbbaac5450..7dfca9ef50 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -4,7 +4,6 @@ import pytest import httpx -from httpx._content_streams import AsyncIteratorStream from httpx._decoders import ( BrotliDecoder, DeflateDecoder, @@ -130,11 +129,10 @@ async def compress(body): yield compressor.flush() headers = [(b"Content-Encoding", b"gzip")] - stream = AsyncIteratorStream(aiterator=compress(body)) response = httpx.Response( 200, headers=headers, - stream=stream, + content=compress(body), ) assert not hasattr(response, "body") assert await response.aread() == body @@ -199,19 +197,17 @@ async def iterator(): yield chunk # Accessing `.text` on a read response. - stream = AsyncIteratorStream(aiterator=iterator()) response = httpx.Response( 200, - stream=stream, + content=iterator(), ) await response.aread() assert response.text == (b"".join(data)).decode(encoding) # Streaming `.aiter_text` iteratively. - stream = AsyncIteratorStream(aiterator=iterator()) response = httpx.Response( 200, - stream=stream, + content=iterator(), ) text = "".join([part async for part in response.aiter_text()]) assert text == (b"".join(data)).decode(encoding) @@ -224,11 +220,10 @@ async def iterator(): yield b"\x83" yield b"\x89\x83x\x83\x8b" - stream = AsyncIteratorStream(aiterator=iterator()) response = httpx.Response( 200, headers=[(b"Content-Type", b"text/html; charset=shift-jis")], - stream=stream, + content=iterator(), ) await response.aread() From 6af5e50bec87d8bf7cac00eb7c66cfdffd311378 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 10 Sep 2020 20:02:51 +0100 Subject: [PATCH 2/2] Update test for merged master --- tests/models/test_responses.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 5517ab8efc..b52e4846f3 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -226,12 +226,7 @@ def test_iter_raw(): def test_iter_raw_increments_updates_counter(): - stream = IteratorStream(iterator=streaming_body()) - - response = httpx.Response( - 200, - stream=stream, - ) + response = httpx.Response(200, content=streaming_body()) num_downloaded = response.num_bytes_downloaded for part in response.iter_raw(): @@ -241,10 +236,7 @@ def test_iter_raw_increments_updates_counter(): @pytest.mark.asyncio async def test_aiter_raw(): - response = httpx.Response( - 200, - content=async_streaming_body(), - ) + response = httpx.Response(200, content=async_streaming_body()) raw = b"" async for part in response.aiter_raw(): @@ -254,12 +246,7 @@ async def test_aiter_raw(): @pytest.mark.asyncio async def test_aiter_raw_increments_updates_counter(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) - - response = httpx.Response( - 200, - stream=stream, - ) + response = httpx.Response(200, content=async_streaming_body()) num_downloaded = response.num_bytes_downloaded async for part in response.aiter_raw():