Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
11f2435
Add option to avoid decoding WebSocket TEXT
bdraco Nov 17, 2025
38b3c34
touch ups
bdraco Nov 17, 2025
104666b
decode type
bdraco Nov 17, 2025
348529d
coverage for bytes loads
bdraco Nov 17, 2025
286ca9b
Merge branch 'master' into no_decode_websocket_option
bdraco Dec 12, 2025
421896a
unpack
bdraco Dec 12, 2025
e9c5a7a
fallback overloads
bdraco Dec 12, 2025
10f7c52
fallback overloads
bdraco Dec 12, 2025
baab16e
fallback overloads
bdraco Dec 12, 2025
3c3ad78
fallback overloads
bdraco Dec 12, 2025
ca14d5f
fallback overloads
bdraco Dec 12, 2025
2e6d996
fallback overloads
bdraco Dec 12, 2025
550ca1a
fallback overloads
bdraco Dec 12, 2025
bc4f6a0
fallback overloads
bdraco Dec 12, 2025
5620648
fallback overloads
bdraco Dec 12, 2025
e025546
fallback overloads
bdraco Dec 12, 2025
5ef6d87
narrow receive
bdraco Dec 12, 2025
500c239
narrow receive
bdraco Dec 12, 2025
2b85337
narrow
bdraco Dec 12, 2025
ef7b0b2
default only works on py3.13+
bdraco Dec 12, 2025
51b6154
no fallback;
bdraco Dec 12, 2025
5328513
no fallback;
bdraco Dec 12, 2025
976d816
try another way for fallback
bdraco Dec 12, 2025
f6d9d1e
try another way for fallback
bdraco Dec 12, 2025
39f2a47
need to have default or everything has to be updated
bdraco Dec 12, 2025
df9d892
update tests as well
bdraco Dec 12, 2025
4ecc82c
infer from false
bdraco Dec 12, 2025
6bd3b60
just set them
bdraco Dec 12, 2025
cd9a044
just set them
bdraco Dec 12, 2025
c34ba73
just set them
bdraco Dec 12, 2025
fc32471
cleanup
bdraco Dec 12, 2025
743132d
changelog
bdraco Dec 12, 2025
633fc46
reduce
bdraco Dec 12, 2025
a49441f
reduce
bdraco Dec 12, 2025
55852de
fix changelog
bdraco Dec 12, 2025
b6493ee
changelog
bdraco Dec 12, 2025
f8015af
Update pyproject.toml
bdraco Dec 12, 2025
69c13bf
Update aiohttp/_websocket/models.py
bdraco Dec 12, 2025
79a9036
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2025
ec38532
tweaks
bdraco Dec 12, 2025
73e297f
Merge remote-tracking branch 'upstream/no_decode_websocket_option' in…
bdraco Dec 12, 2025
080880d
newer syntax
bdraco Dec 12, 2025
d639aaf
narrow to bytes
bdraco Dec 12, 2025
c85cce7
docs
bdraco Dec 14, 2025
637d829
Merge branch 'master' into no_decode_websocket_option
bdraco Dec 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/11763.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added ``decode_text`` parameter to :meth:`~aiohttp.ClientSession.ws_connect` and :class:`~aiohttp.web.WebSocketResponse` to receive WebSocket TEXT messages as raw bytes instead of decoded strings, enabling direct use with high-performance JSON parsers like ``orjson`` -- by :user:`bdraco`.
1 change: 1 addition & 0 deletions CHANGES/11764.feature.rst
46 changes: 34 additions & 12 deletions aiohttp/_websocket/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from collections.abc import Callable
from enum import IntEnum
from typing import Any, Final, Literal, NamedTuple, Union, cast
from typing import Any, Final, Literal, NamedTuple, cast

WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])

Expand Down Expand Up @@ -59,6 +59,19 @@ def json(
return loads(self.data)


class WSMessageTextBytes(NamedTuple):
"""WebSocket TEXT message with raw bytes (no UTF-8 decoding)."""

data: bytes
size: int
extra: str | None = None
type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT

def json(self, *, loads: Callable[[bytes], Any] = json.loads) -> Any:
"""Return parsed JSON data."""
return loads(self.data)


class WSMessageBinary(NamedTuple):
data: bytes
size: int
Expand Down Expand Up @@ -114,17 +127,26 @@ class WSMessageError(NamedTuple):
type: Literal[WSMsgType.ERROR] = WSMsgType.ERROR


WSMessage = Union[
WSMessageContinuation,
WSMessageText,
WSMessageBinary,
WSMessagePing,
WSMessagePong,
WSMessageClose,
WSMessageClosing,
WSMessageClosed,
WSMessageError,
]
# Base message types (excluding TEXT variants)
_WSMessageBase = (
WSMessageContinuation
| WSMessageBinary
| WSMessagePing
| WSMessagePong
| WSMessageClose
| WSMessageClosing
| WSMessageClosed
| WSMessageError
)

# All message types
WSMessage = _WSMessageBase | WSMessageText | WSMessageTextBytes

# Message type when decode_text=True (default) - TEXT messages have str data
WSMessageDecodeText = _WSMessageBase | WSMessageText

# Message type when decode_text=False - TEXT messages have bytes data
WSMessageNoDecodeText = _WSMessageBase | WSMessageTextBytes

WS_CLOSED_MESSAGE = WSMessageClosed()
WS_CLOSING_MESSAGE = WSMessageClosing()
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/_websocket/reader_c.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cdef object TUPLE_NEW
cdef object WSMsgType

cdef object WSMessageText
cdef object WSMessageTextBytes
cdef object WSMessageBinary
cdef object WSMessagePing
cdef object WSMessagePong
Expand Down Expand Up @@ -66,6 +67,7 @@ cdef class WebSocketReader:

cdef WebSocketDataQueue queue
cdef unsigned int _max_msg_size
cdef bint _decode_text

cdef Exception _exc
cdef bytearray _partial
Expand Down
38 changes: 25 additions & 13 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
WSMessagePing,
WSMessagePong,
WSMessageText,
WSMessageTextBytes,
WSMsgType,
)

Expand Down Expand Up @@ -139,10 +140,15 @@ def _read_from_buffer(self) -> WSMessage:

class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
self,
queue: WebSocketDataQueue,
max_msg_size: int,
compress: bool = True,
decode_text: bool = True,
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._decode_text = decode_text

self._exc: Exception | None = None
self._partial = bytearray()
Expand Down Expand Up @@ -270,18 +276,24 @@ def _handle_frame(

size = len(payload_merged)
if opcode == OP_CODE_TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc

# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
if self._decode_text:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc

# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
else:
# Return raw bytes for TEXT messages when decode_text=False
msg = TUPLE_NEW(
WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT)
)
else:
msg = TUPLE_NEW(
WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY)
Expand Down
134 changes: 117 additions & 17 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@
)
from contextlib import suppress
from types import TracebackType
from typing import TYPE_CHECKING, Any, Final, Generic, TypedDict, TypeVar, final
from typing import (
TYPE_CHECKING,
Any,
Final,
Generic,
Literal,
TypedDict,
TypeVar,
final,
overload,
)

from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
from yarl import URL, Query
Expand Down Expand Up @@ -187,6 +197,27 @@ class _RequestOptions(TypedDict, total=False):
middlewares: Sequence[ClientMiddlewareType] | None


class _WSConnectOptions(TypedDict, total=False):
method: str
protocols: Collection[str]
timeout: "ClientWSTimeout | _SENTINEL"
receive_timeout: float | None
autoclose: bool
autoping: bool
heartbeat: float | None
auth: BasicAuth | None
origin: str | None
params: Query
headers: LooseHeaders | None
proxy: StrOrURL | None
proxy_auth: BasicAuth | None
ssl: SSLContext | bool | Fingerprint
server_hostname: str | None
proxy_headers: LooseHeaders | None
compress: int
max_msg_size: int


@frozen_dataclass_decorator
class ClientTimeout:
total: float | None = None
Expand Down Expand Up @@ -215,7 +246,11 @@ class ClientTimeout:
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})

_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse)
_RetType_co = TypeVar(
"_RetType_co",
bound="ClientResponse | ClientWebSocketResponse[bool]",
covariant=True,
)
_CharsetResolver = Callable[[ClientResponse, bytes], str]


Expand Down Expand Up @@ -866,6 +901,35 @@ async def _connect_and_send_request(
)
raise

if sys.version_info >= (3, 11) and TYPE_CHECKING:

@overload
def ws_connect(
self,
url: StrOrURL,
*,
decode_text: Literal[True] = ...,
**kwargs: Unpack[_WSConnectOptions],
) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ...

@overload
def ws_connect(
self,
url: StrOrURL,
*,
decode_text: Literal[False],
**kwargs: Unpack[_WSConnectOptions],
) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ...

@overload
def ws_connect(
self,
url: StrOrURL,
*,
decode_text: bool = ...,
**kwargs: Unpack[_WSConnectOptions],
) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ...

def ws_connect(
self,
url: StrOrURL,
Expand All @@ -888,7 +952,8 @@ def ws_connect(
proxy_headers: LooseHeaders | None = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
) -> "_WSRequestContextManager":
decode_text: bool = True,
) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]":
"""Initiate websocket connection."""
return _WSRequestContextManager(
self._ws_connect(
Expand All @@ -911,9 +976,39 @@ def ws_connect(
proxy_headers=proxy_headers,
compress=compress,
max_msg_size=max_msg_size,
decode_text=decode_text,
)
)

if sys.version_info >= (3, 11) and TYPE_CHECKING:

@overload
async def _ws_connect(
self,
url: StrOrURL,
*,
decode_text: Literal[True] = ...,
**kwargs: Unpack[_WSConnectOptions],
) -> "ClientWebSocketResponse[Literal[True]]": ...

@overload
async def _ws_connect(
self,
url: StrOrURL,
*,
decode_text: Literal[False],
**kwargs: Unpack[_WSConnectOptions],
) -> "ClientWebSocketResponse[Literal[False]]": ...

@overload
async def _ws_connect(
self,
url: StrOrURL,
*,
decode_text: bool = ...,
**kwargs: Unpack[_WSConnectOptions],
) -> "ClientWebSocketResponse[bool]": ...

async def _ws_connect(
self,
url: StrOrURL,
Expand All @@ -936,7 +1031,8 @@ async def _ws_connect(
proxy_headers: LooseHeaders | None = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
) -> ClientWebSocketResponse:
decode_text: bool = True,
) -> "ClientWebSocketResponse[bool]":
if timeout is not sentinel:
if isinstance(timeout, ClientWSTimeout):
ws_timeout = timeout
Expand Down Expand Up @@ -1098,7 +1194,9 @@ async def _ws_connect(
transport = conn.transport
assert transport is not None
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
conn_proto.set_parser(
WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader
)
writer = WebSocketWriter(
conn_proto,
transport,
Expand Down Expand Up @@ -1373,31 +1471,33 @@ async def __aexit__(
await self.close()


class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]):
class _BaseRequestContextManager(
Coroutine[Any, Any, _RetType_co], Generic[_RetType_co]
):
__slots__ = ("_coro", "_resp")

def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None:
self._coro: Coroutine[asyncio.Future[Any], None, _RetType] = coro
def __init__(self, coro: Coroutine[asyncio.Future[Any], None, _RetType_co]) -> None:
self._coro: Coroutine[asyncio.Future[Any], None, _RetType_co] = coro

def send(self, arg: None) -> "asyncio.Future[Any]":
def send(self, arg: None) -> asyncio.Future[Any]:
return self._coro.send(arg)

def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]":
def throw(self, *args: Any, **kwargs: Any) -> asyncio.Future[Any]:
return self._coro.throw(*args, **kwargs)

def close(self) -> None:
return self._coro.close()

def __await__(self) -> Generator[Any, None, _RetType]:
def __await__(self) -> Generator[Any, None, _RetType_co]:
ret = self._coro.__await__()
return ret

def __iter__(self) -> Generator[Any, None, _RetType]:
def __iter__(self) -> Generator[Any, None, _RetType_co]:
return self.__await__()

async def __aenter__(self) -> _RetType:
self._resp: _RetType = await self._coro
return await self._resp.__aenter__()
async def __aenter__(self) -> _RetType_co:
self._resp: _RetType_co = await self._coro
return await self._resp.__aenter__() # type: ignore[return-value]

async def __aexit__(
self,
Expand All @@ -1409,15 +1509,15 @@ async def __aexit__(


_RequestContextManager = _BaseRequestContextManager[ClientResponse]
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse]
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse[bool]]


class _SessionRequestContextManager:
__slots__ = ("_coro", "_resp", "_session")

def __init__(
self,
coro: Coroutine["asyncio.Future[Any]", None, ClientResponse],
coro: Coroutine[asyncio.Future[Any], None, ClientResponse],
session: ClientSession,
) -> None:
self._coro = coro
Expand Down
Loading
Loading