Skip to content

Commit 8b919d3

Browse files
bdracoDreamsorcererpre-commit-ci[bot]
authored
Add decode_text parameter to WebSocket for receiving TEXT as bytes (#11764)
Co-authored-by: Sam Bull <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5d67dcf commit 8b919d3

20 files changed

+883
-79
lines changed

CHANGES/11763.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added ``decode_text`` parameter to :meth:`~aiohttp.ClientSession.ws_connect` and :class:`~aiohttp.web.WebSocketResponse` to receive WebSocket TEXT messages as raw bytes instead of decoded strings, enabling direct use with high-performance JSON parsers like ``orjson`` -- by :user:`bdraco`.

CHANGES/11764.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
11763.feature.rst

aiohttp/_websocket/models.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
from collections.abc import Callable
55
from enum import IntEnum
6-
from typing import Any, Final, Literal, NamedTuple, Union, cast
6+
from typing import Any, Final, Literal, NamedTuple, cast
77

88
WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
99

@@ -59,6 +59,19 @@ def json(
5959
return loads(self.data)
6060

6161

62+
class WSMessageTextBytes(NamedTuple):
63+
"""WebSocket TEXT message with raw bytes (no UTF-8 decoding)."""
64+
65+
data: bytes
66+
size: int
67+
extra: str | None = None
68+
type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT
69+
70+
def json(self, *, loads: Callable[[bytes], Any] = json.loads) -> Any:
71+
"""Return parsed JSON data."""
72+
return loads(self.data)
73+
74+
6275
class WSMessageBinary(NamedTuple):
6376
data: bytes
6477
size: int
@@ -114,17 +127,26 @@ class WSMessageError(NamedTuple):
114127
type: Literal[WSMsgType.ERROR] = WSMsgType.ERROR
115128

116129

117-
WSMessage = Union[
118-
WSMessageContinuation,
119-
WSMessageText,
120-
WSMessageBinary,
121-
WSMessagePing,
122-
WSMessagePong,
123-
WSMessageClose,
124-
WSMessageClosing,
125-
WSMessageClosed,
126-
WSMessageError,
127-
]
130+
# Base message types (excluding TEXT variants)
131+
_WSMessageBase = (
132+
WSMessageContinuation
133+
| WSMessageBinary
134+
| WSMessagePing
135+
| WSMessagePong
136+
| WSMessageClose
137+
| WSMessageClosing
138+
| WSMessageClosed
139+
| WSMessageError
140+
)
141+
142+
# All message types
143+
WSMessage = _WSMessageBase | WSMessageText | WSMessageTextBytes
144+
145+
# Message type when decode_text=True (default) - TEXT messages have str data
146+
WSMessageDecodeText = _WSMessageBase | WSMessageText
147+
148+
# Message type when decode_text=False - TEXT messages have bytes data
149+
WSMessageNoDecodeText = _WSMessageBase | WSMessageTextBytes
128150

129151
WS_CLOSED_MESSAGE = WSMessageClosed()
130152
WS_CLOSING_MESSAGE = WSMessageClosing()

aiohttp/_websocket/reader_c.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ cdef object TUPLE_NEW
2727
cdef object WSMsgType
2828

2929
cdef object WSMessageText
30+
cdef object WSMessageTextBytes
3031
cdef object WSMessageBinary
3132
cdef object WSMessagePing
3233
cdef object WSMessagePong
@@ -66,6 +67,7 @@ cdef class WebSocketReader:
6667

6768
cdef WebSocketDataQueue queue
6869
cdef unsigned int _max_msg_size
70+
cdef bint _decode_text
6971

7072
cdef Exception _exc
7173
cdef bytearray _partial

aiohttp/_websocket/reader_py.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
WSMessagePing,
2121
WSMessagePong,
2222
WSMessageText,
23+
WSMessageTextBytes,
2324
WSMsgType,
2425
)
2526

@@ -139,10 +140,15 @@ def _read_from_buffer(self) -> WSMessage:
139140

140141
class WebSocketReader:
141142
def __init__(
142-
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
143+
self,
144+
queue: WebSocketDataQueue,
145+
max_msg_size: int,
146+
compress: bool = True,
147+
decode_text: bool = True,
143148
) -> None:
144149
self.queue = queue
145150
self._max_msg_size = max_msg_size
151+
self._decode_text = decode_text
146152

147153
self._exc: Exception | None = None
148154
self._partial = bytearray()
@@ -270,18 +276,24 @@ def _handle_frame(
270276

271277
size = len(payload_merged)
272278
if opcode == OP_CODE_TEXT:
273-
try:
274-
text = payload_merged.decode("utf-8")
275-
except UnicodeDecodeError as exc:
276-
raise WebSocketError(
277-
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
278-
) from exc
279-
280-
# XXX: The Text and Binary messages here can be a performance
281-
# bottleneck, so we use tuple.__new__ to improve performance.
282-
# This is not type safe, but many tests should fail in
283-
# test_client_ws_functional.py if this is wrong.
284-
msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
279+
if self._decode_text:
280+
try:
281+
text = payload_merged.decode("utf-8")
282+
except UnicodeDecodeError as exc:
283+
raise WebSocketError(
284+
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
285+
) from exc
286+
287+
# XXX: The Text and Binary messages here can be a performance
288+
# bottleneck, so we use tuple.__new__ to improve performance.
289+
# This is not type safe, but many tests should fail in
290+
# test_client_ws_functional.py if this is wrong.
291+
msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
292+
else:
293+
# Return raw bytes for TEXT messages when decode_text=False
294+
msg = TUPLE_NEW(
295+
WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT)
296+
)
285297
else:
286298
msg = TUPLE_NEW(
287299
WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY)

aiohttp/client.py

Lines changed: 117 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,17 @@
2121
)
2222
from contextlib import suppress
2323
from types import TracebackType
24-
from typing import TYPE_CHECKING, Any, Final, Generic, TypedDict, TypeVar, final
24+
from typing import (
25+
TYPE_CHECKING,
26+
Any,
27+
Final,
28+
Generic,
29+
Literal,
30+
TypedDict,
31+
TypeVar,
32+
final,
33+
overload,
34+
)
2535

2636
from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
2737
from yarl import URL, Query
@@ -187,6 +197,27 @@ class _RequestOptions(TypedDict, total=False):
187197
middlewares: Sequence[ClientMiddlewareType] | None
188198

189199

200+
class _WSConnectOptions(TypedDict, total=False):
201+
method: str
202+
protocols: Collection[str]
203+
timeout: "ClientWSTimeout | _SENTINEL"
204+
receive_timeout: float | None
205+
autoclose: bool
206+
autoping: bool
207+
heartbeat: float | None
208+
auth: BasicAuth | None
209+
origin: str | None
210+
params: Query
211+
headers: LooseHeaders | None
212+
proxy: StrOrURL | None
213+
proxy_auth: BasicAuth | None
214+
ssl: SSLContext | bool | Fingerprint
215+
server_hostname: str | None
216+
proxy_headers: LooseHeaders | None
217+
compress: int
218+
max_msg_size: int
219+
220+
190221
@frozen_dataclass_decorator
191222
class ClientTimeout:
192223
total: float | None = None
@@ -215,7 +246,11 @@ class ClientTimeout:
215246
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
216247
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
217248

218-
_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse)
249+
_RetType_co = TypeVar(
250+
"_RetType_co",
251+
bound="ClientResponse | ClientWebSocketResponse[bool]",
252+
covariant=True,
253+
)
219254
_CharsetResolver = Callable[[ClientResponse, bytes], str]
220255

221256

@@ -866,6 +901,35 @@ async def _connect_and_send_request(
866901
)
867902
raise
868903

904+
if sys.version_info >= (3, 11) and TYPE_CHECKING:
905+
906+
@overload
907+
def ws_connect(
908+
self,
909+
url: StrOrURL,
910+
*,
911+
decode_text: Literal[True] = ...,
912+
**kwargs: Unpack[_WSConnectOptions],
913+
) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ...
914+
915+
@overload
916+
def ws_connect(
917+
self,
918+
url: StrOrURL,
919+
*,
920+
decode_text: Literal[False],
921+
**kwargs: Unpack[_WSConnectOptions],
922+
) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ...
923+
924+
@overload
925+
def ws_connect(
926+
self,
927+
url: StrOrURL,
928+
*,
929+
decode_text: bool = ...,
930+
**kwargs: Unpack[_WSConnectOptions],
931+
) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ...
932+
869933
def ws_connect(
870934
self,
871935
url: StrOrURL,
@@ -888,7 +952,8 @@ def ws_connect(
888952
proxy_headers: LooseHeaders | None = None,
889953
compress: int = 0,
890954
max_msg_size: int = 4 * 1024 * 1024,
891-
) -> "_WSRequestContextManager":
955+
decode_text: bool = True,
956+
) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]":
892957
"""Initiate websocket connection."""
893958
return _WSRequestContextManager(
894959
self._ws_connect(
@@ -911,9 +976,39 @@ def ws_connect(
911976
proxy_headers=proxy_headers,
912977
compress=compress,
913978
max_msg_size=max_msg_size,
979+
decode_text=decode_text,
914980
)
915981
)
916982

983+
if sys.version_info >= (3, 11) and TYPE_CHECKING:
984+
985+
@overload
986+
async def _ws_connect(
987+
self,
988+
url: StrOrURL,
989+
*,
990+
decode_text: Literal[True] = ...,
991+
**kwargs: Unpack[_WSConnectOptions],
992+
) -> "ClientWebSocketResponse[Literal[True]]": ...
993+
994+
@overload
995+
async def _ws_connect(
996+
self,
997+
url: StrOrURL,
998+
*,
999+
decode_text: Literal[False],
1000+
**kwargs: Unpack[_WSConnectOptions],
1001+
) -> "ClientWebSocketResponse[Literal[False]]": ...
1002+
1003+
@overload
1004+
async def _ws_connect(
1005+
self,
1006+
url: StrOrURL,
1007+
*,
1008+
decode_text: bool = ...,
1009+
**kwargs: Unpack[_WSConnectOptions],
1010+
) -> "ClientWebSocketResponse[bool]": ...
1011+
9171012
async def _ws_connect(
9181013
self,
9191014
url: StrOrURL,
@@ -936,7 +1031,8 @@ async def _ws_connect(
9361031
proxy_headers: LooseHeaders | None = None,
9371032
compress: int = 0,
9381033
max_msg_size: int = 4 * 1024 * 1024,
939-
) -> ClientWebSocketResponse:
1034+
decode_text: bool = True,
1035+
) -> "ClientWebSocketResponse[bool]":
9401036
if timeout is not sentinel:
9411037
if isinstance(timeout, ClientWSTimeout):
9421038
ws_timeout = timeout
@@ -1098,7 +1194,9 @@ async def _ws_connect(
10981194
transport = conn.transport
10991195
assert transport is not None
11001196
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
1101-
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
1197+
conn_proto.set_parser(
1198+
WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader
1199+
)
11021200
writer = WebSocketWriter(
11031201
conn_proto,
11041202
transport,
@@ -1373,31 +1471,33 @@ async def __aexit__(
13731471
await self.close()
13741472

13751473

1376-
class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]):
1474+
class _BaseRequestContextManager(
1475+
Coroutine[Any, Any, _RetType_co], Generic[_RetType_co]
1476+
):
13771477
__slots__ = ("_coro", "_resp")
13781478

1379-
def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None:
1380-
self._coro: Coroutine[asyncio.Future[Any], None, _RetType] = coro
1479+
def __init__(self, coro: Coroutine[asyncio.Future[Any], None, _RetType_co]) -> None:
1480+
self._coro: Coroutine[asyncio.Future[Any], None, _RetType_co] = coro
13811481

1382-
def send(self, arg: None) -> "asyncio.Future[Any]":
1482+
def send(self, arg: None) -> asyncio.Future[Any]:
13831483
return self._coro.send(arg)
13841484

1385-
def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]":
1485+
def throw(self, *args: Any, **kwargs: Any) -> asyncio.Future[Any]:
13861486
return self._coro.throw(*args, **kwargs)
13871487

13881488
def close(self) -> None:
13891489
return self._coro.close()
13901490

1391-
def __await__(self) -> Generator[Any, None, _RetType]:
1491+
def __await__(self) -> Generator[Any, None, _RetType_co]:
13921492
ret = self._coro.__await__()
13931493
return ret
13941494

1395-
def __iter__(self) -> Generator[Any, None, _RetType]:
1495+
def __iter__(self) -> Generator[Any, None, _RetType_co]:
13961496
return self.__await__()
13971497

1398-
async def __aenter__(self) -> _RetType:
1399-
self._resp: _RetType = await self._coro
1400-
return await self._resp.__aenter__()
1498+
async def __aenter__(self) -> _RetType_co:
1499+
self._resp: _RetType_co = await self._coro
1500+
return await self._resp.__aenter__() # type: ignore[return-value]
14011501

14021502
async def __aexit__(
14031503
self,
@@ -1409,15 +1509,15 @@ async def __aexit__(
14091509

14101510

14111511
_RequestContextManager = _BaseRequestContextManager[ClientResponse]
1412-
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse]
1512+
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse[bool]]
14131513

14141514

14151515
class _SessionRequestContextManager:
14161516
__slots__ = ("_coro", "_resp", "_session")
14171517

14181518
def __init__(
14191519
self,
1420-
coro: Coroutine["asyncio.Future[Any]", None, ClientResponse],
1520+
coro: Coroutine[asyncio.Future[Any], None, ClientResponse],
14211521
session: ClientSession,
14221522
) -> None:
14231523
self._coro = coro

0 commit comments

Comments
 (0)