Skip to content

Commit 2b1af69

Browse files
bdracoDreamsorcererpre-commit-ci[bot]
committed
Add decode_text parameter to WebSocket for receiving TEXT as bytes (#11764)
Co-authored-by: Sam Bull <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 8b919d3)
1 parent fdd9c1d commit 2b1af69

20 files changed

+884
-71
lines changed

CHANGES/11763.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added ``decode_text`` parameter to :meth:`~aiohttp.ClientSession.ws_connect` and :class:`~aiohttp.web.WebSocketResponse` to receive WebSocket TEXT messages as raw bytes instead of decoded strings, enabling direct use with high-performance JSON parsers like ``orjson`` -- by :user:`bdraco`.

CHANGES/11764.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
11763.feature.rst

aiohttp/_websocket/models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,25 @@ def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
6262
return loads(self.data)
6363

6464

65+
class WSMessageTextBytes(NamedTuple):
66+
"""WebSocket TEXT message with raw bytes (no UTF-8 decoding)."""
67+
68+
type: WSMsgType
69+
data: bytes
70+
extra: str | None
71+
72+
def json(self, *, loads: Callable[[bytes], Any] = json.loads) -> Any:
73+
"""Return parsed JSON data."""
74+
return loads(self.data)
75+
76+
77+
# Type aliases for message types based on decode_text setting
78+
# When decode_text=True, TEXT messages have str data (WSMessage)
79+
# When decode_text=False, TEXT messages have bytes data (WSMessageTextBytes)
80+
WSMessageDecodeText = WSMessage
81+
WSMessageNoDecodeText = WSMessage | WSMessageTextBytes
82+
83+
6584
# Constructing the tuple directly to avoid the overhead of
6685
# the lambda and arg processing since NamedTuples are constructed
6786
# with a run time built lambda

aiohttp/_websocket/reader_c.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ cdef object TUPLE_NEW
2626

2727
cdef object WSMsgType
2828
cdef object WSMessage
29+
cdef object WSMessageTextBytes
2930

3031
cdef object WS_MSG_TYPE_TEXT
3132
cdef object WS_MSG_TYPE_BINARY
@@ -60,6 +61,7 @@ cdef class WebSocketReader:
6061

6162
cdef WebSocketDataQueue queue
6263
cdef unsigned int _max_msg_size
64+
cdef bint _decode_text
6365

6466
cdef Exception _exc
6567
cdef bytearray _partial

aiohttp/_websocket/reader_py.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
WebSocketError,
1616
WSCloseCode,
1717
WSMessage,
18+
WSMessageTextBytes,
1819
WSMsgType,
1920
)
2021

@@ -67,7 +68,7 @@ def __init__(
6768
self._eof = False
6869
self._waiter: asyncio.Future[None] | None = None
6970
self._exception: BaseException | None = None
70-
self._buffer: deque[tuple[WSMessage, int]] = deque()
71+
self._buffer: deque[tuple[WSMessage | WSMessageTextBytes, int]] = deque()
7172
self._get_buffer = self._buffer.popleft
7273
self._put_buffer = self._buffer.append
7374

@@ -100,7 +101,9 @@ def feed_eof(self) -> None:
100101
self._release_waiter()
101102
self._exception = None # Break cyclic references
102103

103-
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
104+
def feed_data(
105+
self, data: "WSMessage | WSMessageTextBytes", size: "cython_int"
106+
) -> None:
104107
self._size += size
105108
self._put_buffer((data, size))
106109
self._release_waiter()
@@ -132,10 +135,15 @@ def _read_from_buffer(self) -> WSMessage:
132135

133136
class WebSocketReader:
134137
def __init__(
135-
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
138+
self,
139+
queue: WebSocketDataQueue,
140+
max_msg_size: int,
141+
compress: bool = True,
142+
decode_text: bool = True,
136143
) -> None:
137144
self.queue = queue
138145
self._max_msg_size = max_msg_size
146+
self._decode_text = decode_text
139147

140148
self._exc: Exception | None = None
141149
self._partial = bytearray()
@@ -262,21 +270,30 @@ def _handle_frame(
262270
payload_merged = bytes(assembled_payload)
263271

264272
if opcode == OP_CODE_TEXT:
265-
try:
266-
text = payload_merged.decode("utf-8")
267-
except UnicodeDecodeError as exc:
268-
raise WebSocketError(
269-
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
270-
) from exc
271-
272-
# XXX: The Text and Binary messages here can be a performance
273-
# bottleneck, so we use tuple.__new__ to improve performance.
274-
# This is not type safe, but many tests should fail in
275-
# test_client_ws_functional.py if this is wrong.
276-
self.queue.feed_data(
277-
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
278-
len(payload_merged),
279-
)
273+
if self._decode_text:
274+
try:
275+
text = payload_merged.decode("utf-8")
276+
except UnicodeDecodeError as exc:
277+
raise WebSocketError(
278+
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
279+
) from exc
280+
281+
# XXX: The Text and Binary messages here can be a performance
282+
# bottleneck, so we use tuple.__new__ to improve performance.
283+
# This is not type safe, but many tests should fail in
284+
# test_client_ws_functional.py if this is wrong.
285+
self.queue.feed_data(
286+
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
287+
len(payload_merged),
288+
)
289+
else:
290+
# Return raw bytes for TEXT messages when decode_text=False
291+
self.queue.feed_data(
292+
TUPLE_NEW(
293+
WSMessageTextBytes, (WS_MSG_TYPE_TEXT, payload_merged, "")
294+
),
295+
len(payload_merged),
296+
)
280297
else:
281298
self.queue.feed_data(
282299
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),

aiohttp/client.py

Lines changed: 120 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections.abc import (
1212
Awaitable,
1313
Callable,
14+
Collection,
1415
Coroutine,
1516
Generator,
1617
Iterable,
@@ -19,7 +20,16 @@
1920
)
2021
from contextlib import suppress
2122
from types import TracebackType
22-
from typing import TYPE_CHECKING, Any, Final, Generic, TypedDict, TypeVar
23+
from typing import (
24+
TYPE_CHECKING,
25+
Any,
26+
Final,
27+
Generic,
28+
Literal,
29+
TypedDict,
30+
TypeVar,
31+
overload,
32+
)
2333

2434
import attr
2535
from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
@@ -186,6 +196,30 @@ class _RequestOptions(TypedDict, total=False):
186196
middlewares: Sequence[ClientMiddlewareType] | None
187197

188198

199+
class _WSConnectOptions(TypedDict, total=False):
200+
method: str
201+
protocols: Collection[str]
202+
timeout: "ClientWSTimeout | _SENTINEL"
203+
receive_timeout: float | None
204+
autoclose: bool
205+
autoping: bool
206+
heartbeat: float | None
207+
auth: BasicAuth | None
208+
origin: str | None
209+
params: Query
210+
headers: LooseHeaders | None
211+
proxy: StrOrURL | None
212+
proxy_auth: BasicAuth | None
213+
ssl: SSLContext | bool | Fingerprint
214+
verify_ssl: bool | None
215+
fingerprint: bytes | None
216+
ssl_context: SSLContext | None
217+
server_hostname: str | None
218+
proxy_headers: LooseHeaders | None
219+
compress: int
220+
max_msg_size: int
221+
222+
189223
@attr.s(auto_attribs=True, frozen=True, slots=True)
190224
class ClientTimeout:
191225
total: float | None = None
@@ -214,7 +248,11 @@ class ClientTimeout:
214248
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
215249
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
216250

217-
_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse)
251+
_RetType_co = TypeVar(
252+
"_RetType_co",
253+
bound="ClientResponse | ClientWebSocketResponse[bool]",
254+
covariant=True,
255+
)
218256
_CharsetResolver = Callable[[ClientResponse, bytes], str]
219257

220258

@@ -917,6 +955,35 @@ async def _connect_and_send_request(
917955
)
918956
raise
919957

958+
if sys.version_info >= (3, 11) and TYPE_CHECKING:
959+
960+
@overload
961+
def ws_connect(
962+
self,
963+
url: StrOrURL,
964+
*,
965+
decode_text: Literal[True] = ...,
966+
**kwargs: Unpack[_WSConnectOptions],
967+
) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ...
968+
969+
@overload
970+
def ws_connect(
971+
self,
972+
url: StrOrURL,
973+
*,
974+
decode_text: Literal[False],
975+
**kwargs: Unpack[_WSConnectOptions],
976+
) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ...
977+
978+
@overload
979+
def ws_connect(
980+
self,
981+
url: StrOrURL,
982+
*,
983+
decode_text: bool = ...,
984+
**kwargs: Unpack[_WSConnectOptions],
985+
) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ...
986+
920987
def ws_connect(
921988
self,
922989
url: StrOrURL,
@@ -942,7 +1009,8 @@ def ws_connect(
9421009
proxy_headers: LooseHeaders | None = None,
9431010
compress: int = 0,
9441011
max_msg_size: int = 4 * 1024 * 1024,
945-
) -> "_WSRequestContextManager":
1012+
decode_text: bool = True,
1013+
) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]":
9461014
"""Initiate websocket connection."""
9471015
return _WSRequestContextManager(
9481016
self._ws_connect(
@@ -968,9 +1036,39 @@ def ws_connect(
9681036
proxy_headers=proxy_headers,
9691037
compress=compress,
9701038
max_msg_size=max_msg_size,
1039+
decode_text=decode_text,
9711040
)
9721041
)
9731042

1043+
if sys.version_info >= (3, 11) and TYPE_CHECKING:
1044+
1045+
@overload
1046+
async def _ws_connect(
1047+
self,
1048+
url: StrOrURL,
1049+
*,
1050+
decode_text: Literal[True] = ...,
1051+
**kwargs: Unpack[_WSConnectOptions],
1052+
) -> "ClientWebSocketResponse[Literal[True]]": ...
1053+
1054+
@overload
1055+
async def _ws_connect(
1056+
self,
1057+
url: StrOrURL,
1058+
*,
1059+
decode_text: Literal[False],
1060+
**kwargs: Unpack[_WSConnectOptions],
1061+
) -> "ClientWebSocketResponse[Literal[False]]": ...
1062+
1063+
@overload
1064+
async def _ws_connect(
1065+
self,
1066+
url: StrOrURL,
1067+
*,
1068+
decode_text: bool = ...,
1069+
**kwargs: Unpack[_WSConnectOptions],
1070+
) -> "ClientWebSocketResponse[bool]": ...
1071+
9741072
async def _ws_connect(
9751073
self,
9761074
url: StrOrURL,
@@ -996,7 +1094,8 @@ async def _ws_connect(
9961094
proxy_headers: LooseHeaders | None = None,
9971095
compress: int = 0,
9981096
max_msg_size: int = 4 * 1024 * 1024,
999-
) -> ClientWebSocketResponse:
1097+
decode_text: bool = True,
1098+
) -> "ClientWebSocketResponse[bool]":
10001099
if timeout is not sentinel:
10011100
if isinstance(timeout, ClientWSTimeout):
10021101
ws_timeout = timeout
@@ -1162,7 +1261,9 @@ async def _ws_connect(
11621261
transport = conn.transport
11631262
assert transport is not None
11641263
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
1165-
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
1264+
conn_proto.set_parser(
1265+
WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader
1266+
)
11661267
writer = WebSocketWriter(
11671268
conn_proto,
11681269
transport,
@@ -1467,32 +1568,34 @@ async def __aexit__(
14671568
await self.close()
14681569

14691570

1470-
class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]):
1571+
class _BaseRequestContextManager(
1572+
Coroutine[Any, Any, _RetType_co], Generic[_RetType_co]
1573+
):
14711574

14721575
__slots__ = ("_coro", "_resp")
14731576

1474-
def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None:
1475-
self._coro: Coroutine[asyncio.Future[Any], None, _RetType] = coro
1577+
def __init__(self, coro: Coroutine[asyncio.Future[Any], None, _RetType_co]) -> None:
1578+
self._coro: Coroutine[asyncio.Future[Any], None, _RetType_co] = coro
14761579

1477-
def send(self, arg: None) -> "asyncio.Future[Any]":
1580+
def send(self, arg: None) -> asyncio.Future[Any]:
14781581
return self._coro.send(arg)
14791582

1480-
def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]":
1583+
def throw(self, *args: Any, **kwargs: Any) -> asyncio.Future[Any]:
14811584
return self._coro.throw(*args, **kwargs)
14821585

14831586
def close(self) -> None:
14841587
return self._coro.close()
14851588

1486-
def __await__(self) -> Generator[Any, None, _RetType]:
1589+
def __await__(self) -> Generator[Any, None, _RetType_co]:
14871590
ret = self._coro.__await__()
14881591
return ret
14891592

1490-
def __iter__(self) -> Generator[Any, None, _RetType]:
1593+
def __iter__(self) -> Generator[Any, None, _RetType_co]:
14911594
return self.__await__()
14921595

1493-
async def __aenter__(self) -> _RetType:
1494-
self._resp: _RetType = await self._coro
1495-
return await self._resp.__aenter__()
1596+
async def __aenter__(self) -> _RetType_co:
1597+
self._resp: _RetType_co = await self._coro
1598+
return await self._resp.__aenter__() # type: ignore[return-value]
14961599

14971600
async def __aexit__(
14981601
self,
@@ -1504,7 +1607,7 @@ async def __aexit__(
15041607

15051608

15061609
_RequestContextManager = _BaseRequestContextManager[ClientResponse]
1507-
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse]
1610+
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse[bool]]
15081611

15091612

15101613
class _SessionRequestContextManager:
@@ -1513,7 +1616,7 @@ class _SessionRequestContextManager:
15131616

15141617
def __init__(
15151618
self,
1516-
coro: Coroutine["asyncio.Future[Any]", None, ClientResponse],
1619+
coro: Coroutine[asyncio.Future[Any], None, ClientResponse],
15171620
session: ClientSession,
15181621
) -> None:
15191622
self._coro = coro

0 commit comments

Comments
 (0)