diff --git a/aiohttp/client.py b/aiohttp/client.py index 9c2fd8073a..a6dabdce6c 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -75,6 +75,7 @@ ClientResponse, Fingerprint, RequestInfo, + process_data_to_payload, ) from .client_ws import ( DEFAULT_WS_CLIENT_TIMEOUT, @@ -521,6 +522,9 @@ async def _request( for trace in traces: await trace.send_request_start(method, url.update_query(params), headers) + # preprocess the data so we can reuse the Payload object when redirect is needed + data = process_data_to_payload(data) + timer = tm.timer() try: with timer: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index b15fe9ebbf..f8a5346dee 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -163,6 +163,24 @@ class ConnectionKey: proxy_headers_hash: Optional[int] # hash(CIMultiDict) +def process_data_to_payload(body: Any) -> Any: + # this function is used to convert data to payload before looping into redirects, + # so payload with io objects can be keep alive and use the stored data for the next request + if body is None: + return None + + # FormData + if isinstance(body, FormData): + body = body() + + try: + body = payload.PAYLOAD_REGISTRY.get(body, disposition=None) + except payload.LookupError: + pass # keep for ClientRequest to handle + + return body + + class ClientRequest: GET_METHODS = { hdrs.METH_GET, diff --git a/aiohttp/formdata.py b/aiohttp/formdata.py index 6e005a78ba..f49c33e1b4 100644 --- a/aiohttp/formdata.py +++ b/aiohttp/formdata.py @@ -28,7 +28,6 @@ def __init__( self._writer = multipart.MultipartWriter("form-data", boundary=self._boundary) self._fields: List[Any] = [] self._is_multipart = False - self._is_processed = False self._quote_fields = quote_fields self._charset = charset @@ -117,8 +116,8 @@ def _gen_form_urlencoded(self) -> payload.BytesPayload: def _gen_form_data(self) -> multipart.MultipartWriter: """Encode a list of fields using the multipart/form-data MIME format""" - if self._is_processed: - raise RuntimeError("Form data has been processed already") + if not self._fields: + return self._writer for dispparams, headers, value in self._fields: try: if hdrs.CONTENT_TYPE in headers: @@ -149,7 +148,7 @@ def _gen_form_data(self) -> multipart.MultipartWriter: self._writer.append_payload(part) - self._is_processed = True + self._fields.clear() return self._writer def __call__(self) -> Payload: diff --git a/aiohttp/payload.py b/aiohttp/payload.py index ea50b6a38c..e353666f32 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -307,17 +307,39 @@ def __init__( if hdrs.CONTENT_DISPOSITION not in self.headers: self.set_content_disposition(disposition, filename=self._filename) + self._writable = True + self._seekable = True + try: + # It is weird but some IO object dont have `seekable()` method as IOBase object, + # it seems better for us to direct try if the `seek()` and `tell()` is available + # e.g. tarfile.TarFile._Stream + self._value.seek(self._value.tell()) + except (AttributeError, OSError): + self._seekable = False + + if self._seekable: + self._stream_pos = self._value.tell() + else: + self._stream_pos = 0 + async def write(self, writer: AbstractStreamWriter) -> None: loop = asyncio.get_event_loop() - try: + if self._seekable: + await loop.run_in_executor(None, self._value.seek, self._stream_pos) + elif not self._writable: + raise RuntimeError( + f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)' + ) + chunk = await loop.run_in_executor(None, self._value.read, 2**16) + while chunk: + await writer.write(chunk) chunk = await loop.run_in_executor(None, self._value.read, 2**16) - while chunk: - await writer.write(chunk) - chunk = await loop.run_in_executor(None, self._value.read, 2**16) - finally: - await loop.run_in_executor(None, self._value.close) + if not self._seekable: + self._writable = False # Non-seekable IO `_value` can only be consumed once def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + if self._seekable: + self._value.seek(self._stream_pos) return "".join(r.decode(encoding, errors) for r in self._value.readlines()) @@ -354,40 +376,50 @@ def __init__( @property def size(self) -> Optional[int]: try: - return os.fstat(self._value.fileno()).st_size - self._value.tell() + return os.fstat(self._value.fileno()).st_size - self._stream_pos except OSError: return None def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + if self._seekable: + self._value.seek(self._stream_pos) return self._value.read() async def write(self, writer: AbstractStreamWriter) -> None: loop = asyncio.get_event_loop() - try: + if self._seekable: + await loop.run_in_executor(None, self._value.seek, self._stream_pos) + elif not self._writable: + raise RuntimeError( + f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)' + ) + chunk = await loop.run_in_executor(None, self._value.read, 2**16) + while chunk: + data = ( + chunk.encode(encoding=self._encoding) + if self._encoding + else chunk.encode() + ) + await writer.write(data) chunk = await loop.run_in_executor(None, self._value.read, 2**16) - while chunk: - data = ( - chunk.encode(encoding=self._encoding) - if self._encoding - else chunk.encode() - ) - await writer.write(data) - chunk = await loop.run_in_executor(None, self._value.read, 2**16) - finally: - await loop.run_in_executor(None, self._value.close) + if not self._seekable: + self._writable = False # Non-seekable IO `_value` can only be consumed once class BytesIOPayload(IOBasePayload): _value: io.BytesIO @property - def size(self) -> int: - position = self._value.tell() - end = self._value.seek(0, os.SEEK_END) - self._value.seek(position) - return end - position + def size(self) -> Optional[int]: + if self._seekable: + end = self._value.seek(0, os.SEEK_END) + self._value.seek(self._stream_pos) + return end - self._stream_pos + return None def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + if self._seekable: + self._value.seek(self._stream_pos) return self._value.read().decode(encoding, errors) @@ -397,7 +429,7 @@ class BufferedReaderPayload(IOBasePayload): @property def size(self) -> Optional[int]: try: - return os.fstat(self._value.fileno()).st_size - self._value.tell() + return os.fstat(self._value.fileno()).st_size - self._stream_pos except (OSError, AttributeError): # data.fileno() is not supported, e.g. # io.BufferedReader(io.BytesIO(b'data')) @@ -406,6 +438,8 @@ def size(self) -> Optional[int]: return None def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + if self._seekable: + self._value.seek(self._stream_pos) return self._value.read().decode(encoding, errors) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 8713f3682f..bff815c12f 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -1544,6 +1544,8 @@ async def test_GET_DEFLATE( aiohttp_client: AiohttpClient, data: Optional[bytes] ) -> None: async def handler(request: web.Request) -> web.Response: + recv_data = await request.read() + assert recv_data == b"" # both cases should receive empty bytes return web.json_response({"ok": True}) write_mock = None @@ -1553,10 +1555,10 @@ async def write_bytes( self: ClientRequest, writer: StreamWriter, conn: Connection ) -> None: nonlocal write_mock - original_write = writer._write + original_write = writer.write with mock.patch.object( - writer, "_write", autospec=True, spec_set=True, side_effect=original_write + writer, "write", autospec=True, spec_set=True, side_effect=original_write ) as write_mock: await original_write_bytes(self, writer, conn) @@ -1571,8 +1573,8 @@ async def write_bytes( assert content == {"ok": True} assert write_mock is not None - # No chunks should have been sent for an empty body. - write_mock.assert_not_called() + # Empty b"" should have been sent for an empty body. + write_mock.assert_called_once_with(b"") async def test_POST_DATA_DEFLATE(aiohttp_client: AiohttpClient) -> None: diff --git a/tests/test_formdata.py b/tests/test_formdata.py index 7ddd53038c..b15e493573 100644 --- a/tests/test_formdata.py +++ b/tests/test_formdata.py @@ -1,9 +1,12 @@ import io +import pathlib +import tarfile from unittest import mock import pytest from aiohttp import FormData, web +from aiohttp.client_exceptions import ClientConnectionError from aiohttp.http_writer import StreamWriter from aiohttp.pytest_plugin import AiohttpClient @@ -95,28 +98,113 @@ async def test_formdata_field_name_is_not_quoted( assert b'name="email 1"' in buf -async def test_mark_formdata_as_processed(aiohttp_client: AiohttpClient) -> None: - async def handler(request: web.Request) -> web.Response: - return web.Response() +async def test_formdata_boundary_param() -> None: + boundary = "some_boundary" + form = FormData(boundary=boundary) + assert form._writer.boundary == boundary - app = web.Application() - app.add_routes([web.post("/", handler)]) - client = await aiohttp_client(app) +async def test_formdata_on_redirect(aiohttp_client: AiohttpClient) -> None: + with pathlib.Path(pathlib.Path(__file__).parent / "sample.txt").open("rb") as fobj: + content = fobj.read() + fobj.seek(0) - data = FormData() - data.add_field("test", "test_value", content_type="application/json") + async def handler_0(request: web.Request) -> web.Response: + raise web.HTTPPermanentRedirect("/1") - resp = await client.post("/", data=data) - assert len(data._writer._parts) == 1 + async def handler_1(request: web.Request) -> web.Response: + req_data = await request.post() + assert ["sample.txt"] == list(req_data.keys()) + file_field = req_data["sample.txt"] + assert isinstance(file_field, web.FileField) + assert content == file_field.file.read() + return web.Response() - with pytest.raises(RuntimeError): - await client.post("/", data=data) + app = web.Application() + app.router.add_post("/0", handler_0) + app.router.add_post("/1", handler_1) - resp.release() + client = await aiohttp_client(app) + data = FormData() + data.add_field("sample.txt", fobj) -async def test_formdata_boundary_param() -> None: - boundary = "some_boundary" - form = FormData(boundary=boundary) - assert form._writer.boundary == boundary + resp = await client.post("/0", data=data) + assert len(data._writer._parts) == 1 + assert resp.status == 200 + + resp.release() + + +async def test_formdata_on_redirect_after_recv(aiohttp_client: AiohttpClient) -> None: + with pathlib.Path(pathlib.Path(__file__).parent / "sample.txt").open("rb") as fobj: + content = fobj.read() + fobj.seek(0) + + async def handler_0(request: web.Request) -> web.Response: + req_data = await request.post() + assert ["sample.txt"] == list(req_data.keys()) + file_field = req_data["sample.txt"] + assert isinstance(file_field, web.FileField) + assert content == file_field.file.read() + raise web.HTTPPermanentRedirect("/1") + + async def handler_1(request: web.Request) -> web.Response: + req_data = await request.post() + assert ["sample.txt"] == list(req_data.keys()) + file_field = req_data["sample.txt"] + assert isinstance(file_field, web.FileField) + assert content == file_field.file.read() + return web.Response() + + app = web.Application() + app.router.add_post("/0", handler_0) + app.router.add_post("/1", handler_1) + + client = await aiohttp_client(app) + + data = FormData() + data.add_field("sample.txt", fobj) + + resp = await client.post("/0", data=data) + assert len(data._writer._parts) == 1 + assert resp.status == 200 + + resp.release() + + +async def test_streaming_tarfile_on_redirect(aiohttp_client: AiohttpClient) -> None: + data = b"This is a tar file payload text file." + + async def handler_0(request: web.Request) -> web.Response: + await request.read() + raise web.HTTPPermanentRedirect("/1") + + async def handler_1(request: web.Request) -> web.Response: + await request.read() + return web.Response() + + app = web.Application() + app.router.add_post("/0", handler_0) + app.router.add_post("/1", handler_1) + + client = await aiohttp_client(app) + + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tf: + ti = tarfile.TarInfo(name="payload1.txt") + ti.size = len(data) + tf.addfile(tarinfo=ti, fileobj=io.BytesIO(data)) + + # Streaming tarfile. + buf.seek(0) + tf = tarfile.open(fileobj=buf, mode="r|") + for entry in tf: + with pytest.raises(ClientConnectionError) as exc_info: + await client.post("/0", data=tf.extractfile(entry)) + raw_exc_info = exc_info._excinfo + assert isinstance(raw_exc_info, tuple) + cause_exc = raw_exc_info[1].__cause__ + assert isinstance(cause_exc, RuntimeError) + assert len(cause_exc.args) == 1 + assert cause_exc.args[0].startswith("Non-seekable IO payload") diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 55befdbb60..756f8cffc9 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1422,6 +1422,35 @@ async def test_reset_content_disposition_header( b' attachments; filename="bug.py"' ) + async def test_multiple_write_on_io_payload( + self, buf: bytearray, stream: Stream + ) -> None: + with aiohttp.MultipartWriter("form-data", boundary=":") as writer: + with pathlib.Path(pathlib.Path(__file__).parent / "sample.txt").open( + "rb" + ) as fobj: + content = fobj.read() + fobj.seek(0) + + target_buf = ( + b'--:\r\nContent-Type: text/plain\r\nContent-Disposition: attachment; filename="sample.txt"\r\n\r\n' + + content + + b"\r\n--:--\r\n" + ) + + writer.append(fobj) + assert len(writer._parts) == 1 + assert isinstance(writer._parts[0][0], payload.BufferedReaderPayload) + + await writer.write(stream) + assert bytes(buf) == target_buf + + buf.clear() + assert bytes(buf) == b"" + + await writer.write(stream) + assert bytes(buf) == target_buf + async def test_async_for_reader() -> None: data: Tuple[Dict[str, str], int, bytes, bytes, bytes] = ( diff --git a/tests/test_payload.py b/tests/test_payload.py index 8c04c5cba5..c92f6b7fb7 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -1,6 +1,8 @@ import array -from io import StringIO +import io +import pathlib from typing import Any, AsyncIterator, Iterator +from unittest import mock import pytest @@ -92,13 +94,42 @@ def test_string_payload() -> None: def test_string_io_payload() -> None: - s = StringIO("ű" * 5000) + s = io.StringIO("ű" * 5000) p = payload.StringIOPayload(s) assert p.encoding == "utf-8" assert p.content_type == "text/plain; charset=utf-8" assert p.size == 10000 +def test_text_io_payload() -> None: + filepath = pathlib.Path(__file__).parent / "sample.txt" + filesize = filepath.stat().st_size + with filepath.open("r") as f: + p = payload.TextIOPayload(f) + assert p.encoding == "utf-8" + assert p.content_type == "text/plain; charset=utf-8" + assert p.size == filesize + assert not f.closed + + +def test_bytes_io_payload() -> None: + filepath = pathlib.Path(__file__).parent / "sample.txt" + filesize = filepath.stat().st_size + with filepath.open("rb") as f: + p = payload.BytesIOPayload(f) + assert p.size == filesize + assert not f.closed + + +def test_buffered_reader_payload() -> None: + filepath = pathlib.Path(__file__).parent / "sample.txt" + filesize = filepath.stat().st_size + with filepath.open("rb") as f: + p = payload.BufferedReaderPayload(f) + assert p.size == filesize + assert not f.closed + + def test_async_iterable_payload_default_content_type() -> None: async def gen() -> AsyncIterator[bytes]: return @@ -120,3 +151,85 @@ async def gen() -> AsyncIterator[bytes]: def test_async_iterable_payload_not_async_iterable() -> None: with pytest.raises(TypeError): payload.AsyncIterablePayload(object()) # type: ignore[arg-type] + + +async def test_string_io_payload_write() -> None: + content = "ű" * 5000 + + s = io.StringIO(content) + p = payload.StringIOPayload(s) + + with mock.patch("aiohttp.http_writer.StreamWriter") as mock_obj: + instance = mock_obj.return_value + instance.write = mock.AsyncMock() + + await p.write(instance) + instance.write.assert_called_once_with(content.encode("utf-8")) + + instance.write.reset_mock() + + await p.write(instance) + instance.write.assert_called_once_with(content.encode("utf-8")) + + +async def test_text_io_payload_write() -> None: + filepath = pathlib.Path(__file__).parent / "sample.txt" + with filepath.open("r") as f: + content = f.read() + f.seek(0) + + p = payload.TextIOPayload(f) + + with mock.patch("aiohttp.http_writer.StreamWriter") as mock_obj: + instance = mock_obj.return_value + instance.write = mock.AsyncMock() + + await p.write(instance) + instance.write.assert_called_once_with(content.encode("utf-8")) # 1 chunk + + instance.write.reset_mock() + + await p.write(instance) + instance.write.assert_called_once_with(content.encode("utf-8")) # 1 chunk + + +async def test_bytes_io_payload_write() -> None: + filepath = pathlib.Path(__file__).parent / "sample.txt" + with filepath.open("rb") as f: + content = f.read() + with io.BytesIO(content) as bf: + + p = payload.BytesIOPayload(bf) + + with mock.patch("aiohttp.http_writer.StreamWriter") as mock_obj: + instance = mock_obj.return_value + instance.write = mock.AsyncMock() + + await p.write(instance) + instance.write.assert_called_once_with(content) # 1 chunk + + instance.write.reset_mock() + + await p.write(instance) + instance.write.assert_called_once_with(content) # 1 chunk + + +async def test_buffered_reader_payload_write() -> None: + filepath = pathlib.Path(__file__).parent / "sample.txt" + with filepath.open("rb") as f: + content = f.read() + f.seek(0) + + p = payload.BufferedReaderPayload(f) + + with mock.patch("aiohttp.http_writer.StreamWriter") as mock_obj: + instance = mock_obj.return_value + instance.write = mock.AsyncMock() + + await p.write(instance) + instance.write.assert_called_once_with(content) # 1 chunk + + instance.write.reset_mock() + + await p.write(instance) + instance.write.assert_called_once_with(content) # 1 chunk