Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/11763.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added ``decode_text`` parameter to :meth:`~aiohttp.ClientSession.ws_connect` and :class:`~aiohttp.web.WebSocketResponse` to receive WebSocket TEXT messages as raw bytes instead of decoded strings, enabling direct use with high-performance JSON parsers like ``orjson`` -- by :user:`bdraco`.
1 change: 1 addition & 0 deletions CHANGES/11764.feature.rst
46 changes: 34 additions & 12 deletions aiohttp/_websocket/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from collections.abc import Callable
from enum import IntEnum
from typing import Any, Final, Literal, NamedTuple, Union, cast
from typing import Any, Final, Literal, NamedTuple, cast

WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])

Expand Down Expand Up @@ -59,6 +59,19 @@ def json(
return loads(self.data)


class WSMessageTextBytes(NamedTuple):
"""WebSocket TEXT message with raw bytes (no UTF-8 decoding)."""

data: bytes
size: int
extra: str | None = None
type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT

def json(self, *, loads: Callable[[bytes], Any] = json.loads) -> Any:
"""Return parsed JSON data."""
return loads(self.data)


class WSMessageBinary(NamedTuple):
data: bytes
size: int
Expand Down Expand Up @@ -114,17 +127,26 @@ class WSMessageError(NamedTuple):
type: Literal[WSMsgType.ERROR] = WSMsgType.ERROR


WSMessage = Union[
WSMessageContinuation,
WSMessageText,
WSMessageBinary,
WSMessagePing,
WSMessagePong,
WSMessageClose,
WSMessageClosing,
WSMessageClosed,
WSMessageError,
]
# Base message types (excluding TEXT variants)
_WSMessageBase = (
WSMessageContinuation
| WSMessageBinary
| WSMessagePing
| WSMessagePong
| WSMessageClose
| WSMessageClosing
| WSMessageClosed
| WSMessageError
)

# All message types
WSMessage = _WSMessageBase | WSMessageText | WSMessageTextBytes

# Message type when decode_text=True (default) - TEXT messages have str data
WSMessageDecodeText = _WSMessageBase | WSMessageText

# Message type when decode_text=False - TEXT messages have bytes data
WSMessageNoDecodeText = _WSMessageBase | WSMessageTextBytes

WS_CLOSED_MESSAGE = WSMessageClosed()
WS_CLOSING_MESSAGE = WSMessageClosing()
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/_websocket/reader_c.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cdef object TUPLE_NEW
cdef object WSMsgType

cdef object WSMessageText
cdef object WSMessageTextBytes
cdef object WSMessageBinary
cdef object WSMessagePing
cdef object WSMessagePong
Expand Down Expand Up @@ -66,6 +67,7 @@ cdef class WebSocketReader:

cdef WebSocketDataQueue queue
cdef unsigned int _max_msg_size
cdef bint _decode_text

cdef Exception _exc
cdef bytearray _partial
Expand Down
38 changes: 25 additions & 13 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
WSMessagePing,
WSMessagePong,
WSMessageText,
WSMessageTextBytes,
WSMsgType,
)

Expand Down Expand Up @@ -139,10 +140,15 @@ def _read_from_buffer(self) -> WSMessage:

class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
self,
queue: WebSocketDataQueue,
max_msg_size: int,
compress: bool = True,
decode_text: bool = True,
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._decode_text = decode_text

self._exc: Exception | None = None
self._partial = bytearray()
Expand Down Expand Up @@ -270,18 +276,24 @@ def _handle_frame(

size = len(payload_merged)
if opcode == OP_CODE_TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc

# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
if self._decode_text:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc

# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
else:
# Return raw bytes for TEXT messages when decode_text=False
msg = TUPLE_NEW(
WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT)
)
else:
msg = TUPLE_NEW(
WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY)
Expand Down
134 changes: 117 additions & 17 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@
)
from contextlib import suppress
from types import TracebackType
from typing import TYPE_CHECKING, Any, Final, Generic, TypedDict, TypeVar, final
from typing import (
TYPE_CHECKING,
Any,
Final,
Generic,
Literal,
TypedDict,
TypeVar,
final,
overload,
)

from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
from yarl import URL, Query
Expand Down Expand Up @@ -187,6 +197,27 @@ class _RequestOptions(TypedDict, total=False):
middlewares: Sequence[ClientMiddlewareType] | None


class _WSConnectOptions(TypedDict, total=False):
method: str
protocols: Collection[str]
timeout: "ClientWSTimeout | _SENTINEL"
receive_timeout: float | None
autoclose: bool
autoping: bool
heartbeat: float | None
auth: BasicAuth | None
origin: str | None
params: Query
headers: LooseHeaders | None
proxy: StrOrURL | None
proxy_auth: BasicAuth | None
ssl: SSLContext | bool | Fingerprint
server_hostname: str | None
proxy_headers: LooseHeaders | None
compress: int
max_msg_size: int


@frozen_dataclass_decorator
class ClientTimeout:
total: float | None = None
Expand Down Expand Up @@ -215,7 +246,11 @@ class ClientTimeout:
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})

_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse)
_RetType_co = TypeVar(
"_RetType_co",
bound="ClientResponse | ClientWebSocketResponse[bool]",
covariant=True,
)
_CharsetResolver = Callable[[ClientResponse, bytes], str]


Expand Down Expand Up @@ -866,6 +901,35 @@ async def _connect_and_send_request(
)
raise

if sys.version_info >= (3, 11) and TYPE_CHECKING:

@overload
def ws_connect(
self,
url: StrOrURL,
*,
decode_text: Literal[True] = ...,
**kwargs: Unpack[_WSConnectOptions],
) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ...

@overload
def ws_connect(
self,
url: StrOrURL,
*,
decode_text: Literal[False],
**kwargs: Unpack[_WSConnectOptions],
) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ...

@overload
def ws_connect(
self,
url: StrOrURL,
*,
decode_text: bool = ...,
**kwargs: Unpack[_WSConnectOptions],
) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ...

def ws_connect(
self,
url: StrOrURL,
Expand All @@ -888,7 +952,8 @@ def ws_connect(
proxy_headers: LooseHeaders | None = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
) -> "_WSRequestContextManager":
decode_text: bool = True,
) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]":
"""Initiate websocket connection."""
return _WSRequestContextManager(
self._ws_connect(
Expand All @@ -911,9 +976,39 @@ def ws_connect(
proxy_headers=proxy_headers,
compress=compress,
max_msg_size=max_msg_size,
decode_text=decode_text,
)
)

if sys.version_info >= (3, 11) and TYPE_CHECKING:

@overload
async def _ws_connect(
self,
url: StrOrURL,
*,
decode_text: Literal[True] = ...,
**kwargs: Unpack[_WSConnectOptions],
) -> "ClientWebSocketResponse[Literal[True]]": ...

@overload
async def _ws_connect(
self,
url: StrOrURL,
*,
decode_text: Literal[False],
**kwargs: Unpack[_WSConnectOptions],
) -> "ClientWebSocketResponse[Literal[False]]": ...

@overload
async def _ws_connect(
self,
url: StrOrURL,
*,
decode_text: bool = ...,
**kwargs: Unpack[_WSConnectOptions],
) -> "ClientWebSocketResponse[bool]": ...

async def _ws_connect(
self,
url: StrOrURL,
Expand All @@ -936,7 +1031,8 @@ async def _ws_connect(
proxy_headers: LooseHeaders | None = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
) -> ClientWebSocketResponse:
decode_text: bool = True,
) -> "ClientWebSocketResponse[bool]":
if timeout is not sentinel:
if isinstance(timeout, ClientWSTimeout):
ws_timeout = timeout
Expand Down Expand Up @@ -1098,7 +1194,9 @@ async def _ws_connect(
transport = conn.transport
assert transport is not None
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
conn_proto.set_parser(
WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader
)
writer = WebSocketWriter(
conn_proto,
transport,
Expand Down Expand Up @@ -1373,31 +1471,33 @@ async def __aexit__(
await self.close()


class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]):
class _BaseRequestContextManager(
Coroutine[Any, Any, _RetType_co], Generic[_RetType_co]
):
__slots__ = ("_coro", "_resp")

def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None:
self._coro: Coroutine[asyncio.Future[Any], None, _RetType] = coro
def __init__(self, coro: Coroutine[asyncio.Future[Any], None, _RetType_co]) -> None:
self._coro: Coroutine[asyncio.Future[Any], None, _RetType_co] = coro

def send(self, arg: None) -> "asyncio.Future[Any]":
def send(self, arg: None) -> asyncio.Future[Any]:
return self._coro.send(arg)

def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]":
def throw(self, *args: Any, **kwargs: Any) -> asyncio.Future[Any]:
return self._coro.throw(*args, **kwargs)

def close(self) -> None:
return self._coro.close()

def __await__(self) -> Generator[Any, None, _RetType]:
def __await__(self) -> Generator[Any, None, _RetType_co]:
ret = self._coro.__await__()
return ret

def __iter__(self) -> Generator[Any, None, _RetType]:
def __iter__(self) -> Generator[Any, None, _RetType_co]:
return self.__await__()

async def __aenter__(self) -> _RetType:
self._resp: _RetType = await self._coro
return await self._resp.__aenter__()
async def __aenter__(self) -> _RetType_co:
self._resp: _RetType_co = await self._coro
return await self._resp.__aenter__() # type: ignore[return-value]

async def __aexit__(
self,
Expand All @@ -1409,15 +1509,15 @@ async def __aexit__(


_RequestContextManager = _BaseRequestContextManager[ClientResponse]
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse]
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse[bool]]


class _SessionRequestContextManager:
__slots__ = ("_coro", "_resp", "_session")

def __init__(
self,
coro: Coroutine["asyncio.Future[Any]", None, ClientResponse],
coro: Coroutine[asyncio.Future[Any], None, ClientResponse],
session: ClientSession,
) -> None:
self._coro = coro
Expand Down
Loading
Loading