From 8b919d30a2339d3b9dbc53e8958f745b0f9703ea Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 19 Dec 2025 13:30:14 -1000 Subject: [PATCH] Add decode_text parameter to WebSocket for receiving TEXT as bytes (#11764) Co-authored-by: Sam Bull Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGES/11763.feature.rst | 1 + CHANGES/11764.feature.rst | 1 + aiohttp/_websocket/models.py | 46 ++++-- aiohttp/_websocket/reader_c.pxd | 2 + aiohttp/_websocket/reader_py.py | 38 +++-- aiohttp/client.py | 134 +++++++++++++-- aiohttp/client_ws.py | 108 +++++++++++-- aiohttp/http.py | 4 + aiohttp/http_websocket.py | 6 + aiohttp/test_utils.py | 51 +++++- aiohttp/web_ws.py | 116 +++++++++++-- docs/client_reference.rst | 11 +- docs/web_reference.rst | 10 +- pyproject.toml | 1 + requirements/base.txt | 3 +- requirements/runtime-deps.in | 1 + requirements/runtime-deps.txt | 5 +- requirements/test.txt | 3 +- tests/test_client_ws_functional.py | 205 ++++++++++++++++++++++- tests/test_web_websocket_functional.py | 216 ++++++++++++++++++++++++- 20 files changed, 883 insertions(+), 79 deletions(-) create mode 100644 CHANGES/11763.feature.rst create mode 120000 CHANGES/11764.feature.rst diff --git a/CHANGES/11763.feature.rst b/CHANGES/11763.feature.rst new file mode 100644 index 00000000000..b34bfafaca8 --- /dev/null +++ b/CHANGES/11763.feature.rst @@ -0,0 +1 @@ +Added ``decode_text`` parameter to :meth:`~aiohttp.ClientSession.ws_connect` and :class:`~aiohttp.web.WebSocketResponse` to receive WebSocket TEXT messages as raw bytes instead of decoded strings, enabling direct use with high-performance JSON parsers like ``orjson`` -- by :user:`bdraco`. diff --git a/CHANGES/11764.feature.rst b/CHANGES/11764.feature.rst new file mode 120000 index 00000000000..0860becd808 --- /dev/null +++ b/CHANGES/11764.feature.rst @@ -0,0 +1 @@ +11763.feature.rst \ No newline at end of file diff --git a/aiohttp/_websocket/models.py b/aiohttp/_websocket/models.py index 085fb460cb5..3d7e6d7d5ac 100644 --- a/aiohttp/_websocket/models.py +++ b/aiohttp/_websocket/models.py @@ -3,7 +3,7 @@ import json from collections.abc import Callable from enum import IntEnum -from typing import Any, Final, Literal, NamedTuple, Union, cast +from typing import Any, Final, Literal, NamedTuple, cast WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF]) @@ -59,6 +59,19 @@ def json( return loads(self.data) +class WSMessageTextBytes(NamedTuple): + """WebSocket TEXT message with raw bytes (no UTF-8 decoding).""" + + data: bytes + size: int + extra: str | None = None + type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT + + def json(self, *, loads: Callable[[bytes], Any] = json.loads) -> Any: + """Return parsed JSON data.""" + return loads(self.data) + + class WSMessageBinary(NamedTuple): data: bytes size: int @@ -114,17 +127,26 @@ class WSMessageError(NamedTuple): type: Literal[WSMsgType.ERROR] = WSMsgType.ERROR -WSMessage = Union[ - WSMessageContinuation, - WSMessageText, - WSMessageBinary, - WSMessagePing, - WSMessagePong, - WSMessageClose, - WSMessageClosing, - WSMessageClosed, - WSMessageError, -] +# Base message types (excluding TEXT variants) +_WSMessageBase = ( + WSMessageContinuation + | WSMessageBinary + | WSMessagePing + | WSMessagePong + | WSMessageClose + | WSMessageClosing + | WSMessageClosed + | WSMessageError +) + +# All message types +WSMessage = _WSMessageBase | WSMessageText | WSMessageTextBytes + +# Message type when decode_text=True (default) - TEXT messages have str data +WSMessageDecodeText = _WSMessageBase | WSMessageText + +# Message type when decode_text=False - TEXT messages have bytes data +WSMessageNoDecodeText = _WSMessageBase | WSMessageTextBytes WS_CLOSED_MESSAGE = WSMessageClosed() WS_CLOSING_MESSAGE = WSMessageClosing() diff --git a/aiohttp/_websocket/reader_c.pxd b/aiohttp/_websocket/reader_c.pxd index 9a6fdae3e97..7e5e46f13c7 100644 --- a/aiohttp/_websocket/reader_c.pxd +++ b/aiohttp/_websocket/reader_c.pxd @@ -27,6 +27,7 @@ cdef object TUPLE_NEW cdef object WSMsgType cdef object WSMessageText +cdef object WSMessageTextBytes cdef object WSMessageBinary cdef object WSMessagePing cdef object WSMessagePong @@ -66,6 +67,7 @@ cdef class WebSocketReader: cdef WebSocketDataQueue queue cdef unsigned int _max_msg_size + cdef bint _decode_text cdef Exception _exc cdef bytearray _partial diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index 5bcc2ecfb78..e0088a47af8 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -20,6 +20,7 @@ WSMessagePing, WSMessagePong, WSMessageText, + WSMessageTextBytes, WSMsgType, ) @@ -139,10 +140,15 @@ def _read_from_buffer(self) -> WSMessage: class WebSocketReader: def __init__( - self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True + self, + queue: WebSocketDataQueue, + max_msg_size: int, + compress: bool = True, + decode_text: bool = True, ) -> None: self.queue = queue self._max_msg_size = max_msg_size + self._decode_text = decode_text self._exc: Exception | None = None self._partial = bytearray() @@ -270,18 +276,24 @@ def _handle_frame( size = len(payload_merged) if opcode == OP_CODE_TEXT: - try: - text = payload_merged.decode("utf-8") - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" - ) from exc - - # XXX: The Text and Binary messages here can be a performance - # bottleneck, so we use tuple.__new__ to improve performance. - # This is not type safe, but many tests should fail in - # test_client_ws_functional.py if this is wrong. - msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) + if self._decode_text: + try: + text = payload_merged.decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + + # XXX: The Text and Binary messages here can be a performance + # bottleneck, so we use tuple.__new__ to improve performance. + # This is not type safe, but many tests should fail in + # test_client_ws_functional.py if this is wrong. + msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) + else: + # Return raw bytes for TEXT messages when decode_text=False + msg = TUPLE_NEW( + WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT) + ) else: msg = TUPLE_NEW( WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY) diff --git a/aiohttp/client.py b/aiohttp/client.py index fca569e3ec4..b7b5c8a7acb 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -21,7 +21,17 @@ ) from contextlib import suppress from types import TracebackType -from typing import TYPE_CHECKING, Any, Final, Generic, TypedDict, TypeVar, final +from typing import ( + TYPE_CHECKING, + Any, + Final, + Generic, + Literal, + TypedDict, + TypeVar, + final, + overload, +) from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr from yarl import URL, Query @@ -187,6 +197,27 @@ class _RequestOptions(TypedDict, total=False): middlewares: Sequence[ClientMiddlewareType] | None +class _WSConnectOptions(TypedDict, total=False): + method: str + protocols: Collection[str] + timeout: "ClientWSTimeout | _SENTINEL" + receive_timeout: float | None + autoclose: bool + autoping: bool + heartbeat: float | None + auth: BasicAuth | None + origin: str | None + params: Query + headers: LooseHeaders | None + proxy: StrOrURL | None + proxy_auth: BasicAuth | None + ssl: SSLContext | bool | Fingerprint + server_hostname: str | None + proxy_headers: LooseHeaders | None + compress: int + max_msg_size: int + + @frozen_dataclass_decorator class ClientTimeout: total: float | None = None @@ -215,7 +246,11 @@ class ClientTimeout: # https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) -_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse) +_RetType_co = TypeVar( + "_RetType_co", + bound="ClientResponse | ClientWebSocketResponse[bool]", + covariant=True, +) _CharsetResolver = Callable[[ClientResponse, bytes], str] @@ -866,6 +901,35 @@ async def _connect_and_send_request( ) raise + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + @overload + def ws_connect( + self, + url: StrOrURL, + *, + decode_text: Literal[True] = ..., + **kwargs: Unpack[_WSConnectOptions], + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ... + + @overload + def ws_connect( + self, + url: StrOrURL, + *, + decode_text: Literal[False], + **kwargs: Unpack[_WSConnectOptions], + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... + + @overload + def ws_connect( + self, + url: StrOrURL, + *, + decode_text: bool = ..., + **kwargs: Unpack[_WSConnectOptions], + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ... + def ws_connect( self, url: StrOrURL, @@ -888,7 +952,8 @@ def ws_connect( proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, - ) -> "_WSRequestContextManager": + decode_text: bool = True, + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": """Initiate websocket connection.""" return _WSRequestContextManager( self._ws_connect( @@ -911,9 +976,39 @@ def ws_connect( proxy_headers=proxy_headers, compress=compress, max_msg_size=max_msg_size, + decode_text=decode_text, ) ) + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + @overload + async def _ws_connect( + self, + url: StrOrURL, + *, + decode_text: Literal[True] = ..., + **kwargs: Unpack[_WSConnectOptions], + ) -> "ClientWebSocketResponse[Literal[True]]": ... + + @overload + async def _ws_connect( + self, + url: StrOrURL, + *, + decode_text: Literal[False], + **kwargs: Unpack[_WSConnectOptions], + ) -> "ClientWebSocketResponse[Literal[False]]": ... + + @overload + async def _ws_connect( + self, + url: StrOrURL, + *, + decode_text: bool = ..., + **kwargs: Unpack[_WSConnectOptions], + ) -> "ClientWebSocketResponse[bool]": ... + async def _ws_connect( self, url: StrOrURL, @@ -936,7 +1031,8 @@ async def _ws_connect( proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, - ) -> ClientWebSocketResponse: + decode_text: bool = True, + ) -> "ClientWebSocketResponse[bool]": if timeout is not sentinel: if isinstance(timeout, ClientWSTimeout): ws_timeout = timeout @@ -1098,7 +1194,9 @@ async def _ws_connect( transport = conn.transport assert transport is not None reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop) - conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader) + conn_proto.set_parser( + WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader + ) writer = WebSocketWriter( conn_proto, transport, @@ -1373,31 +1471,33 @@ async def __aexit__( await self.close() -class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]): +class _BaseRequestContextManager( + Coroutine[Any, Any, _RetType_co], Generic[_RetType_co] +): __slots__ = ("_coro", "_resp") - def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None: - self._coro: Coroutine[asyncio.Future[Any], None, _RetType] = coro + def __init__(self, coro: Coroutine[asyncio.Future[Any], None, _RetType_co]) -> None: + self._coro: Coroutine[asyncio.Future[Any], None, _RetType_co] = coro - def send(self, arg: None) -> "asyncio.Future[Any]": + def send(self, arg: None) -> asyncio.Future[Any]: return self._coro.send(arg) - def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]": + def throw(self, *args: Any, **kwargs: Any) -> asyncio.Future[Any]: return self._coro.throw(*args, **kwargs) def close(self) -> None: return self._coro.close() - def __await__(self) -> Generator[Any, None, _RetType]: + def __await__(self) -> Generator[Any, None, _RetType_co]: ret = self._coro.__await__() return ret - def __iter__(self) -> Generator[Any, None, _RetType]: + def __iter__(self) -> Generator[Any, None, _RetType_co]: return self.__await__() - async def __aenter__(self) -> _RetType: - self._resp: _RetType = await self._coro - return await self._resp.__aenter__() + async def __aenter__(self) -> _RetType_co: + self._resp: _RetType_co = await self._coro + return await self._resp.__aenter__() # type: ignore[return-value] async def __aexit__( self, @@ -1409,7 +1509,7 @@ async def __aexit__( _RequestContextManager = _BaseRequestContextManager[ClientResponse] -_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse] +_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse[bool]] class _SessionRequestContextManager: @@ -1417,7 +1517,7 @@ class _SessionRequestContextManager: def __init__( self, - coro: Coroutine["asyncio.Future[Any]", None, ClientResponse], + coro: Coroutine[asyncio.Future[Any], None, ClientResponse], session: ClientSession, ) -> None: self._coro = coro diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 36959aae0c7..f2e92149e55 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -2,8 +2,9 @@ import asyncio import sys +from collections.abc import Callable from types import TracebackType -from typing import Any, Final +from typing import Any, Final, Generic, Literal, overload from ._websocket.reader import WebSocketDataQueue from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError @@ -14,7 +15,8 @@ WS_CLOSING_MESSAGE, WebSocketError, WSCloseCode, - WSMessage, + WSMessageDecodeText, + WSMessageNoDecodeText, WSMsgType, ) from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter, WSMessageError @@ -26,10 +28,21 @@ JSONEncoder, ) +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + if sys.version_info >= (3, 11): import asyncio as async_timeout + from typing import Self else: import async_timeout + from typing_extensions import Self + +# TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) +# Covariant because it only affects return types, not input types +_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) @frozen_dataclass_decorator @@ -43,7 +56,7 @@ class ClientWSTimeout: ) -class ClientWebSocketResponse: +class ClientWebSocketResponse(Generic[_DecodeText]): def __init__( self, reader: WebSocketDataQueue, @@ -309,7 +322,24 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo self._response.close() return True - async def receive(self, timeout: float | None = None) -> WSMessage: + @overload + async def receive( + self: "ClientWebSocketResponse[Literal[True]]", timeout: float | None = None + ) -> WSMessageDecodeText: ... + + @overload + async def receive( + self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None + ) -> WSMessageNoDecodeText: ... + + @overload + async def receive( + self: "ClientWebSocketResponse[_DecodeText]", timeout: float | None = None + ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... + + async def receive( + self, timeout: float | None = None + ) -> WSMessageDecodeText | WSMessageNoDecodeText: receive_timeout = timeout or self._timeout.ws_receive while True: @@ -383,7 +413,26 @@ async def receive(self, timeout: float | None = None) -> WSMessage: return msg - async def receive_str(self, *, timeout: float | None = None) -> str: + @overload + async def receive_str( + self: "ClientWebSocketResponse[Literal[True]]", *, timeout: float | None = None + ) -> str: ... + + @overload + async def receive_str( + self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None + ) -> bytes: ... + + @overload + async def receive_str( + self: "ClientWebSocketResponse[_DecodeText]", *, timeout: float | None = None + ) -> str | bytes: ... + + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: + """Receive TEXT message. + + Returns str when decode_text=True (default), bytes when decode_text=False. + """ msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( @@ -399,25 +448,64 @@ async def receive_bytes(self, *, timeout: float | None = None) -> bytes: ) return msg.data + @overload + async def receive_json( + self: "ClientWebSocketResponse[Literal[True]]", + *, + loads: JSONDecoder = ..., + timeout: float | None = None, + ) -> Any: ... + + @overload + async def receive_json( + self: "ClientWebSocketResponse[Literal[False]]", + *, + loads: Callable[[bytes], Any] = ..., + timeout: float | None = None, + ) -> Any: ... + + @overload + async def receive_json( + self: "ClientWebSocketResponse[_DecodeText]", + *, + loads: JSONDecoder | Callable[[bytes], Any] = ..., + timeout: float | None = None, + ) -> Any: ... + async def receive_json( self, *, - loads: JSONDecoder = DEFAULT_JSON_DECODER, + loads: JSONDecoder | Callable[[bytes], Any] = DEFAULT_JSON_DECODER, timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) - return loads(data) + return loads(data) # type: ignore[arg-type] - def __aiter__(self) -> "ClientWebSocketResponse": + def __aiter__(self) -> Self: return self - async def __anext__(self) -> WSMessage: + @overload + async def __anext__( + self: "ClientWebSocketResponse[Literal[True]]", + ) -> WSMessageDecodeText: ... + + @overload + async def __anext__( + self: "ClientWebSocketResponse[Literal[False]]", + ) -> WSMessageNoDecodeText: ... + + @overload + async def __anext__( + self: "ClientWebSocketResponse[_DecodeText]", + ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... + + async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): raise StopAsyncIteration return msg - async def __aenter__(self) -> "ClientWebSocketResponse": + async def __aenter__(self) -> Self: return self async def __aexit__( diff --git a/aiohttp/http.py b/aiohttp/http.py index 6dad94bb11c..9d50377edf6 100644 --- a/aiohttp/http.py +++ b/aiohttp/http.py @@ -19,6 +19,8 @@ WebSocketWriter, WSCloseCode, WSMessage, + WSMessageDecodeText, + WSMessageNoDecodeText, WSMsgType, ws_ext_gen, ws_ext_parse, @@ -49,6 +51,8 @@ "ws_ext_gen", "ws_ext_parse", "WSMessage", + "WSMessageDecodeText", + "WSMessageNoDecodeText", "WebSocketError", "WSMsgType", "WSCloseCode", diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index f49d8aee287..bc6b387c6b3 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -13,10 +13,13 @@ WSMessageClosed, WSMessageClosing, WSMessageContinuation, + WSMessageDecodeText, WSMessageError, + WSMessageNoDecodeText, WSMessagePing, WSMessagePong, WSMessageText, + WSMessageTextBytes, WSMsgType, ) from ._websocket.reader import WebSocketReader @@ -35,6 +38,8 @@ "WebSocketReader", "WebSocketWriter", "WSMessage", + "WSMessageDecodeText", + "WSMessageNoDecodeText", "WebSocketError", "WSMsgType", "WSCloseCode", @@ -48,6 +53,7 @@ "WSMessagePong", "WSMessageBinary", "WSMessageText", + "WSMessageTextBytes", "WSMessagePing", "WSMessageContinuation", ) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index 192173b42c8..c333f6a2236 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterator from types import TracebackType -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload from unittest import IsolatedAsyncioTestCase, mock from aiosignal import Signal @@ -19,6 +19,7 @@ import aiohttp from aiohttp.client import ( + _BaseRequestContextManager, _RequestContextManager, _RequestOptions, _WSRequestContextManager, @@ -286,7 +287,7 @@ def __init__( # type: ignore[misc] self._session._retry_connection = False self._closed = False self._responses: list[ClientResponse] = [] - self._websockets: list[ClientWebSocketResponse] = [] + self._websockets: list[ClientWebSocketResponse[bool]] = [] async def start_server(self) -> None: await self._server.start_server() @@ -429,18 +430,54 @@ def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: self._request(hdrs.METH_DELETE, path, **kwargs) ) - def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: + @overload + def ws_connect( + self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ... + + @overload + def ws_connect( + self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... + + @overload + def ws_connect( + self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ... + + def ws_connect( + self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": """Initiate websocket connection. The api corresponds to aiohttp.ClientSession.ws_connect. """ - return _WSRequestContextManager(self._ws_connect(path, **kwargs)) + return _WSRequestContextManager( + self._ws_connect(path, decode_text=decode_text, **kwargs) + ) + @overload async def _ws_connect( - self, path: StrOrURL, **kwargs: Any - ) -> ClientWebSocketResponse: - ws = await self._session.ws_connect(self.make_url(path), **kwargs) + self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any + ) -> "ClientWebSocketResponse[Literal[True]]": ... + + @overload + async def _ws_connect( + self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any + ) -> "ClientWebSocketResponse[Literal[False]]": ... + + @overload + async def _ws_connect( + self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any + ) -> "ClientWebSocketResponse[bool]": ... + + async def _ws_connect( + self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any + ) -> "ClientWebSocketResponse[bool]": + ws = await self._session.ws_connect( + self.make_url(path), decode_text=decode_text, **kwargs + ) self._websockets.append(ws) return ws diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 8eee7e3ad71..d55d3687d92 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -4,8 +4,8 @@ import hashlib import json import sys -from collections.abc import Iterable -from typing import Any, Final, Union +from collections.abc import Callable, Iterable +from typing import Any, Final, Generic, Literal, Union, overload from multidict import CIMultiDict @@ -28,7 +28,8 @@ WebSocketReader, WebSocketWriter, WSCloseCode, - WSMessage, + WSMessageDecodeText, + WSMessageNoDecodeText, WSMsgType, ws_ext_gen, ws_ext_parse, @@ -41,10 +42,17 @@ from .web_request import BaseRequest from .web_response import StreamResponse +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + if sys.version_info >= (3, 11): import asyncio as async_timeout + from typing import Self else: import async_timeout + from typing_extensions import Self __all__ = ( "WebSocketResponse", @@ -54,6 +62,9 @@ THRESHOLD_CONNLOST_ACCESS: Final[int] = 5 +# TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) +_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) + @frozen_dataclass_decorator class WebSocketReady: @@ -64,7 +75,7 @@ def __bool__(self) -> bool: return self.ok -class WebSocketResponse(StreamResponse): +class WebSocketResponse(StreamResponse, Generic[_DecodeText]): _length_check: bool = False _ws_protocol: str | None = None @@ -95,6 +106,7 @@ def __init__( compress: bool = True, max_msg_size: int = 4 * 1024 * 1024, writer_limit: int = DEFAULT_LIMIT, + decode_text: bool = True, ) -> None: super().__init__(status=101) self._protocols = protocols @@ -108,6 +120,7 @@ def __init__( self._compress: bool | int = compress self._max_msg_size = max_msg_size self._writer_limit = writer_limit + self._decode_text = decode_text def _cancel_heartbeat(self) -> None: self._cancel_pong_response_cb() @@ -341,7 +354,10 @@ def _post_start( self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop) request.protocol.set_parser( WebSocketReader( - self._reader, self._max_msg_size, compress=bool(self._compress) + self._reader, + self._max_msg_size, + compress=bool(self._compress), + decode_text=self._decode_text, ) ) # disable HTTP keepalive for WebSocket @@ -514,7 +530,24 @@ def _close_transport(self) -> None: if self._req is not None and self._req.transport is not None: self._req.transport.close() - async def receive(self, timeout: float | None = None) -> WSMessage: + @overload + async def receive( + self: "WebSocketResponse[Literal[True]]", timeout: float | None = None + ) -> WSMessageDecodeText: ... + + @overload + async def receive( + self: "WebSocketResponse[Literal[False]]", timeout: float | None = None + ) -> WSMessageNoDecodeText: ... + + @overload + async def receive( + self: "WebSocketResponse[_DecodeText]", timeout: float | None = None + ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... + + async def receive( + self, timeout: float | None = None + ) -> WSMessageDecodeText | WSMessageNoDecodeText: if self._reader is None: raise RuntimeError("Call .prepare() first") @@ -588,7 +621,26 @@ async def receive(self, timeout: float | None = None) -> WSMessage: return msg - async def receive_str(self, *, timeout: float | None = None) -> str: + @overload + async def receive_str( + self: "WebSocketResponse[Literal[True]]", *, timeout: float | None = None + ) -> str: ... + + @overload + async def receive_str( + self: "WebSocketResponse[Literal[False]]", *, timeout: float | None = None + ) -> bytes: ... + + @overload + async def receive_str( + self: "WebSocketResponse[_DecodeText]", *, timeout: float | None = None + ) -> str | bytes: ... + + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: + """Receive TEXT message. + + Returns str when decode_text=True (default), bytes when decode_text=False. + """ msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( @@ -604,21 +656,63 @@ async def receive_bytes(self, *, timeout: float | None = None) -> bytes: ) return msg.data + @overload async def receive_json( - self, *, loads: JSONDecoder = json.loads, timeout: float | None = None + self: "WebSocketResponse[Literal[True]]", + *, + loads: JSONDecoder = ..., + timeout: float | None = None, + ) -> Any: ... + + @overload + async def receive_json( + self: "WebSocketResponse[Literal[False]]", + *, + loads: Callable[[bytes], Any] = ..., + timeout: float | None = None, + ) -> Any: ... + + @overload + async def receive_json( + self: "WebSocketResponse[_DecodeText]", + *, + loads: JSONDecoder | Callable[[bytes], Any] = ..., + timeout: float | None = None, + ) -> Any: ... + + async def receive_json( + self, + *, + loads: JSONDecoder | Callable[[bytes], Any] = json.loads, + timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) - return loads(data) + return loads(data) # type: ignore[arg-type] async def write( self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] ) -> None: raise RuntimeError("Cannot call .write() for websocket") - def __aiter__(self) -> "WebSocketResponse": + def __aiter__(self) -> Self: return self - async def __anext__(self) -> WSMessage: + @overload + async def __anext__( + self: "WebSocketResponse[Literal[True]]", + ) -> WSMessageDecodeText: ... + + @overload + async def __anext__( + self: "WebSocketResponse[Literal[False]]", + ) -> WSMessageNoDecodeText: ... + + @overload + async def __anext__( + self: "WebSocketResponse[_DecodeText]", + ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... + + async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): raise StopAsyncIteration diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 50d158b5c2a..ab52fbfa5fb 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -722,7 +722,8 @@ The client session supports the context manager protocol for self closing. proxy=None, proxy_auth=None, ssl=True, \ verify_ssl=None, fingerprint=None, \ ssl_context=None, proxy_headers=None, \ - compress=0, max_msg_size=4194304) + compress=0, max_msg_size=4194304, \ + decode_text=True) :async: Create a websocket connection. Returns a @@ -851,6 +852,14 @@ The client session supports the context manager protocol for self closing. .. versionadded:: 3.5 + :param bool decode_text: If ``True`` (default), TEXT messages are + decoded to strings. If ``False``, TEXT messages + are returned as raw bytes, which can improve + performance when using JSON parsers like + ``orjson`` that accept bytes directly. + + .. versionadded:: 3.14 + .. method:: close() :async: diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 048d798f8c1..01b237f1b0a 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -939,7 +939,7 @@ and :ref:`aiohttp-web-signals` handlers:: .. class:: WebSocketResponse(*, timeout=10.0, receive_timeout=None, \ autoclose=True, autoping=True, heartbeat=None, \ protocols=(), compress=True, max_msg_size=4194304, \ - writer_limit=65536) + writer_limit=65536, decode_text=True) Class for handling server-side websockets, inherited from :class:`StreamResponse`. @@ -1002,6 +1002,14 @@ and :ref:`aiohttp-web-signals` handlers:: .. versionadded:: 3.11 + :param bool decode_text: If ``True`` (default), TEXT messages are + decoded to strings. If ``False``, TEXT messages + are returned as raw bytes, which can improve + performance when using JSON parsers like + ``orjson`` that accept bytes directly. + + .. versionadded:: 3.14 + The class supports ``async for`` statement for iterating over incoming messages:: diff --git a/pyproject.toml b/pyproject.toml index 8b707ddc4cb..0cfa7a3221b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "frozenlist >= 1.1.1", "multidict >=4.5, < 7.0", "propcache >= 0.2.0", + "typing_extensions >= 4.4 ; python_version < '3.13'", "yarl >= 1.17.0, < 2.0", ] dynamic = [ diff --git a/requirements/base.txt b/requirements/base.txt index aded022bbca..eb87dd6da6e 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -40,8 +40,9 @@ pycares==4.11.0 # via aiodns pycparser==2.23 # via cffi -typing-extensions==4.15.0 +typing-extensions==4.15.0 ; python_version < "3.13" # via + # -r requirements/runtime-deps.in # aiosignal # multidict uvloop==0.21.0 ; platform_system != "Windows" and implementation_name == "cpython" diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index 0be3bb7f98f..16515e7551a 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -10,4 +10,5 @@ brotlicffi; platform_python_implementation != 'CPython' frozenlist >= 1.1.1 multidict >=4.5, < 7.0 propcache >= 0.2.0 +typing_extensions >= 4.4 ; python_version < '3.13' yarl >= 1.17.0, < 2.0 diff --git a/requirements/runtime-deps.txt b/requirements/runtime-deps.txt index f45d006b614..e02c165910f 100644 --- a/requirements/runtime-deps.txt +++ b/requirements/runtime-deps.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.12 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/runtime-deps.txt --strip-extras requirements/runtime-deps.in @@ -36,8 +36,9 @@ pycares==4.11.0 # via aiodns pycparser==2.23 # via cffi -typing-extensions==4.15.0 +typing-extensions==4.15.0 ; python_version < "3.13" # via + # -r requirements/runtime-deps.in # aiosignal # multidict yarl==1.22.0 diff --git a/requirements/test.txt b/requirements/test.txt index 42f6caee201..7f862c593ac 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -131,8 +131,9 @@ tomli==2.3.0 # pytest trustme==1.2.1 ; platform_machine != "i686" # via -r requirements/test-common.in -typing-extensions==4.15.0 +typing-extensions==4.15.0 ; python_version < "3.13" # via + # -r requirements/runtime-deps.in # aiosignal # cryptography # exceptiongroup diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 0bc05f300d4..3cefbb26d3d 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,7 +1,7 @@ import asyncio import json import sys -from typing import NoReturn +from typing import Literal, NoReturn from unittest import mock import pytest @@ -1277,7 +1277,7 @@ async def handler(request: web.Request) -> NoReturn: app = web.Application() app.router.add_route("GET", "/", handler) - sync_future: asyncio.Future[list[aiohttp.ClientWebSocketResponse]] = ( + sync_future: asyncio.Future[list[aiohttp.ClientWebSocketResponse[bool]]] = ( loop.create_future() ) client = await aiohttp_client(app) @@ -1305,3 +1305,204 @@ async def websocket_task() -> None: # Cleanup properly websocket._response = mock.Mock() await websocket.close() + + +async def test_receive_text_as_bytes_client_side(aiohttp_client: AiohttpClient) -> None: + """Test client receiving TEXT messages as raw bytes with decode_text=False.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + await ws.send_str(msg + "/answer") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + # Connect with decode_text=False + resp = await client.ws_connect("/", decode_text=False) + await resp.send_str("ask") + + # Receive TEXT message as bytes + msg = await resp.receive() + assert msg.type is WSMsgType.TEXT + assert isinstance(msg.data, bytes) + assert msg.data == b"ask/answer" + + await resp.close() + + +async def test_receive_text_as_bytes_server_side(aiohttp_client: AiohttpClient) -> None: + """Test server receiving TEXT messages as raw bytes with decode_text=False.""" + + async def handler(request: web.Request) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) + await ws.prepare(request) + + # Receive TEXT message as bytes + msg = await ws.receive() + assert msg.type is WSMsgType.TEXT + assert isinstance(msg.data, bytes) + assert msg.data == b"test message" + + # Send response + await ws.send_bytes(msg.data + b"/reply") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + resp = await client.ws_connect("/") + await resp.send_str("test message") + + msg = await resp.receive() + assert msg.type is WSMsgType.BINARY + assert msg.data == b"test message/reply" + + await resp.close() + + +async def test_receive_text_as_bytes_json_parsing( + aiohttp_client: AiohttpClient, +) -> None: + """Test using orjson or similar parsers with raw bytes from TEXT messages.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + data = json.loads(msg) + await ws.send_str(json.dumps({"response": data["value"] * 2})) + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + # Connect with decode_text=False to get raw bytes + resp = await client.ws_connect("/", decode_text=False) + await resp.send_str(json.dumps({"value": 42})) + + # Receive TEXT message as bytes + msg = await resp.receive() + assert msg.type is WSMsgType.TEXT + assert isinstance(msg.data, bytes) + + # Parse JSON using msg.json() method (covers WSMessageTextBytes.json()) + data = msg.json() + assert data == {"response": 84} + + await resp.close() + + +async def test_decode_text_default_true(aiohttp_client: AiohttpClient) -> None: + """Test that decode_text defaults to True for backward compatibility.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + await ws.send_str(msg + "/reply") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + # Default behavior (decode_text=True) + resp = await client.ws_connect("/") + await resp.send_str("test") + + # Should receive TEXT message as string + msg = await resp.receive() + assert msg.type is WSMsgType.TEXT + assert isinstance(msg.data, str) + assert msg.data == "test/reply" + + await resp.close() + + +async def test_receive_str_returns_bytes_with_decode_text_false( + aiohttp_client: AiohttpClient, +) -> None: + """Test that receive_str() returns bytes when decode_text=False.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_str("hello world") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/", decode_text=False) as ws: + # receive_str() should return bytes when decode_text=False + data = await ws.receive_str() + assert isinstance(data, bytes) + assert data == b"hello world" + + +async def test_receive_str_returns_str_with_decode_text_true( + aiohttp_client: AiohttpClient, +) -> None: + """Test that receive_str() returns str when decode_text=True (default).""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_str("hello world") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + # receive_str() should return str when decode_text=True (default) + data = await ws.receive_str() + assert isinstance(data, str) + assert data == "hello world" + + +async def test_receive_json_with_orjson_style_loads( + aiohttp_client: AiohttpClient, +) -> None: + """Test receive_json() with orjson-style loads that accepts bytes.""" + + def orjson_style_loads(data: bytes) -> dict[str, int]: + """Mock orjson.loads that accepts bytes.""" + assert isinstance(data, bytes) + result: dict[str, int] = json.loads(data) + return result + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_str('{"value": 42}') + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/", decode_text=False) as ws: + # receive_json() with orjson-style loads should work with bytes + data = await ws.receive_json(loads=orjson_style_loads) + assert data == {"value": 42} diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index afa76e2d742..0e41faa21f2 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -2,9 +2,10 @@ import asyncio import contextlib +import json import sys import weakref -from typing import NoReturn +from typing import Literal, NoReturn from unittest import mock import pytest @@ -1445,3 +1446,216 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert msg.type is WSMsgType.TEXT assert msg.data == "test" await ws.close() + + +async def test_receive_text_as_bytes_server_side(aiohttp_client: AiohttpClient) -> None: + """Test server receiving TEXT messages as raw bytes with decode_text=False.""" + + async def websocket_handler( + request: web.Request, + ) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) + await ws.prepare(request) + + # Receive TEXT message as bytes + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert isinstance(msg.data, bytes) + assert msg.data == b"test message" + + # Send response + await ws.send_bytes(msg.data + b"/reply") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + await ws.send_str("test message") + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.BINARY + assert msg.data == b"test message/reply" + + await ws.close() + + +async def test_receive_text_as_bytes_server_iteration( + aiohttp_client: AiohttpClient, +) -> None: + """Test server iterating over WebSocket with decode_text=False.""" + + async def websocket_handler( + request: web.Request, + ) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) + await ws.prepare(request) + + async for msg in ws: + if msg.type is aiohttp.WSMsgType.TEXT: + # msg.data should be bytes + assert isinstance(msg.data, bytes) + # Echo back + await ws.send_bytes(msg.data) + else: + assert msg.type is aiohttp.WSMsgType.BINARY + assert isinstance(msg.data, bytes) + await ws.send_bytes(msg.data) + + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + # Send TEXT message + await ws.send_str("hello") + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.BINARY + assert msg.data == b"hello" + + # Send BINARY message + await ws.send_bytes(b"world") + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.BINARY + assert msg.data == b"world" + + await ws.close() + + +async def test_server_decode_text_default_true(aiohttp_client: AiohttpClient) -> None: + """Test that server decode_text defaults to True for backward compatibility.""" + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + # No decode_text parameter - should default to True + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert isinstance(msg.data, str) + assert msg.data == "test" + + await ws.send_str(msg.data + "/reply") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + await ws.send_str("test") + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert isinstance(msg.data, str) + assert msg.data == "test/reply" + + await ws.close() + + +async def test_server_receive_str_returns_bytes_with_decode_text_false( + aiohttp_client: AiohttpClient, +) -> None: + """Test that server receive_str() returns bytes when decode_text=False.""" + + async def websocket_handler( + request: web.Request, + ) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) + await ws.prepare(request) + + # receive_str() should return bytes when decode_text=False + data = await ws.receive_str() + assert isinstance(data, bytes) + assert data == b"hello server" + + await ws.send_str("got bytes") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + await ws.send_str("hello server") + msg = await ws.receive() + assert msg.data == "got bytes" + + +async def test_server_receive_str_returns_str_with_decode_text_true( + aiohttp_client: AiohttpClient, +) -> None: + """Test that server receive_str() returns str when decode_text=True (default).""" + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() # decode_text=True by default + await ws.prepare(request) + + # receive_str() should return str when decode_text=True + data = await ws.receive_str() + assert isinstance(data, str) + assert data == "hello server" + + await ws.send_str("got string") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + await ws.send_str("hello server") + msg = await ws.receive() + assert msg.data == "got string" + + +async def test_server_receive_json_with_orjson_style_loads( + aiohttp_client: AiohttpClient, +) -> None: + """Test server receive_json() with orjson-style loads that accepts bytes.""" + + def orjson_style_loads(data: bytes) -> dict[str, str]: + """Mock orjson.loads that accepts bytes.""" + assert isinstance(data, bytes) + result: dict[str, str] = json.loads(data) + return result + + async def websocket_handler( + request: web.Request, + ) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) + await ws.prepare(request) + + # receive_json() with orjson-style loads should work with bytes + data = await ws.receive_json(loads=orjson_style_loads) + assert data == {"test": "value"} + + await ws.send_str("success") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + ws = await client.ws_connect("/") + await ws.send_str('{"test": "value"}') + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert msg.data == "success" + await ws.close()