Skip to content

Commit

Permalink
Let legacy ws clients know about invalid data frames (#3633)
Browse files Browse the repository at this point in the history
* Let ws clients know about invalid data frames

* Use the same close reason across all frameworks

* Add a release file

* Make it a minor release
  • Loading branch information
DoctorJohn committed Sep 19, 2024
1 parent 2941146 commit 63528a5
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 39 deletions.
6 changes: 6 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Release type: minor

Starting with this release, clients using the legacy graphql-ws subprotocol will receive an error when they try to send binary data frames.
Before, binary data frames were silently ignored.

While vaguely defined in the protocol, the legacy graphql-ws subprotocol is generally understood to only support text data frames.
4 changes: 4 additions & 0 deletions strawberry/aiohttp/handlers/graphql_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ async def handle_request(self) -> Any:
if ws_message.type == http.WSMsgType.TEXT:
message: OperationMessage = ws_message.json()
await self.handle_message(message)
else:
await self.close(
code=1002, reason="WebSocket message type must be text"
)
finally:
if self.keep_alive_task:
self.keep_alive_task.cancel()
Expand Down
5 changes: 3 additions & 2 deletions strawberry/asgi/handlers/graphql_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ async def handle_request(self) -> Any:
try:
message = await self._ws.receive_json()
except KeyError: # noqa: PERF203
# Ignore non-text messages
continue
await self.close(
code=1002, reason="WebSocket message type must be text"
)
else:
await self.handle_message(message)
except WebSocketDisconnect: # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion strawberry/channels/handlers/graphql_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def handle_invalid_message(self, error_message: str) -> None:
# This is not part of the BaseGraphQLWSHandler's interface, but the
# channels integration is a high level wrapper that forwards this to
# both us and the BaseGraphQLTransportWSHandler.
pass
await self.close(code=1002, reason=error_message)


__all__ = ["GraphQLWSHandler"]
5 changes: 3 additions & 2 deletions strawberry/channels/handlers/ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ async def receive(self, *args: str, **kwargs: Any) -> None:
# Overriding this so that we can pass the errors to handle_invalid_message
try:
await super().receive(*args, **kwargs)
except ValueError as e:
await self._handler.handle_invalid_message(str(e))
except ValueError:
reason = "WebSocket message type must be text"
await self._handler.handle_invalid_message(reason)

async def receive_json(self, content: Any, **kwargs: Any) -> None:
await self._handler.handle_message(content)
Expand Down
5 changes: 3 additions & 2 deletions strawberry/litestar/handlers/graphql_ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ async def handle_request(self) -> Any:
try:
message = await self._ws.receive_json()
except (SerializationException, ValueError): # noqa: PERF203
# Ignore non-text messages
continue
await self.close(
code=1002, reason="WebSocket message type must be text"
)
else:
await self.handle_message(message)
except WebSocketDisconnect: # pragma: no cover
Expand Down
32 changes: 27 additions & 5 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,35 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):

await ws.send_bytes(json.dumps(ConnectionInitMessage().as_dict()).encode())

data = await ws.receive(timeout=2)
await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
ws.assert_reason("WebSocket message type must be text")


async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_json(ConnectionInitMessage().as_dict())

response = await ws.receive_json()
assert response == ConnectionAckMessage().as_dict()

await ws.send_bytes(
json.dumps(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="subscription { debug { isConnectionInitTimeoutTaskDone } }"
),
).as_dict()
).encode()
)

await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
if ws.name() == "channels":
ws.assert_reason("No text section for incoming WebSocket frame!")
else:
ws.assert_reason("WebSocket message type must be text")
ws.assert_reason("WebSocket message type must be text")


async def test_connection_init_timeout(
Expand Down
54 changes: 27 additions & 27 deletions tests/websockets/test_graphql_ws.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import json
from typing import TYPE_CHECKING, AsyncGenerator
from unittest import mock

Expand Down Expand Up @@ -280,42 +281,41 @@ async def test_subscription_syntax_error(ws: WebSocketClient):
}


async def test_non_text_ws_messages_are_ignored(ws_raw: WebSocketClient):
async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
ws = ws_raw
await ws.send_bytes(b"foo")
await ws.send_json({"type": GQL_CONNECTION_INIT})

await ws.send_bytes(b"bar")
await ws.send_json(
{
"type": GQL_START,
"id": "demo",
"payload": {
"query": 'subscription { echo(message: "Hi") }',
},
}
)
await ws.send_bytes(json.dumps({"type": GQL_CONNECTION_INIT}).encode())

response = await ws.receive_json()
assert response["type"] == GQL_CONNECTION_ACK
await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 1002
ws.assert_reason("WebSocket message type must be text")

response = await ws.receive_json()
assert response["type"] == GQL_DATA
assert response["id"] == "demo"
assert response["payload"]["data"] == {"echo": "Hi"}

await ws.send_bytes(b"gaz")
await ws.send_json({"type": GQL_STOP, "id": "demo"})
async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_json({"type": GQL_CONNECTION_INIT})

response = await ws.receive_json()
assert response["type"] == GQL_COMPLETE
assert response["id"] == "demo"
assert response["type"] == GQL_CONNECTION_ACK

await ws.send_bytes(b"wat")
await ws.send_json({"type": GQL_CONNECTION_TERMINATE})
await ws.send_bytes(
json.dumps(
{
"type": GQL_START,
"id": "demo",
"payload": {
"query": 'subscription { echo(message: "Hi") }',
},
}
).encode()
)

# make sure the WebSocket is disconnected now
await ws.receive(timeout=2) # receive close
await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 1002
ws.assert_reason("WebSocket message type must be text")


async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient):
Expand Down

0 comments on commit 63528a5

Please sign in to comment.