diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..4e02948cbe --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,6 @@ +Release type: minor + +Starting with this release, clients using the legacy graphql-ws subprotocol will receive an error when they try to send binary data frames. +Before, binary data frames were silently ignored. + +While vaguely defined in the protocol, the legacy graphql-ws subprotocol is generally understood to only support text data frames. diff --git a/strawberry/aiohttp/handlers/graphql_ws_handler.py b/strawberry/aiohttp/handlers/graphql_ws_handler.py index a8a80e481c..677dd34884 100644 --- a/strawberry/aiohttp/handlers/graphql_ws_handler.py +++ b/strawberry/aiohttp/handlers/graphql_ws_handler.py @@ -50,6 +50,10 @@ async def handle_request(self) -> Any: if ws_message.type == http.WSMsgType.TEXT: message: OperationMessage = ws_message.json() await self.handle_message(message) + else: + await self.close( + code=1002, reason="WebSocket message type must be text" + ) finally: if self.keep_alive_task: self.keep_alive_task.cancel() diff --git a/strawberry/asgi/handlers/graphql_ws_handler.py b/strawberry/asgi/handlers/graphql_ws_handler.py index 5ef966f63b..00a314bbd0 100644 --- a/strawberry/asgi/handlers/graphql_ws_handler.py +++ b/strawberry/asgi/handlers/graphql_ws_handler.py @@ -51,8 +51,9 @@ async def handle_request(self) -> Any: try: message = await self._ws.receive_json() except KeyError: # noqa: PERF203 - # Ignore non-text messages - continue + await self.close( + code=1002, reason="WebSocket message type must be text" + ) else: await self.handle_message(message) except WebSocketDisconnect: # pragma: no cover diff --git a/strawberry/channels/handlers/graphql_ws_handler.py b/strawberry/channels/handlers/graphql_ws_handler.py index 41ed4dc6ea..6d967a1d15 100644 --- a/strawberry/channels/handlers/graphql_ws_handler.py +++ b/strawberry/channels/handlers/graphql_ws_handler.py @@ -66,7 +66,7 @@ async def handle_invalid_message(self, error_message: str) -> None: # This is not part of the BaseGraphQLWSHandler's interface, but the # channels integration is a high level wrapper that forwards this to # both us and the BaseGraphQLTransportWSHandler. - pass + await self.close(code=1002, reason=error_message) __all__ = ["GraphQLWSHandler"] diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index 7f7929b3f3..2991059afd 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -106,8 +106,9 @@ async def receive(self, *args: str, **kwargs: Any) -> None: # Overriding this so that we can pass the errors to handle_invalid_message try: await super().receive(*args, **kwargs) - except ValueError as e: - await self._handler.handle_invalid_message(str(e)) + except ValueError: + reason = "WebSocket message type must be text" + await self._handler.handle_invalid_message(reason) async def receive_json(self, content: Any, **kwargs: Any) -> None: await self._handler.handle_message(content) diff --git a/strawberry/litestar/handlers/graphql_ws_handler.py b/strawberry/litestar/handlers/graphql_ws_handler.py index d91ec4f8c3..ada421922f 100644 --- a/strawberry/litestar/handlers/graphql_ws_handler.py +++ b/strawberry/litestar/handlers/graphql_ws_handler.py @@ -46,8 +46,9 @@ async def handle_request(self) -> Any: try: message = await self._ws.receive_json() except (SerializationException, ValueError): # noqa: PERF203 - # Ignore non-text messages - continue + await self.close( + code=1002, reason="WebSocket message type must be text" + ) else: await self.handle_message(message) except WebSocketDisconnect: # pragma: no cover diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 8062f25a5c..02f8366852 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -117,13 +117,35 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): await ws.send_bytes(json.dumps(ConnectionInitMessage().as_dict()).encode()) - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4400 + ws.assert_reason("WebSocket message type must be text") + + +async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): + ws = ws_raw + + await ws.send_json(ConnectionInitMessage().as_dict()) + + response = await ws.receive_json() + assert response == ConnectionAckMessage().as_dict() + + await ws.send_bytes( + json.dumps( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { debug { isConnectionInitTimeoutTaskDone } }" + ), + ).as_dict() + ).encode() + ) + + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - if ws.name() == "channels": - ws.assert_reason("No text section for incoming WebSocket frame!") - else: - ws.assert_reason("WebSocket message type must be text") + ws.assert_reason("WebSocket message type must be text") async def test_connection_init_timeout( diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index b16a8d12db..6752eaa7f8 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json from typing import TYPE_CHECKING, AsyncGenerator from unittest import mock @@ -280,42 +281,41 @@ async def test_subscription_syntax_error(ws: WebSocketClient): } -async def test_non_text_ws_messages_are_ignored(ws_raw: WebSocketClient): +async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_bytes(b"foo") - await ws.send_json({"type": GQL_CONNECTION_INIT}) - await ws.send_bytes(b"bar") - await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } - ) + await ws.send_bytes(json.dumps({"type": GQL_CONNECTION_INIT}).encode()) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 1002 + ws.assert_reason("WebSocket message type must be text") - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"echo": "Hi"} - await ws.send_bytes(b"gaz") - await ws.send_json({"type": GQL_STOP, "id": "demo"}) +async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): + ws = ws_raw + + await ws.send_json({"type": GQL_CONNECTION_INIT}) + response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + assert response["type"] == GQL_CONNECTION_ACK - await ws.send_bytes(b"wat") - await ws.send_json({"type": GQL_CONNECTION_TERMINATE}) + await ws.send_bytes( + json.dumps( + { + "type": GQL_START, + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ).encode() + ) - # make sure the WebSocket is disconnected now - await ws.receive(timeout=2) # receive close + await ws.receive(timeout=2) assert ws.closed + assert ws.close_code == 1002 + ws.assert_reason("WebSocket message type must be text") async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient):