Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ConnectionResetError not being raised when the transport is closed #7180

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
eab15e3
Fix ConnectionResetError not being raised when the transport is closed
bdraco Jan 22, 2023
5a43af0
mypy
bdraco Jan 22, 2023
4a9da40
mypy
bdraco Jan 22, 2023
2da4064
add cover
bdraco Jan 22, 2023
8ac7f38
add cover
bdraco Jan 22, 2023
9cb25db
change
bdraco Jan 22, 2023
0c8b69a
contributors
bdraco Jan 22, 2023
3b90d3c
typo
bdraco Jan 22, 2023
f0ae99e
typo
bdraco Jan 22, 2023
4f44c32
typo
bdraco Jan 22, 2023
a393ece
single source of truth for the transport
bdraco Jan 23, 2023
8670f0a
preen
bdraco Jan 23, 2023
78273dd
empty
bdraco Jan 23, 2023
7f95175
empty
bdraco Jan 24, 2023
860fde8
rst syntax
bdraco Jan 27, 2023
d34fc70
Update aiohttp/http_writer.py
bdraco Jan 27, 2023
516a5c1
docstring
bdraco Jan 27, 2023
811e3e1
Merge remote-tracking branch 'bdraco/fix_connection_reset_error_not_r…
bdraco Jan 27, 2023
21bb53b
matches
bdraco Jan 27, 2023
c946e93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2023
5833948
mocker
bdraco Jan 27, 2023
ded9608
Merge remote-tracking branch 'bdraco/fix_connection_reset_error_not_r…
bdraco Jan 27, 2023
0ff872e
fix incorrect kwarg
bdraco Jan 27, 2023
96090d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2023
b93ed94
empty commit to rerun ci
bdraco Jan 27, 2023
043c1d4
Merge remote-tracking branch 'bdraco/fix_connection_reset_error_not_r…
bdraco Jan 27, 2023
05da43e
empty commit to rerun ci
bdraco Jan 27, 2023
4e10760
empty
bdraco Jan 27, 2023
90f4257
add connected property
bdraco Jan 28, 2023
471fa2e
cleanup
bdraco Jan 28, 2023
0172016
revert change to make mypy happy
bdraco Jan 28, 2023
d51d68a
revert change to make mypy happy
bdraco Jan 28, 2023
74b7e77
update tests
bdraco Jan 28, 2023
0f44ef2
still check for is_closing
bdraco Jan 28, 2023
a7e53ba
Update tests/test_client_proto.py
Dreamsorcerer Feb 1, 2023
ed76669
Update aiohttp/base_protocol.py
Dreamsorcerer Feb 1, 2023
53384ab
Update aiohttp/base_protocol.py
Dreamsorcerer Feb 1, 2023
115ce1d
Merge branch 'master' into fix_connection_reset_error_not_raised
Dreamsorcerer Feb 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/7180.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``ConnectionResetError`` will always be raised when ``StreamWriter.write`` is called after ``connection_lost`` has been called on the ``BaseProtocol``
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ Ilya Gruzinov
Ingmar Steen
Ivan Lakovic
Ivan Larin
J. Nick Koston
Jacob Champion
Jaesung Lee
Jake Davis
Expand Down
9 changes: 6 additions & 3 deletions aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop: asyncio.AbstractEventLoop = loop
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
self._connection_lost = False
self._reading_paused = False

self.transport: Optional[asyncio.Transport] = None

@property
def connected(self) -> bool:
"""Return True if the connection is open."""
return self.transport is not None

def pause_writing(self) -> None:
assert not self._paused
self._paused = True
Expand Down Expand Up @@ -59,7 +63,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.transport = tr

def connection_lost(self, exc: Optional[BaseException]) -> None:
self._connection_lost = True
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
Expand All @@ -76,7 +79,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
waiter.set_exception(exc)

async def _drain_helper(self) -> None:
if self._connection_lost:
if not self.connected:
raise ConnectionResetError("Connection lost")
if not self._paused:
return
Expand Down
10 changes: 4 additions & 6 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
on_headers_sent: _T_OnHeadersSent = None,
) -> None:
self._protocol = protocol
self._transport = protocol.transport

self.loop = loop
self.length = None
Expand All @@ -52,7 +51,7 @@ def __init__(

@property
def transport(self) -> Optional[asyncio.Transport]:
return self._transport
return self._protocol.transport

@property
def protocol(self) -> BaseProtocol:
Expand All @@ -71,10 +70,10 @@ def _write(self, chunk: bytes) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size

if self._transport is None or self._transport.is_closing():
transport = self.transport
if not self._protocol.connected or transport is None or transport.is_closing():
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
raise ConnectionResetError("Cannot write to closing transport")
self._transport.write(chunk)
transport.write(chunk)

async def write(
self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
Expand Down Expand Up @@ -159,7 +158,6 @@ async def write_eof(self, chunk: bytes = b"") -> None:
await self.drain()

self._eof = True
self._transport = None

async def drain(self) -> None:
"""Flush the write buffer.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ async def test_connection_lost_not_paused() -> None:
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
assert pr.connected
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost
assert not pr.connected


async def test_connection_lost_paused_without_waiter() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
assert pr.connected
pr.pause_writing()
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost
assert not pr.connected


async def test_drain_lost() -> None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,17 @@ async def test_eof_received(loop: Any) -> None:
assert proto._read_timeout_handle is not None
proto.eof_received()
assert proto._read_timeout_handle is None


async def test_connection_lost_sets_transport_to_none(loop: Any, mocker: Any) -> None:
"""Ensure that the transport is set to None when the connection is lost.

This ensures the writer knows that the connection is closed.
"""
proto = ResponseHandler(loop=loop)
proto.connection_made(mocker.Mock())
assert proto.transport is not None

proto.connection_lost(OSError())

assert proto.transport is None
17 changes: 17 additions & 0 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,23 @@ async def test_write_to_closing_transport(
await msg.write(b"After closing")


async def test_write_to_closed_transport(
protocol: Any, transport: Any, loop: Any
) -> None:
bdraco marked this conversation as resolved.
Show resolved Hide resolved
"""Test that writing to a closed transport raises ConnectionResetError.

The StreamWriter checks to see if protocol.transport is None before
writing to the transport. If it is None, it raises ConnectionResetError.
"""
msg = http.StreamWriter(protocol, loop)

await msg.write(b"Before transport close")
protocol.transport = None
bdraco marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"):
await msg.write(b"After transport closed")


async def test_drain(protocol: Any, transport: Any, loop: Any) -> None:
msg = http.StreamWriter(protocol, loop)
await msg.drain()
Expand Down