Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 'I/O operation on closed file' and 'Form data has been processed already' upon redirect on multipart data #9201

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
ClientResponse,
Fingerprint,
RequestInfo,
process_data_to_payload,
)
from .client_ws import (
DEFAULT_WS_CLIENT_TIMEOUT,
Expand Down Expand Up @@ -521,6 +522,9 @@ async def _request(
for trace in traces:
await trace.send_request_start(method, url.update_query(params), headers)

# preprocess the data so we can reuse the Payload object when redirect is needed
data = process_data_to_payload(data)

timer = tm.timer()
try:
with timer:
Expand Down
18 changes: 18 additions & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ class ConnectionKey:
proxy_headers_hash: Optional[int] # hash(CIMultiDict)


def process_data_to_payload(body: Any) -> Any:
# this function is used to convert data to payload before looping into redirects,
# so payload with io objects can be keep alive and use the stored data for the next request
if body is None:
return None

# FormData
if isinstance(body, FormData):
body = body()

try:
body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
except payload.LookupError:
pass # keep for ClientRequest to handle

return body


class ClientRequest:
GET_METHODS = {
hdrs.METH_GET,
Expand Down
7 changes: 3 additions & 4 deletions aiohttp/formdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
self._writer = multipart.MultipartWriter("form-data", boundary=self._boundary)
self._fields: List[Any] = []
self._is_multipart = False
self._is_processed = False
self._quote_fields = quote_fields
self._charset = charset

Expand Down Expand Up @@ -117,8 +116,8 @@

def _gen_form_data(self) -> multipart.MultipartWriter:
"""Encode a list of fields using the multipart/form-data MIME format"""
if self._is_processed:
raise RuntimeError("Form data has been processed already")
if not self._fields:
return self._writer

Check warning on line 120 in aiohttp/formdata.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/formdata.py#L120

Added line #L120 was not covered by tests
Comment on lines +119 to +120
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't appear to save much, one call to .clear().

Suggested change
if not self._fields:
return self._writer

for dispparams, headers, value in self._fields:
try:
if hdrs.CONTENT_TYPE in headers:
Expand Down Expand Up @@ -149,7 +148,7 @@

self._writer.append_payload(part)

self._is_processed = True
self._fields.clear()
return self._writer

def __call__(self) -> Payload:
Expand Down
82 changes: 58 additions & 24 deletions aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,17 +307,39 @@
if hdrs.CONTENT_DISPOSITION not in self.headers:
self.set_content_disposition(disposition, filename=self._filename)

self._writable = True
self._seekable = True
try:
# It is weird but some IO object dont have `seekable()` method as IOBase object,
# it seems better for us to direct try if the `seek()` and `tell()` is available
# e.g. tarfile.TarFile._Stream
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class doesn't appear to be an IOBase, so I'm wondering why it reaches this code. Does it match one of the payload subclasses?

self._value.seek(self._value.tell())
except (AttributeError, OSError):
self._seekable = False

if self._seekable:
self._stream_pos = self._value.tell()
else:
self._stream_pos = 0

async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
if self._seekable:
await loop.run_in_executor(None, self._value.seek, self._stream_pos)
elif not self._writable:
raise RuntimeError(
f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)'
)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
finally:
await loop.run_in_executor(None, self._value.close)
if not self._seekable:
self._writable = False # Non-seekable IO `_value` can only be consumed once

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
self._value.seek(self._stream_pos)

Check warning on line 342 in aiohttp/payload.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/payload.py#L342

Added line #L342 was not covered by tests
return "".join(r.decode(encoding, errors) for r in self._value.readlines())


Expand Down Expand Up @@ -354,40 +376,50 @@
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
return os.fstat(self._value.fileno()).st_size - self._stream_pos
except OSError:
return None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
self._value.seek(self._stream_pos)
return self._value.read()

async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
if self._seekable:
await loop.run_in_executor(None, self._value.seek, self._stream_pos)
elif not self._writable:
raise RuntimeError(

Check warning on line 393 in aiohttp/payload.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/payload.py#L393

Added line #L393 was not covered by tests
f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)'
)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
data = (
chunk.encode(encoding=self._encoding)
if self._encoding
else chunk.encode()
)
await writer.write(data)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
data = (
chunk.encode(encoding=self._encoding)
if self._encoding
else chunk.encode()
)
await writer.write(data)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
finally:
await loop.run_in_executor(None, self._value.close)
if not self._seekable:
self._writable = False # Non-seekable IO `_value` can only be consumed once

Check warning on line 406 in aiohttp/payload.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/payload.py#L406

Added line #L406 was not covered by tests


class BytesIOPayload(IOBasePayload):
_value: io.BytesIO

@property
def size(self) -> int:
position = self._value.tell()
end = self._value.seek(0, os.SEEK_END)
self._value.seek(position)
return end - position
def size(self) -> Optional[int]:
if self._seekable:
end = self._value.seek(0, os.SEEK_END)
self._value.seek(self._stream_pos)
return end - self._stream_pos
return None

Check warning on line 418 in aiohttp/payload.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/payload.py#L418

Added line #L418 was not covered by tests

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
self._value.seek(self._stream_pos)
return self._value.read().decode(encoding, errors)


Expand All @@ -397,7 +429,7 @@
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
return os.fstat(self._value.fileno()).st_size - self._stream_pos
except (OSError, AttributeError):
# data.fileno() is not supported, e.g.
# io.BufferedReader(io.BytesIO(b'data'))
Expand All @@ -406,6 +438,8 @@
return None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
self._value.seek(self._stream_pos)
return self._value.read().decode(encoding, errors)


Expand Down
10 changes: 6 additions & 4 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,8 @@ async def test_GET_DEFLATE(
aiohttp_client: AiohttpClient, data: Optional[bytes]
) -> None:
async def handler(request: web.Request) -> web.Response:
recv_data = await request.read()
assert recv_data == b"" # both cases should receive empty bytes
return web.json_response({"ok": True})

write_mock = None
Expand All @@ -1553,10 +1555,10 @@ async def write_bytes(
self: ClientRequest, writer: StreamWriter, conn: Connection
) -> None:
nonlocal write_mock
original_write = writer._write
original_write = writer.write

with mock.patch.object(
writer, "_write", autospec=True, spec_set=True, side_effect=original_write
writer, "write", autospec=True, spec_set=True, side_effect=original_write
) as write_mock:
await original_write_bytes(self, writer, conn)

Expand All @@ -1571,8 +1573,8 @@ async def write_bytes(
assert content == {"ok": True}

assert write_mock is not None
# No chunks should have been sent for an empty body.
write_mock.assert_not_called()
# Empty b"" should have been sent for an empty body.
write_mock.assert_called_once_with(b"")


async def test_POST_DATA_DEFLATE(aiohttp_client: AiohttpClient) -> None:
Expand Down
122 changes: 105 additions & 17 deletions tests/test_formdata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import io
import pathlib
import tarfile
from unittest import mock

import pytest

from aiohttp import FormData, web
from aiohttp.client_exceptions import ClientConnectionError
from aiohttp.http_writer import StreamWriter
from aiohttp.pytest_plugin import AiohttpClient

Expand Down Expand Up @@ -95,28 +98,113 @@
assert b'name="email 1"' in buf


async def test_mark_formdata_as_processed(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
return web.Response()
async def test_formdata_boundary_param() -> None:
boundary = "some_boundary"
form = FormData(boundary=boundary)
assert form._writer.boundary == boundary

app = web.Application()
app.add_routes([web.post("/", handler)])

client = await aiohttp_client(app)
async def test_formdata_on_redirect(aiohttp_client: AiohttpClient) -> None:
with pathlib.Path(pathlib.Path(__file__).parent / "sample.txt").open("rb") as fobj:
content = fobj.read()
fobj.seek(0)

data = FormData()
data.add_field("test", "test_value", content_type="application/json")
async def handler_0(request: web.Request) -> web.Response:
raise web.HTTPPermanentRedirect("/1")

resp = await client.post("/", data=data)
assert len(data._writer._parts) == 1
async def handler_1(request: web.Request) -> web.Response:
req_data = await request.post()
assert ["sample.txt"] == list(req_data.keys())
file_field = req_data["sample.txt"]
assert isinstance(file_field, web.FileField)
assert content == file_field.file.read()
return web.Response()

with pytest.raises(RuntimeError):
await client.post("/", data=data)
app = web.Application()
app.router.add_post("/0", handler_0)
app.router.add_post("/1", handler_1)

resp.release()
client = await aiohttp_client(app)

data = FormData()
data.add_field("sample.txt", fobj)

async def test_formdata_boundary_param() -> None:
boundary = "some_boundary"
form = FormData(boundary=boundary)
assert form._writer.boundary == boundary
resp = await client.post("/0", data=data)
assert len(data._writer._parts) == 1
assert resp.status == 200

resp.release()


async def test_formdata_on_redirect_after_recv(aiohttp_client: AiohttpClient) -> None:
with pathlib.Path(pathlib.Path(__file__).parent / "sample.txt").open("rb") as fobj:
content = fobj.read()
fobj.seek(0)

async def handler_0(request: web.Request) -> web.Response:
req_data = await request.post()
assert ["sample.txt"] == list(req_data.keys())
file_field = req_data["sample.txt"]
assert isinstance(file_field, web.FileField)
assert content == file_field.file.read()
raise web.HTTPPermanentRedirect("/1")

async def handler_1(request: web.Request) -> web.Response:
req_data = await request.post()
assert ["sample.txt"] == list(req_data.keys())
file_field = req_data["sample.txt"]
assert isinstance(file_field, web.FileField)
assert content == file_field.file.read()
return web.Response()

app = web.Application()
app.router.add_post("/0", handler_0)
app.router.add_post("/1", handler_1)

client = await aiohttp_client(app)

data = FormData()
data.add_field("sample.txt", fobj)

resp = await client.post("/0", data=data)
assert len(data._writer._parts) == 1
assert resp.status == 200

resp.release()


async def test_streaming_tarfile_on_redirect(aiohttp_client: AiohttpClient) -> None:
data = b"This is a tar file payload text file."

async def handler_0(request: web.Request) -> web.Response:
await request.read()
raise web.HTTPPermanentRedirect("/1")

async def handler_1(request: web.Request) -> web.Response:
await request.read()
return web.Response()

Check warning on line 185 in tests/test_formdata.py

View check run for this annotation

Codecov / codecov/patch

tests/test_formdata.py#L185

Added line #L185 was not covered by tests

app = web.Application()
app.router.add_post("/0", handler_0)
app.router.add_post("/1", handler_1)

client = await aiohttp_client(app)

buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
ti = tarfile.TarInfo(name="payload1.txt")
ti.size = len(data)
tf.addfile(tarinfo=ti, fileobj=io.BytesIO(data))

# Streaming tarfile.
buf.seek(0)
tf = tarfile.open(fileobj=buf, mode="r|")
for entry in tf:
with pytest.raises(ClientConnectionError) as exc_info:
await client.post("/0", data=tf.extractfile(entry))
raw_exc_info = exc_info._excinfo
assert isinstance(raw_exc_info, tuple)
cause_exc = raw_exc_info[1].__cause__
assert isinstance(cause_exc, RuntimeError)
assert len(cause_exc.args) == 1
assert cause_exc.args[0].startswith("Non-seekable IO payload")
Loading
Loading