From 11f2435fe7f76f089ffa39c253b743e93a2c39e9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 17 Nov 2025 12:09:42 -0600 Subject: [PATCH 01/42] Add option to avoid decoding WebSocket TEXT --- aiohttp/_websocket/models.py | 16 +++ aiohttp/_websocket/reader_c.pxd | 2 + aiohttp/_websocket/reader_py.py | 38 ++++-- aiohttp/client.py | 128 +++++++++++++++++- aiohttp/client_ws.py | 30 ++++- aiohttp/http_websocket.py | 2 + aiohttp/test_utils.py | 19 ++- aiohttp/web_ws.py | 37 +++++- tests/test_client_ws_functional.py | 171 +++++++++++++++++++++++++ tests/test_web_websocket_functional.py | 160 +++++++++++++++++++++++ 10 files changed, 572 insertions(+), 31 deletions(-) diff --git a/aiohttp/_websocket/models.py b/aiohttp/_websocket/models.py index 085fb460cb5..b42f88fd5f3 100644 --- a/aiohttp/_websocket/models.py +++ b/aiohttp/_websocket/models.py @@ -59,6 +59,21 @@ def json( return loads(self.data) +class WSMessageTextBytes(NamedTuple): + """WebSocket TEXT message with raw bytes (no UTF-8 decoding).""" + + data: bytes + size: int + extra: str | None = None + type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT + + def json( + self, *, loads: Callable[[str | bytes | bytearray], Any] = json.loads + ) -> Any: + """Return parsed JSON data.""" + return loads(self.data) + + class WSMessageBinary(NamedTuple): data: bytes size: int @@ -117,6 +132,7 @@ class WSMessageError(NamedTuple): WSMessage = Union[ WSMessageContinuation, WSMessageText, + WSMessageTextBytes, WSMessageBinary, WSMessagePing, WSMessagePong, diff --git a/aiohttp/_websocket/reader_c.pxd b/aiohttp/_websocket/reader_c.pxd index 9a6fdae3e97..7e5e46f13c7 100644 --- a/aiohttp/_websocket/reader_c.pxd +++ b/aiohttp/_websocket/reader_c.pxd @@ -27,6 +27,7 @@ cdef object TUPLE_NEW cdef object WSMsgType cdef object WSMessageText +cdef object WSMessageTextBytes cdef object WSMessageBinary cdef object WSMessagePing cdef object WSMessagePong @@ -66,6 +67,7 @@ cdef class WebSocketReader: cdef WebSocketDataQueue queue cdef unsigned int _max_msg_size + cdef bint _decode_text cdef Exception _exc cdef bytearray _partial diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index 5bcc2ecfb78..e0088a47af8 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -20,6 +20,7 @@ WSMessagePing, WSMessagePong, WSMessageText, + WSMessageTextBytes, WSMsgType, ) @@ -139,10 +140,15 @@ def _read_from_buffer(self) -> WSMessage: class WebSocketReader: def __init__( - self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True + self, + queue: WebSocketDataQueue, + max_msg_size: int, + compress: bool = True, + decode_text: bool = True, ) -> None: self.queue = queue self._max_msg_size = max_msg_size + self._decode_text = decode_text self._exc: Exception | None = None self._partial = bytearray() @@ -270,18 +276,24 @@ def _handle_frame( size = len(payload_merged) if opcode == OP_CODE_TEXT: - try: - text = payload_merged.decode("utf-8") - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" - ) from exc - - # XXX: The Text and Binary messages here can be a performance - # bottleneck, so we use tuple.__new__ to improve performance. - # This is not type safe, but many tests should fail in - # test_client_ws_functional.py if this is wrong. - msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) + if self._decode_text: + try: + text = payload_merged.decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + + # XXX: The Text and Binary messages here can be a performance + # bottleneck, so we use tuple.__new__ to improve performance. + # This is not type safe, but many tests should fail in + # test_client_ws_functional.py if this is wrong. + msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) + else: + # Return raw bytes for TEXT messages when decode_text=False + msg = TUPLE_NEW( + WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT) + ) else: msg = TUPLE_NEW( WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY) diff --git a/aiohttp/client.py b/aiohttp/client.py index 026006023ce..861ccf278d6 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -21,7 +21,17 @@ ) from contextlib import suppress from types import TracebackType -from typing import TYPE_CHECKING, Any, Final, Generic, TypedDict, TypeVar, final +from typing import ( + TYPE_CHECKING, + Any, + Final, + Generic, + Literal, + TypedDict, + TypeVar, + final, + overload, +) from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr from yarl import URL, Query @@ -866,6 +876,59 @@ async def _connect_and_send_request( ) raise + # Overloads for type-safe decode_text parameter + @overload + def ws_connect( + self, + url: StrOrURL, + *, + method: str = ..., + protocols: Collection[str] = ..., + timeout: ClientWSTimeout | _SENTINEL = ..., + receive_timeout: float | None = ..., + autoclose: bool = ..., + autoping: bool = ..., + heartbeat: float | None = ..., + auth: BasicAuth | None = ..., + origin: str | None = ..., + params: Query = ..., + headers: LooseHeaders | None = ..., + proxy: StrOrURL | None = ..., + proxy_auth: BasicAuth | None = ..., + ssl: SSLContext | bool | Fingerprint = ..., + server_hostname: str | None = ..., + proxy_headers: LooseHeaders | None = ..., + compress: int = ..., + max_msg_size: int = ..., + decode_text: Literal[True] = ..., + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ... + + @overload + def ws_connect( + self, + url: StrOrURL, + *, + method: str = ..., + protocols: Collection[str] = ..., + timeout: ClientWSTimeout | _SENTINEL = ..., + receive_timeout: float | None = ..., + autoclose: bool = ..., + autoping: bool = ..., + heartbeat: float | None = ..., + auth: BasicAuth | None = ..., + origin: str | None = ..., + params: Query = ..., + headers: LooseHeaders | None = ..., + proxy: StrOrURL | None = ..., + proxy_auth: BasicAuth | None = ..., + ssl: SSLContext | bool | Fingerprint = ..., + server_hostname: str | None = ..., + proxy_headers: LooseHeaders | None = ..., + compress: int = ..., + max_msg_size: int = ..., + decode_text: Literal[False] = ..., + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... + def ws_connect( self, url: StrOrURL, @@ -888,7 +951,8 @@ def ws_connect( proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, - ) -> "_WSRequestContextManager": + decode_text: bool = True, + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Any]]": """Initiate websocket connection.""" return _WSRequestContextManager( self._ws_connect( @@ -911,9 +975,62 @@ def ws_connect( proxy_headers=proxy_headers, compress=compress, max_msg_size=max_msg_size, + decode_text=decode_text, ) ) + @overload + async def _ws_connect( + self, + url: StrOrURL, + *, + method: str = ..., + protocols: Collection[str] = ..., + timeout: ClientWSTimeout | _SENTINEL = ..., + receive_timeout: float | None = ..., + autoclose: bool = ..., + autoping: bool = ..., + heartbeat: float | None = ..., + auth: BasicAuth | None = ..., + origin: str | None = ..., + params: Query = ..., + headers: LooseHeaders | None = ..., + proxy: StrOrURL | None = ..., + proxy_auth: BasicAuth | None = ..., + ssl: SSLContext | bool | Fingerprint = ..., + server_hostname: str | None = ..., + proxy_headers: LooseHeaders | None = ..., + compress: int = ..., + max_msg_size: int = ..., + decode_text: Literal[True] = ..., + ) -> "ClientWebSocketResponse[Literal[True]]": ... + + @overload + async def _ws_connect( + self, + url: StrOrURL, + *, + method: str = ..., + protocols: Collection[str] = ..., + timeout: ClientWSTimeout | _SENTINEL = ..., + receive_timeout: float | None = ..., + autoclose: bool = ..., + autoping: bool = ..., + heartbeat: float | None = ..., + auth: BasicAuth | None = ..., + origin: str | None = ..., + params: Query = ..., + headers: LooseHeaders | None = ..., + proxy: StrOrURL | None = ..., + proxy_auth: BasicAuth | None = ..., + ssl: SSLContext | bool | Fingerprint = ..., + server_hostname: str | None = ..., + proxy_headers: LooseHeaders | None = ..., + compress: int = ..., + max_msg_size: int = ..., + decode_text: Literal[False] = ..., + ) -> "ClientWebSocketResponse[Literal[False]]": ... + async def _ws_connect( self, url: StrOrURL, @@ -936,7 +1053,8 @@ async def _ws_connect( proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, - ) -> ClientWebSocketResponse: + decode_text: bool = True, + ) -> "ClientWebSocketResponse[Any]": if timeout is not sentinel: if isinstance(timeout, ClientWSTimeout): ws_timeout = timeout @@ -1098,7 +1216,9 @@ async def _ws_connect( transport = conn.transport assert transport is not None reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop) - conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader) + conn_proto.set_parser( + WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader + ) writer = WebSocketWriter( conn_proto, transport, diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 36959aae0c7..77d695b0d9e 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -3,7 +3,7 @@ import asyncio import sys from types import TracebackType -from typing import Any, Final +from typing import Any, Final, Generic, Literal, TypeVar, overload from ._websocket.reader import WebSocketDataQueue from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError @@ -28,9 +28,15 @@ if sys.version_info >= (3, 11): import asyncio as async_timeout + from typing import Self else: import async_timeout + Self = TypeVar("Self", bound="ClientWebSocketResponse[Any]") + +# TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) +_DecodeText = TypeVar("_DecodeText", bound=bool, default=Literal[True]) + @frozen_dataclass_decorator class ClientWSTimeout: @@ -43,7 +49,7 @@ class ClientWSTimeout: ) -class ClientWebSocketResponse: +class ClientWebSocketResponse(Generic[_DecodeText]): def __init__( self, reader: WebSocketDataQueue, @@ -383,7 +389,21 @@ async def receive(self, timeout: float | None = None) -> WSMessage: return msg - async def receive_str(self, *, timeout: float | None = None) -> str: + @overload + async def receive_str( + self: "ClientWebSocketResponse[Literal[True]]", *, timeout: float | None = None + ) -> str: ... + + @overload + async def receive_str( + self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None + ) -> bytes: ... + + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: + """Receive TEXT message. + + Returns str when decode_text=True (default), bytes when decode_text=False. + """ msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( @@ -408,7 +428,7 @@ async def receive_json( data = await self.receive_str(timeout=timeout) return loads(data) - def __aiter__(self) -> "ClientWebSocketResponse": + def __aiter__(self) -> Self: return self async def __anext__(self) -> WSMessage: @@ -417,7 +437,7 @@ async def __anext__(self) -> WSMessage: raise StopAsyncIteration return msg - async def __aenter__(self) -> "ClientWebSocketResponse": + async def __aenter__(self) -> Self: return self async def __aexit__( diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index f49d8aee287..830318c0b9a 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -17,6 +17,7 @@ WSMessagePing, WSMessagePong, WSMessageText, + WSMessageTextBytes, WSMsgType, ) from ._websocket.reader import WebSocketReader @@ -48,6 +49,7 @@ "WSMessagePong", "WSMessageBinary", "WSMessageText", + "WSMessageTextBytes", "WSMessagePing", "WSMessageContinuation", ) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index 192173b42c8..fc0c4ed1b1e 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterator from types import TracebackType -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload from unittest import IsolatedAsyncioTestCase, mock from aiosignal import Signal @@ -19,6 +19,7 @@ import aiohttp from aiohttp.client import ( + _BaseRequestContextManager, _RequestContextManager, _RequestOptions, _WSRequestContextManager, @@ -429,7 +430,19 @@ def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: self._request(hdrs.METH_DELETE, path, **kwargs) ) - def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: + @overload + def ws_connect( + self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ... + + @overload + def ws_connect( + self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... + + def ws_connect( + self, path: StrOrURL, **kwargs: Any + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Any]]": """Initiate websocket connection. The api corresponds to aiohttp.ClientSession.ws_connect. @@ -439,7 +452,7 @@ def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: async def _ws_connect( self, path: StrOrURL, **kwargs: Any - ) -> ClientWebSocketResponse: + ) -> "ClientWebSocketResponse[Any]": ws = await self._session.ws_connect(self.make_url(path), **kwargs) self._websockets.append(ws) return ws diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 8eee7e3ad71..6731ac48b89 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -5,7 +5,7 @@ import json import sys from collections.abc import Iterable -from typing import Any, Final, Union +from typing import Any, Final, Generic, Literal, TypeVar, Union, overload from multidict import CIMultiDict @@ -43,9 +43,12 @@ if sys.version_info >= (3, 11): import asyncio as async_timeout + from typing import Self else: import async_timeout + Self = TypeVar("Self", bound="WebSocketResponse[Any]") + __all__ = ( "WebSocketResponse", "WebSocketReady", @@ -54,6 +57,9 @@ THRESHOLD_CONNLOST_ACCESS: Final[int] = 5 +# TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) +_DecodeText = TypeVar("_DecodeText", bound=bool, default=Literal[True]) + @frozen_dataclass_decorator class WebSocketReady: @@ -64,7 +70,7 @@ def __bool__(self) -> bool: return self.ok -class WebSocketResponse(StreamResponse): +class WebSocketResponse(StreamResponse, Generic[_DecodeText]): _length_check: bool = False _ws_protocol: str | None = None @@ -95,6 +101,7 @@ def __init__( compress: bool = True, max_msg_size: int = 4 * 1024 * 1024, writer_limit: int = DEFAULT_LIMIT, + decode_text: bool = True, ) -> None: super().__init__(status=101) self._protocols = protocols @@ -108,6 +115,7 @@ def __init__( self._compress: bool | int = compress self._max_msg_size = max_msg_size self._writer_limit = writer_limit + self._decode_text = decode_text def _cancel_heartbeat(self) -> None: self._cancel_pong_response_cb() @@ -341,7 +349,10 @@ def _post_start( self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop) request.protocol.set_parser( WebSocketReader( - self._reader, self._max_msg_size, compress=bool(self._compress) + self._reader, + self._max_msg_size, + compress=bool(self._compress), + decode_text=self._decode_text, ) ) # disable HTTP keepalive for WebSocket @@ -588,13 +599,27 @@ async def receive(self, timeout: float | None = None) -> WSMessage: return msg - async def receive_str(self, *, timeout: float | None = None) -> str: + @overload + async def receive_str( + self: "WebSocketResponse[Literal[True]]", *, timeout: float | None = None + ) -> str: ... + + @overload + async def receive_str( + self: "WebSocketResponse[Literal[False]]", *, timeout: float | None = None + ) -> bytes: ... + + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: + """Receive TEXT message. + + Returns str when decode_text=True (default), bytes when decode_text=False. + """ msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" ) - return msg.data + return msg.data # type: ignore[return-value] async def receive_bytes(self, *, timeout: float | None = None) -> bytes: msg = await self.receive(timeout) @@ -615,7 +640,7 @@ async def write( ) -> None: raise RuntimeError("Cannot call .write() for websocket") - def __aiter__(self) -> "WebSocketResponse": + def __aiter__(self) -> Self: return self async def __anext__(self) -> WSMessage: diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 0bc05f300d4..0e8e3830673 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1305,3 +1305,174 @@ async def websocket_task() -> None: # Cleanup properly websocket._response = mock.Mock() await websocket.close() + + +async def test_receive_text_as_bytes_client_side(aiohttp_client: AiohttpClient) -> None: + """Test client receiving TEXT messages as raw bytes with decode_text=False.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + await ws.send_str(msg + "/answer") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + # Connect with decode_text=False + resp = await client.ws_connect("/", decode_text=False) + await resp.send_str("ask") + + # Receive TEXT message as bytes + msg = await resp.receive() + assert msg.type is WSMsgType.TEXT + assert isinstance(msg.data, bytes) + assert msg.data == b"ask/answer" + + await resp.close() + + +async def test_receive_text_as_bytes_server_side(aiohttp_client: AiohttpClient) -> None: + """Test server receiving TEXT messages as raw bytes with decode_text=False.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse(decode_text=False) + await ws.prepare(request) + + # Receive TEXT message as bytes + msg = await ws.receive() + assert msg.type is WSMsgType.TEXT + assert isinstance(msg.data, bytes) + assert msg.data == b"test message" + + # Send response + await ws.send_bytes(msg.data + b"/reply") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + resp = await client.ws_connect("/") + await resp.send_str("test message") + + msg = await resp.receive() + assert msg.type is WSMsgType.BINARY + assert msg.data == b"test message/reply" + + await resp.close() + + +async def test_receive_text_as_bytes_json_parsing( + aiohttp_client: AiohttpClient, +) -> None: + """Test using orjson or similar parsers with raw bytes from TEXT messages.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + data = json.loads(msg) + await ws.send_str(json.dumps({"response": data["value"] * 2})) + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + # Connect with decode_text=False to get raw bytes + resp = await client.ws_connect("/", decode_text=False) + await resp.send_str(json.dumps({"value": 42})) + + # Receive TEXT message as bytes + msg = await resp.receive() + assert msg.type is WSMsgType.TEXT + assert isinstance(msg.data, bytes) + + # Parse JSON directly from bytes (like orjson would) + data = json.loads(msg.data) + assert data == {"response": 84} + + await resp.close() + + +async def test_decode_text_default_true(aiohttp_client: AiohttpClient) -> None: + """Test that decode_text defaults to True for backward compatibility.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive_str() + await ws.send_str(msg + "/reply") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + # Default behavior (decode_text=True) + resp = await client.ws_connect("/") + await resp.send_str("test") + + # Should receive TEXT message as string + msg = await resp.receive() + assert msg.type is WSMsgType.TEXT + assert isinstance(msg.data, str) + assert msg.data == "test/reply" + + await resp.close() + + +async def test_receive_str_returns_bytes_with_decode_text_false( + aiohttp_client: AiohttpClient, +) -> None: + """Test that receive_str() returns bytes when decode_text=False.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_str("hello world") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/", decode_text=False) as ws: + # receive_str() should return bytes when decode_text=False + data = await ws.receive_str() + assert isinstance(data, bytes) + assert data == b"hello world" + + +async def test_receive_str_returns_str_with_decode_text_true( + aiohttp_client: AiohttpClient, +) -> None: + """Test that receive_str() returns str when decode_text=True (default).""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_str("hello world") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + # receive_str() should return str when decode_text=True (default) + data = await ws.receive_str() + assert isinstance(data, str) + assert data == "hello world" diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index afa76e2d742..8c01cc5a973 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1445,3 +1445,163 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert msg.type is WSMsgType.TEXT assert msg.data == "test" await ws.close() + + +async def test_receive_text_as_bytes_server_side(aiohttp_client: AiohttpClient) -> None: + """Test server receiving TEXT messages as raw bytes with decode_text=False.""" + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse(decode_text=False) + await ws.prepare(request) + + # Receive TEXT message as bytes + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert isinstance(msg.data, bytes) + assert msg.data == b"test message" + + # Send response + await ws.send_bytes(msg.data + b"/reply") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + await ws.send_str("test message") + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.BINARY + assert msg.data == b"test message/reply" + + await ws.close() + + +async def test_receive_text_as_bytes_server_iteration( + aiohttp_client: AiohttpClient, +) -> None: + """Test server iterating over WebSocket with decode_text=False.""" + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse(decode_text=False) + await ws.prepare(request) + + async for msg in ws: + if msg.type is aiohttp.WSMsgType.TEXT: + # msg.data should be bytes + assert isinstance(msg.data, bytes) + # Echo back + await ws.send_bytes(msg.data) + elif msg.type is aiohttp.WSMsgType.BINARY: + await ws.send_bytes(msg.data) + + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + # Send TEXT message + await ws.send_str("hello") + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.BINARY + assert msg.data == b"hello" + + # Send BINARY message + await ws.send_bytes(b"world") + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.BINARY + assert msg.data == b"world" + + await ws.close() + + +async def test_server_decode_text_default_true(aiohttp_client: AiohttpClient) -> None: + """Test that server decode_text defaults to True for backward compatibility.""" + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + # No decode_text parameter - should default to True + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert isinstance(msg.data, str) + assert msg.data == "test" + + await ws.send_str(msg.data + "/reply") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + await ws.send_str("test") + + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert isinstance(msg.data, str) + assert msg.data == "test/reply" + + await ws.close() + + +async def test_server_receive_str_returns_bytes_with_decode_text_false( + aiohttp_client: AiohttpClient, +) -> None: + """Test that server receive_str() returns bytes when decode_text=False.""" + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse(decode_text=False) + await ws.prepare(request) + + # receive_str() should return bytes when decode_text=False + data = await ws.receive_str() + assert isinstance(data, bytes) + assert data == b"hello server" + + await ws.send_str("got bytes") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + await ws.send_str("hello server") + msg = await ws.receive() + assert msg.data == "got bytes" + + +async def test_server_receive_str_returns_str_with_decode_text_true( + aiohttp_client: AiohttpClient, +) -> None: + """Test that server receive_str() returns str when decode_text=True (default).""" + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() # decode_text=True by default + await ws.prepare(request) + + # receive_str() should return str when decode_text=True + data = await ws.receive_str() + assert isinstance(data, str) + assert data == "hello server" + + await ws.send_str("got string") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/") as ws: + await ws.send_str("hello server") + msg = await ws.receive() + assert msg.data == "got string" From 38b3c346c4878d541f6b600624c76b65cd375345 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 17 Nov 2025 12:12:07 -0600 Subject: [PATCH 02/42] touch ups --- aiohttp/web_ws.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 6731ac48b89..e3293de2b12 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -619,7 +619,7 @@ async def receive_str(self, *, timeout: float | None = None) -> str | bytes: raise WSMessageTypeError( f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" ) - return msg.data # type: ignore[return-value] + return msg.data async def receive_bytes(self, *, timeout: float | None = None) -> bytes: msg = await self.receive(timeout) From 104666b740bb30d788a8d6e560865e75a705228a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 17 Nov 2025 12:17:38 -0600 Subject: [PATCH 03/42] decode type --- aiohttp/client.py | 2 +- aiohttp/client_ws.py | 21 ++++++++++++++++++++- aiohttp/web_ws.py | 25 +++++++++++++++++++++++-- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index 861ccf278d6..cd51fcd3537 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -225,7 +225,7 @@ class ClientTimeout: # https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) -_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse) +_RetType = TypeVar("_RetType", bound="ClientResponse | ClientWebSocketResponse[Any]") _CharsetResolver = Callable[[ClientResponse, bytes], str] diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 77d695b0d9e..550dc6bda64 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -2,6 +2,7 @@ import asyncio import sys +from collections.abc import Callable from types import TracebackType from typing import Any, Final, Generic, Literal, TypeVar, overload @@ -419,10 +420,28 @@ async def receive_bytes(self, *, timeout: float | None = None) -> bytes: ) return msg.data + @overload + async def receive_json( + self: "ClientWebSocketResponse[Literal[True]]", + *, + loads: JSONDecoder = ..., + timeout: float | None = None, + ) -> Any: ... + + @overload + async def receive_json( + self: "ClientWebSocketResponse[Literal[False]]", + *, + loads: Callable[[bytes | bytearray | memoryview | str], Any] = ..., + timeout: float | None = None, + ) -> Any: ... + async def receive_json( self, *, - loads: JSONDecoder = DEFAULT_JSON_DECODER, + loads: ( + JSONDecoder | Callable[[bytes | bytearray | memoryview | str], Any] + ) = DEFAULT_JSON_DECODER, timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index e3293de2b12..d9c0d8ada3a 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -4,7 +4,7 @@ import hashlib import json import sys -from collections.abc import Iterable +from collections.abc import Callable, Iterable from typing import Any, Final, Generic, Literal, TypeVar, Union, overload from multidict import CIMultiDict @@ -629,8 +629,29 @@ async def receive_bytes(self, *, timeout: float | None = None) -> bytes: ) return msg.data + @overload + async def receive_json( + self: "WebSocketResponse[Literal[True]]", + *, + loads: JSONDecoder = ..., + timeout: float | None = None, + ) -> Any: ... + + @overload async def receive_json( - self, *, loads: JSONDecoder = json.loads, timeout: float | None = None + self: "WebSocketResponse[Literal[False]]", + *, + loads: Callable[[bytes | bytearray | memoryview | str], Any] = ..., + timeout: float | None = None, + ) -> Any: ... + + async def receive_json( + self, + *, + loads: ( + JSONDecoder | Callable[[bytes | bytearray | memoryview | str], Any] + ) = json.loads, + timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) return loads(data) From 348529d69d0861783932809b0c01c6612ea6f630 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 17 Nov 2025 12:21:40 -0600 Subject: [PATCH 04/42] coverage for bytes loads --- aiohttp/web_ws.py | 2 +- tests/test_client_ws_functional.py | 28 ++++++++++++++++++++ tests/test_web_websocket_functional.py | 36 ++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index d9c0d8ada3a..959f71cc19b 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -653,7 +653,7 @@ async def receive_json( ) = json.loads, timeout: float | None = None, ) -> Any: - data = await self.receive_str(timeout=timeout) + data: str | bytes = await self.receive_str(timeout=timeout) return loads(data) async def write( diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 0e8e3830673..a06eb4e14c6 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1476,3 +1476,31 @@ async def handler(request: web.Request) -> web.WebSocketResponse: data = await ws.receive_str() assert isinstance(data, str) assert data == "hello world" + + +async def test_receive_json_with_orjson_style_loads( + aiohttp_client: AiohttpClient, +) -> None: + """Test receive_json() with orjson-style loads that accepts bytes.""" + + def orjson_style_loads(data: bytes | bytearray | memoryview | str) -> dict: + """Mock orjson.loads that accepts bytes/str.""" + if isinstance(data, (bytes, bytearray, memoryview)): + data = bytes(data).decode("utf-8") + return json.loads(data) + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.send_str('{"value": 42}') + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/", decode_text=False) as ws: + # receive_json() with orjson-style loads should work with bytes + data = await ws.receive_json(loads=orjson_style_loads) + assert data == {"value": 42} diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 8c01cc5a973..7dbfcdcda8c 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import json import sys import weakref from typing import NoReturn @@ -1605,3 +1606,38 @@ async def websocket_handler(request: web.Request) -> web.WebSocketResponse: await ws.send_str("hello server") msg = await ws.receive() assert msg.data == "got string" + + +async def test_server_receive_json_with_orjson_style_loads( + aiohttp_client: AiohttpClient, +) -> None: + """Test server receive_json() with orjson-style loads that accepts bytes.""" + + def orjson_style_loads(data: bytes | bytearray | memoryview | str) -> dict: + """Mock orjson.loads that accepts bytes/str.""" + if isinstance(data, (bytes, bytearray, memoryview)): + data = bytes(data).decode("utf-8") + return json.loads(data) + + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse(decode_text=False) + await ws.prepare(request) + + # receive_json() with orjson-style loads should work with bytes + data = await ws.receive_json(loads=orjson_style_loads) + assert data == {"test": "value"} + + await ws.send_str("success") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", websocket_handler) + client = await aiohttp_client(app) + + ws = await client.ws_connect("/") + await ws.send_str('{"test": "value"}') + msg = await ws.receive() + assert msg.type is aiohttp.WSMsgType.TEXT + assert msg.data == "success" + await ws.close() From 421896abcaafc5f3b0f7266ed04b48d3bbb920c3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:28:44 -0600 Subject: [PATCH 05/42] unpack --- aiohttp/client.py | 162 +++++++++++++++++----------------------------- 1 file changed, 59 insertions(+), 103 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index e70cfa1ea09..f840e9ed6eb 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -197,6 +197,27 @@ class _RequestOptions(TypedDict, total=False): middlewares: Sequence[ClientMiddlewareType] | None +class _WSConnectOptions(TypedDict, total=False): + method: str + protocols: Collection[str] + timeout: "ClientWSTimeout | _SENTINEL" + receive_timeout: float | None + autoclose: bool + autoping: bool + heartbeat: float | None + auth: BasicAuth | None + origin: str | None + params: Query + headers: LooseHeaders | None + proxy: StrOrURL | None + proxy_auth: BasicAuth | None + ssl: SSLContext | bool | Fingerprint + server_hostname: str | None + proxy_headers: LooseHeaders | None + compress: int + max_msg_size: int + + @frozen_dataclass_decorator class ClientTimeout: total: float | None = None @@ -876,58 +897,25 @@ async def _connect_and_send_request( ) raise - # Overloads for type-safe decode_text parameter - @overload - def ws_connect( - self, - url: StrOrURL, - *, - method: str = ..., - protocols: Collection[str] = ..., - timeout: ClientWSTimeout | _SENTINEL = ..., - receive_timeout: float | None = ..., - autoclose: bool = ..., - autoping: bool = ..., - heartbeat: float | None = ..., - auth: BasicAuth | None = ..., - origin: str | None = ..., - params: Query = ..., - headers: LooseHeaders | None = ..., - proxy: StrOrURL | None = ..., - proxy_auth: BasicAuth | None = ..., - ssl: SSLContext | bool | Fingerprint = ..., - server_hostname: str | None = ..., - proxy_headers: LooseHeaders | None = ..., - compress: int = ..., - max_msg_size: int = ..., - decode_text: Literal[True] = ..., - ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ... - - @overload - def ws_connect( - self, - url: StrOrURL, - *, - method: str = ..., - protocols: Collection[str] = ..., - timeout: ClientWSTimeout | _SENTINEL = ..., - receive_timeout: float | None = ..., - autoclose: bool = ..., - autoping: bool = ..., - heartbeat: float | None = ..., - auth: BasicAuth | None = ..., - origin: str | None = ..., - params: Query = ..., - headers: LooseHeaders | None = ..., - proxy: StrOrURL | None = ..., - proxy_auth: BasicAuth | None = ..., - ssl: SSLContext | bool | Fingerprint = ..., - server_hostname: str | None = ..., - proxy_headers: LooseHeaders | None = ..., - compress: int = ..., - max_msg_size: int = ..., - decode_text: Literal[False] = ..., - ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + @overload + def ws_connect( + self, + url: StrOrURL, + *, + decode_text: Literal[True] = ..., + **kwargs: Unpack[_WSConnectOptions], + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ... + + @overload + def ws_connect( + self, + url: StrOrURL, + *, + decode_text: Literal[False], + **kwargs: Unpack[_WSConnectOptions], + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... def ws_connect( self, @@ -979,57 +967,25 @@ def ws_connect( ) ) - @overload - async def _ws_connect( - self, - url: StrOrURL, - *, - method: str = ..., - protocols: Collection[str] = ..., - timeout: ClientWSTimeout | _SENTINEL = ..., - receive_timeout: float | None = ..., - autoclose: bool = ..., - autoping: bool = ..., - heartbeat: float | None = ..., - auth: BasicAuth | None = ..., - origin: str | None = ..., - params: Query = ..., - headers: LooseHeaders | None = ..., - proxy: StrOrURL | None = ..., - proxy_auth: BasicAuth | None = ..., - ssl: SSLContext | bool | Fingerprint = ..., - server_hostname: str | None = ..., - proxy_headers: LooseHeaders | None = ..., - compress: int = ..., - max_msg_size: int = ..., - decode_text: Literal[True] = ..., - ) -> "ClientWebSocketResponse[Literal[True]]": ... - - @overload - async def _ws_connect( - self, - url: StrOrURL, - *, - method: str = ..., - protocols: Collection[str] = ..., - timeout: ClientWSTimeout | _SENTINEL = ..., - receive_timeout: float | None = ..., - autoclose: bool = ..., - autoping: bool = ..., - heartbeat: float | None = ..., - auth: BasicAuth | None = ..., - origin: str | None = ..., - params: Query = ..., - headers: LooseHeaders | None = ..., - proxy: StrOrURL | None = ..., - proxy_auth: BasicAuth | None = ..., - ssl: SSLContext | bool | Fingerprint = ..., - server_hostname: str | None = ..., - proxy_headers: LooseHeaders | None = ..., - compress: int = ..., - max_msg_size: int = ..., - decode_text: Literal[False] = ..., - ) -> "ClientWebSocketResponse[Literal[False]]": ... + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + @overload + async def _ws_connect( + self, + url: StrOrURL, + *, + decode_text: Literal[True] = ..., + **kwargs: Unpack[_WSConnectOptions], + ) -> "ClientWebSocketResponse[Literal[True]]": ... + + @overload + async def _ws_connect( + self, + url: StrOrURL, + *, + decode_text: Literal[False], + **kwargs: Unpack[_WSConnectOptions], + ) -> "ClientWebSocketResponse[Literal[False]]": ... async def _ws_connect( self, From e9c5a7add6a2f25b3b4454a64489e85ce5b6a468 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:32:19 -0600 Subject: [PATCH 06/42] fallback overloads --- aiohttp/client.py | 18 ++++++++++++++++++ aiohttp/client_ws.py | 3 +++ 2 files changed, 21 insertions(+) diff --git a/aiohttp/client.py b/aiohttp/client.py index f840e9ed6eb..a4e9ee13de2 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -917,6 +917,15 @@ def ws_connect( **kwargs: Unpack[_WSConnectOptions], ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... + @overload + def ws_connect( + self, + url: StrOrURL, + *, + decode_text: bool = ..., + **kwargs: Unpack[_WSConnectOptions], + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Any]]": ... + def ws_connect( self, url: StrOrURL, @@ -987,6 +996,15 @@ async def _ws_connect( **kwargs: Unpack[_WSConnectOptions], ) -> "ClientWebSocketResponse[Literal[False]]": ... + @overload + async def _ws_connect( + self, + url: StrOrURL, + *, + decode_text: bool = ..., + **kwargs: Unpack[_WSConnectOptions], + ) -> "ClientWebSocketResponse[Any]": ... + async def _ws_connect( self, url: StrOrURL, diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 550dc6bda64..16f3f0770d0 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -400,6 +400,9 @@ async def receive_str( self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None ) -> bytes: ... + @overload + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: ... + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. From 10f7c5285a97bac30b40c883c0f3c7457b821575 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:33:08 -0600 Subject: [PATCH 07/42] fallback overloads --- aiohttp/client_ws.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 16f3f0770d0..3487fadee1d 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -439,6 +439,14 @@ async def receive_json( timeout: float | None = None, ) -> Any: ... + @overload + async def receive_json( + self, + *, + loads: Callable[[bytes | bytearray | memoryview | str], Any] = ..., + timeout: float | None = None, + ) -> Any: ... + async def receive_json( self, *, From baab16e893261d34ccf6987154c598eb29a70756 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:33:44 -0600 Subject: [PATCH 08/42] fallback overloads --- aiohttp/client_ws.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 3487fadee1d..a0556140c98 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -456,7 +456,7 @@ async def receive_json( timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) - return loads(data) + return loads(data) # type: ignore[arg-type] def __aiter__(self) -> Self: return self From 3c3ad78e4a4c5656e307de35a4843bca194f91e7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:35:05 -0600 Subject: [PATCH 09/42] fallback overloads --- aiohttp/connector.py | 578 ++++++++++++++++++++++++++++--------------- aiohttp/web_ws.py | 15 +- 2 files changed, 390 insertions(+), 203 deletions(-) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 2cdc425d83f..6978ca667e6 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -6,17 +6,33 @@ import traceback import warnings from collections import OrderedDict, defaultdict, deque -from collections.abc import Awaitable, Callable, Iterator, Sequence from contextlib import suppress from http import HTTPStatus from itertools import chain, cycle, islice from time import monotonic from types import TracebackType -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + DefaultDict, + Deque, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, + cast, +) import aiohappyeyeballs from aiohappyeyeballs import AddrInfoType, SocketFactoryType -from multidict import CIMultiDict from . import hdrs, helpers from .abc import AbstractResolver, ResolveResult @@ -34,16 +50,12 @@ ssl_errors, ) from .client_proto import ResponseHandler -from .client_reqrep import ( - SSL_ALLOWED_TYPES, - ClientRequest, - ClientRequestBase, - Fingerprint, -) +from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params from .helpers import ( _SENTINEL, ceil_timeout, is_ip_address, + noop, sentinel, set_exception, set_result, @@ -54,15 +66,20 @@ if sys.version_info >= (3, 12): from collections.abc import Buffer else: - Buffer = "bytes | bytearray | memoryview[int] | memoryview[bytes]" + Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] -try: +if TYPE_CHECKING: import ssl SSLContext = ssl.SSLContext -except ImportError: # pragma: no cover - ssl = None # type: ignore[assignment] - SSLContext = object # type: ignore[misc,assignment] +else: + try: + import ssl + + SSLContext = ssl.SSLContext + except ImportError: # pragma: no cover + ssl = None # type: ignore[assignment] + SSLContext = object # type: ignore[misc,assignment] EMPTY_SCHEMA_SET = frozenset({""}) HTTP_SCHEMA_SET = frozenset({"http", "https"}) @@ -96,17 +113,37 @@ from .tracing import Trace +class _DeprecationWaiter: + __slots__ = ("_awaitable", "_awaited") + + def __init__(self, awaitable: Awaitable[Any]) -> None: + self._awaitable = awaitable + self._awaited = False + + def __await__(self) -> Any: + self._awaited = True + return self._awaitable.__await__() + + def __del__(self) -> None: + if not self._awaited: + warnings.warn( + "Connector.close() is a coroutine, " + "please use await connector.close()", + DeprecationWarning, + ) + + +async def _wait_for_close(waiters: List[Awaitable[object]]) -> None: + """Wait for all waiters to finish closing.""" + results = await asyncio.gather(*waiters, return_exceptions=True) + for res in results: + if isinstance(res, Exception): + client_logger.debug("Error while closing connector: %r", res) + + class Connection: - """Represents a single connection.""" - __slots__ = ( - "_key", - "_connector", - "_loop", - "_protocol", - "_callbacks", - "_source_traceback", - ) + _source_traceback = None def __init__( self, @@ -118,20 +155,19 @@ def __init__( self._key = key self._connector = connector self._loop = loop - self._protocol: ResponseHandler | None = protocol - self._callbacks: list[Callable[[], None]] = [] - self._source_traceback = ( - traceback.extract_stack(sys._getframe(1)) if loop.get_debug() else None - ) + self._protocol: Optional[ResponseHandler] = protocol + self._callbacks: List[Callable[[], None]] = [] + + if loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) def __repr__(self) -> str: return f"Connection<{self._key}>" def __del__(self, _warnings: Any = warnings) -> None: if self._protocol is not None: - _warnings.warn( - f"Unclosed connection {self!r}", ResourceWarning, source=self - ) + kwargs = {"source": self} + _warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs) if self._loop.is_closed(): return @@ -147,13 +183,20 @@ def __bool__(self) -> Literal[True]: return True @property - def transport(self) -> asyncio.Transport | None: + def loop(self) -> asyncio.AbstractEventLoop: + warnings.warn( + "connector.loop property is deprecated", DeprecationWarning, stacklevel=2 + ) + return self._loop + + @property + def transport(self) -> Optional[asyncio.Transport]: if self._protocol is None: return None return self._protocol.transport @property - def protocol(self) -> ResponseHandler | None: + def protocol(self) -> Optional[ResponseHandler]: return self._protocol def add_callback(self, callback: Callable[[], None]) -> None: @@ -211,7 +254,7 @@ class _TransportPlaceholder: __slots__ = ("closed", "transport") - def __init__(self, closed_future: asyncio.Future[Exception | None]) -> None: + def __init__(self, closed_future: asyncio.Future[Optional[Exception]]) -> None: """Initialize a placeholder for a transport.""" self.closed = closed_future self.transport = None @@ -249,13 +292,15 @@ class BaseConnector: def __init__( self, *, - keepalive_timeout: _SENTINEL | None | float = sentinel, + keepalive_timeout: Union[object, None, float] = sentinel, force_close: bool = False, limit: int = 100, limit_per_host: int = 0, enable_cleanup_closed: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, timeout_ceil_threshold: float = 5, ) -> None: + if force_close: if keepalive_timeout is not None and keepalive_timeout is not sentinel: raise ValueError( @@ -265,10 +310,9 @@ def __init__( if keepalive_timeout is sentinel: keepalive_timeout = 15.0 + loop = loop or asyncio.get_running_loop() self._timeout_ceil_threshold = timeout_ceil_threshold - loop = asyncio.get_running_loop() - self._closed = False if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) @@ -276,13 +320,13 @@ def __init__( # Connection pool of reusable connections. # We use a deque to store connections because it has O(1) popleft() # and O(1) append() operations to implement a FIFO queue. - self._conns: defaultdict[ - ConnectionKey, deque[tuple[ResponseHandler, float]] + self._conns: DefaultDict[ + ConnectionKey, Deque[Tuple[ResponseHandler, float]] ] = defaultdict(deque) self._limit = limit self._limit_per_host = limit_per_host - self._acquired: set[ResponseHandler] = set() - self._acquired_per_host: defaultdict[ConnectionKey, set[ResponseHandler]] = ( + self._acquired: Set[ResponseHandler] = set() + self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = ( defaultdict(set) ) self._keepalive_timeout = cast(float, keepalive_timeout) @@ -291,7 +335,7 @@ def __init__( # {host_key: FIFO list of waiters} # The FIFO is implemented with an OrderedDict with None keys because # python does not have an ordered set. - self._waiters: defaultdict[ + self._waiters: DefaultDict[ ConnectionKey, OrderedDict[asyncio.Future[None], None] ] = defaultdict(OrderedDict) @@ -299,10 +343,10 @@ def __init__( self._factory = functools.partial(ResponseHandler, loop=loop) # start keep-alive connection cleanup task - self._cleanup_handle: asyncio.TimerHandle | None = None + self._cleanup_handle: Optional[asyncio.TimerHandle] = None # start cleanup closed transports task - self._cleanup_closed_handle: asyncio.TimerHandle | None = None + self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None if enable_cleanup_closed and not NEEDS_CLEANUP_CLOSED: warnings.warn( @@ -315,9 +359,8 @@ def __init__( enable_cleanup_closed = False self._cleanup_closed_disabled = not enable_cleanup_closed - self._cleanup_closed_transports: list[asyncio.Transport | None] = [] - - self._placeholder_future: asyncio.Future[Exception | None] = ( + self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = [] + self._placeholder_future: asyncio.Future[Optional[Exception]] = ( loop.create_future() ) self._placeholder_future.set_result(None) @@ -331,9 +374,10 @@ def __del__(self, _warnings: Any = warnings) -> None: conns = [repr(c) for c in self._conns.values()] - self._close_immediately() + self._close() - _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, source=self) + kwargs = {"source": self} + _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs) context = { "connector": self, "connections": conns, @@ -343,14 +387,25 @@ def __del__(self, _warnings: Any = warnings) -> None: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) + def __enter__(self) -> "BaseConnector": + warnings.warn( + '"with Connector():" is deprecated, ' + 'use "async with Connector():" instead', + DeprecationWarning, + ) + return self + + def __exit__(self, *exc: Any) -> None: + self._close() + async def __aenter__(self) -> "BaseConnector": return self async def __aexit__( self, - exc_type: type[BaseException] | None = None, - exc_value: BaseException | None = None, - exc_traceback: TracebackType | None = None, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + exc_traceback: Optional[TracebackType] = None, ) -> None: await self.close() @@ -392,7 +447,7 @@ def _cleanup(self) -> None: connections = defaultdict(deque) deadline = now - timeout for key, conns in self._conns.items(): - alive: deque[tuple[ResponseHandler, float]] = deque() + alive: Deque[Tuple[ResponseHandler, float]] = deque() for proto, use_time in conns: if proto.is_connected() and use_time - deadline >= 0: alive.append((proto, use_time)) @@ -439,23 +494,28 @@ def _cleanup_closed(self) -> None: timeout_ceil_threshold=self._timeout_ceil_threshold, ) - async def close(self, *, abort_ssl: bool = False) -> None: + def close(self, *, abort_ssl: bool = False) -> Awaitable[None]: """Close all opened transports. :param abort_ssl: If True, SSL connections will be aborted immediately without performing the shutdown handshake. This provides faster cleanup at the cost of less graceful disconnection. """ - waiters = self._close_immediately(abort_ssl=abort_ssl) - if waiters: - results = await asyncio.gather(*waiters, return_exceptions=True) - for res in results: - if isinstance(res, Exception): - err_msg = "Error while closing connector: " + repr(res) - client_logger.debug(err_msg) + if not (waiters := self._close(abort_ssl=abort_ssl)): + # If there are no connections to close, we can return a noop + # awaitable to avoid scheduling a task on the event loop. + return _DeprecationWaiter(noop()) + coro = _wait_for_close(waiters) + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to close connections + # immediately to avoid having to schedule the task on the event loop. + task = asyncio.Task(coro, loop=self._loop, eager_start=True) + else: + task = self._loop.create_task(coro) + return _DeprecationWaiter(task) - def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[object]]: - waiters: list[Awaitable[object]] = [] + def _close(self, *, abort_ssl: bool = False) -> List[Awaitable[object]]: + waiters: List[Awaitable[object]] = [] if self._closed: return waiters @@ -499,7 +559,6 @@ def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[objec if closed := proto.closed: waiters.append(closed) - # TODO (A.Yushovskiy, 24-May-2019) collect transp. closing futures for transport in self._cleanup_closed_transports: if transport is not None: transport.abort() @@ -550,48 +609,20 @@ def _available_connections(self, key: "ConnectionKey") -> int: return total_remain - def _update_proxy_auth_header_and_build_proxy_req( - self, req: ClientRequest - ) -> ClientRequestBase: - """Set Proxy-Authorization header for non-SSL proxy requests and builds the proxy request for SSL proxy requests.""" - url = req.proxy - assert url is not None - headers = req.proxy_headers or CIMultiDict[str]() - headers[hdrs.HOST] = req.headers[hdrs.HOST] - proxy_req = ClientRequestBase( - hdrs.METH_GET, - url, - headers=headers, - auth=req.proxy_auth, - loop=self._loop, - ssl=req.ssl, - ) - auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) - if auth is not None: - if not req.is_ssl(): - req.headers[hdrs.PROXY_AUTHORIZATION] = auth - else: - proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth - return proxy_req - async def connect( - self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> Connection: """Get from pool or create new connection.""" key = req.connection_key if (conn := await self._get(key, traces)) is not None: # If we do not have to wait and we can get a connection from the pool # we can avoid the timeout ceil logic and directly return the connection - if req.proxy: - self._update_proxy_auth_header_and_build_proxy_req(req) return conn async with ceil_timeout(timeout.connect, timeout.ceil_threshold): if self._available_connections(key) <= 0: await self._wait_for_available_connection(key, traces) if (conn := await self._get(key, traces)) is not None: - if req.proxy: - self._update_proxy_auth_header_and_build_proxy_req(req) return conn placeholder = cast( @@ -634,7 +665,7 @@ async def connect( return Connection(self, key, proto, self._loop) async def _wait_for_available_connection( - self, key: "ConnectionKey", traces: list["Trace"] + self, key: "ConnectionKey", traces: List["Trace"] ) -> None: """Wait for an available connection slot.""" # We loop here because there is a race between @@ -676,8 +707,8 @@ async def _wait_for_available_connection( attempts += 1 async def _get( - self, key: "ConnectionKey", traces: list["Trace"] - ) -> Connection | None: + self, key: "ConnectionKey", traces: List["Trace"] + ) -> Optional[Connection]: """Get next reusable connection for the key or None. The connection will be marked as acquired. @@ -772,6 +803,7 @@ def _release( if self._force_close or should_close or protocol.should_close: transport = protocol.transport protocol.close() + if key.is_ssl and not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transport) return @@ -788,27 +820,27 @@ def _release( ) async def _create_connection( - self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: raise NotImplementedError() class _DNSCacheTable: - def __init__(self, ttl: float | None = None) -> None: - self._addrs_rr: dict[tuple[str, int], tuple[Iterator[ResolveResult], int]] = {} - self._timestamps: dict[tuple[str, int], float] = {} + def __init__(self, ttl: Optional[float] = None) -> None: + self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {} + self._timestamps: Dict[Tuple[str, int], float] = {} self._ttl = ttl def __contains__(self, host: object) -> bool: return host in self._addrs_rr - def add(self, key: tuple[str, int], addrs: list[ResolveResult]) -> None: + def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None: self._addrs_rr[key] = (cycle(addrs), len(addrs)) if self._ttl is not None: self._timestamps[key] = monotonic() - def remove(self, key: tuple[str, int]) -> None: + def remove(self, key: Tuple[str, int]) -> None: self._addrs_rr.pop(key, None) if self._ttl is not None: @@ -818,14 +850,14 @@ def clear(self) -> None: self._addrs_rr.clear() self._timestamps.clear() - def next_addrs(self, key: tuple[str, int]) -> list[ResolveResult]: + def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]: loop, length = self._addrs_rr[key] addrs = list(islice(loop, length)) # Consume one more element to shift internal state of `cycle` next(loop) return addrs - def expired(self, key: tuple[str, int]) -> bool: + def expired(self, key: Tuple[str, int]) -> bool: if self._ttl is None: return False @@ -840,7 +872,7 @@ def _make_ssl_context(verified: bool) -> SSLContext: """ if ssl is None: # No ssl support - return None # type: ignore[unreachable] + return None if verified: sslcontext = ssl.create_default_context() else: @@ -907,22 +939,26 @@ class TCPConnector(BaseConnector): def __init__( self, *, + verify_ssl: bool = True, + fingerprint: Optional[bytes] = None, use_dns_cache: bool = True, - ttl_dns_cache: int | None = 10, + ttl_dns_cache: Optional[int] = 10, family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC, - ssl: bool | Fingerprint | SSLContext = True, - local_addr: tuple[str, int] | None = None, - resolver: AbstractResolver | None = None, - keepalive_timeout: None | float | _SENTINEL = sentinel, + ssl_context: Optional[SSLContext] = None, + ssl: Union[bool, Fingerprint, SSLContext] = True, + local_addr: Optional[Tuple[str, int]] = None, + resolver: Optional[AbstractResolver] = None, + keepalive_timeout: Union[None, float, object] = sentinel, force_close: bool = False, limit: int = 100, limit_per_host: int = 0, enable_cleanup_closed: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, timeout_ceil_threshold: float = 5, - happy_eyeballs_delay: float | None = 0.25, - interleave: int | None = None, - socket_factory: SocketFactoryType | None = None, - ssl_shutdown_timeout: _SENTINEL | None | float = sentinel, + happy_eyeballs_delay: Optional[float] = 0.25, + interleave: Optional[int] = None, + socket_factory: Optional[SocketFactoryType] = None, + ssl_shutdown_timeout: Union[_SENTINEL, None, float] = sentinel, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -930,19 +966,15 @@ def __init__( limit=limit, limit_per_host=limit_per_host, enable_cleanup_closed=enable_cleanup_closed, + loop=loop, timeout_ceil_threshold=timeout_ceil_threshold, ) - if not isinstance(ssl, SSL_ALLOWED_TYPES): - raise TypeError( - "ssl should be SSLContext, Fingerprint, or bool, " - f"got {ssl!r} instead." - ) - self._ssl = ssl + self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) self._resolver: AbstractResolver if resolver is None: - self._resolver = DefaultResolver() + self._resolver = DefaultResolver(loop=self._loop) self._resolver_owner = True else: self._resolver = resolver @@ -950,17 +982,16 @@ def __init__( self._use_dns_cache = use_dns_cache self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) - self._throttle_dns_futures: dict[tuple[str, int], set[asyncio.Future[None]]] = ( - {} - ) + self._throttle_dns_futures: Dict[ + Tuple[str, int], Set["asyncio.Future[None]"] + ] = {} self._family = family self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) self._happy_eyeballs_delay = happy_eyeballs_delay self._interleave = interleave - self._resolve_host_tasks: set[asyncio.Task[list[ResolveResult]]] = set() + self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() self._socket_factory = socket_factory - self._ssl_shutdown_timeout: float | None - + self._ssl_shutdown_timeout: Optional[float] # Handle ssl_shutdown_timeout with warning for Python < 3.11 if ssl_shutdown_timeout is sentinel: self._ssl_shutdown_timeout = 0 @@ -984,8 +1015,22 @@ def __init__( ) self._ssl_shutdown_timeout = ssl_shutdown_timeout + def _close(self, *, abort_ssl: bool = False) -> List[Awaitable[object]]: + """Close all ongoing DNS calls.""" + for fut in chain.from_iterable(self._throttle_dns_futures.values()): + fut.cancel() + + waiters = super()._close(abort_ssl=abort_ssl) + + for t in self._resolve_host_tasks: + t.cancel() + waiters.append(t) + + return waiters + async def close(self, *, abort_ssl: bool = False) -> None: - """Close all opened transports. + """ + Close all opened transports. :param abort_ssl: If True, SSL connections will be aborted immediately without performing the shutdown handshake. If False (default), @@ -998,18 +1043,6 @@ async def close(self, *, abort_ssl: bool = False) -> None: # Use abort_ssl param if explicitly set, otherwise use ssl_shutdown_timeout default await super().close(abort_ssl=abort_ssl or self._ssl_shutdown_timeout == 0) - def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[object]]: - for fut in chain.from_iterable(self._throttle_dns_futures.values()): - fut.cancel() - - waiters = super()._close_immediately(abort_ssl=abort_ssl) - - for t in self._resolve_host_tasks: - t.cancel() - waiters.append(t) - - return waiters - @property def family(self) -> int: """Socket family like AF_INET.""" @@ -1020,7 +1053,9 @@ def use_dns_cache(self) -> bool: """True if local DNS caching is enabled.""" return self._use_dns_cache - def clear_dns_cache(self, host: str | None = None, port: int | None = None) -> None: + def clear_dns_cache( + self, host: Optional[str] = None, port: Optional[int] = None + ) -> None: """Remove specified host/port or clear all dns local cache.""" if host is not None and port is not None: self._cached_hosts.remove((host, port)) @@ -1030,8 +1065,8 @@ def clear_dns_cache(self, host: str | None = None, port: int | None = None) -> N self._cached_hosts.clear() async def _resolve_host( - self, host: str, port: int, traces: Sequence["Trace"] | None = None - ) -> list[ResolveResult]: + self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None + ) -> List[ResolveResult]: """Resolve host and return list of addresses.""" if is_ip_address(host): return [ @@ -1046,6 +1081,7 @@ async def _resolve_host( ] if not self._use_dns_cache: + if traces: for trace in traces: await trace.send_dns_resolvehost_start(host) @@ -1068,7 +1104,7 @@ async def _resolve_host( await trace.send_dns_cache_hit(host) return result - futures: set[asyncio.Future[None]] + futures: Set["asyncio.Future[None]"] # # If multiple connectors are resolving the same host, we wait # for the first one to resolve and then use the result for all of them. @@ -1112,7 +1148,7 @@ async def _resolve_host( return await asyncio.shield(resolved_host_task) except asyncio.CancelledError: - def drop_exception(fut: "asyncio.Future[list[ResolveResult]]") -> None: + def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None: with suppress(Exception, asyncio.CancelledError): fut.result() @@ -1121,12 +1157,12 @@ def drop_exception(fut: "asyncio.Future[list[ResolveResult]]") -> None: async def _resolve_host_with_throttle( self, - key: tuple[str, int], + key: Tuple[str, int], host: str, port: int, - futures: set[asyncio.Future[None]], - traces: Sequence["Trace"] | None, - ) -> list[ResolveResult]: + futures: Set["asyncio.Future[None]"], + traces: Optional[Sequence["Trace"]], + ) -> List[ResolveResult]: """Resolve host and set result for all waiters. This method must be run in a task and shielded from cancellation @@ -1161,7 +1197,7 @@ async def _resolve_host_with_throttle( return self._cached_hosts.next_addrs(key) async def _create_connection( - self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: """Create connection. @@ -1174,7 +1210,7 @@ async def _create_connection( return proto - def _get_ssl_context(self, req: ClientRequestBase) -> SSLContext | None: + def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -1207,7 +1243,7 @@ def _get_ssl_context(self, req: ClientRequestBase) -> SSLContext | None: return _SSL_CONTEXT_UNVERIFIED return _SSL_CONTEXT_VERIFIED - def _get_fingerprint(self, req: ClientRequestBase) -> "Fingerprint | None": + def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: ret = req.ssl if isinstance(ret, Fingerprint): return ret @@ -1219,32 +1255,71 @@ def _get_fingerprint(self, req: ClientRequestBase) -> "Fingerprint | None": async def _wrap_create_connection( self, *args: Any, - addr_infos: list[AddrInfoType], - req: ClientRequestBase, + addr_infos: List[AddrInfoType], + req: ClientRequest, timeout: "ClientTimeout", - client_error: type[Exception] = ClientConnectorError, + client_error: Type[Exception] = ClientConnectorError, **kwargs: Any, - ) -> tuple[asyncio.Transport, ResponseHandler]: + ) -> Tuple[asyncio.Transport, ResponseHandler]: + # Add ssl_shutdown_timeout for Python 3.11+ when SSL is used + if ( + kwargs.get("ssl") + and self._ssl_shutdown_timeout + and sys.version_info >= (3, 11) + ): + kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout try: async with ceil_timeout( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold ): - sock = await aiohappyeyeballs.start_connection( - addr_infos=addr_infos, - local_addr_infos=self._local_addr_infos, - happy_eyeballs_delay=self._happy_eyeballs_delay, - interleave=self._interleave, - loop=self._loop, - socket_factory=self._socket_factory, - ) - # Add ssl_shutdown_timeout for Python 3.11+ when SSL is used - if ( - kwargs.get("ssl") - and self._ssl_shutdown_timeout - and sys.version_info >= (3, 11) - ): - kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout - return await self._loop.create_connection(*args, **kwargs, sock=sock) + if self._happy_eyeballs_delay is None: + # If happyeyeballs is disabled, connect in sequence + # this avoids a bug in uvloop where it can lose track + # of sockets passed between aiohappyeyeballs.start_connect + # and create_connection and try to reuse the same fd. + # https://github.com/aio-libs/aiohttp/issues/10506 + # https://github.com/MagicStack/uvloop/issues/645 + first_addr_infos = addr_infos[0] + address_tuple = first_addr_infos[4] + host: str = address_tuple[0] + port: int = address_tuple[1] + return await self._loop.create_connection( + *args, host=host, port=port, **kwargs + ) + else: + sock = await aiohappyeyeballs.start_connection( + addr_infos=addr_infos, + local_addr_infos=self._local_addr_infos, + happy_eyeballs_delay=self._happy_eyeballs_delay, + interleave=self._interleave, + loop=self._loop, + socket_factory=self._socket_factory, + ) + return await self._loop.create_connection( + *args, **kwargs, sock=sock + ) + except cert_errors as exc: + raise ClientConnectorCertificateError(req.connection_key, exc) from exc + except ssl_errors as exc: + raise ClientConnectorSSLError(req.connection_key, exc) from exc + except OSError as exc: + if exc.errno is None and isinstance(exc, asyncio.TimeoutError): + raise + raise client_error(req.connection_key, exc) from exc + + async def _wrap_existing_connection( + self, + *args: Any, + req: ClientRequest, + timeout: "ClientTimeout", + client_error: Type[Exception] = ClientConnectorError, + **kwargs: Any, + ) -> Tuple[asyncio.Transport, ResponseHandler]: + try: + async with ceil_timeout( + timeout.sock_connect, ceil_threshold=timeout.ceil_threshold + ): + return await self._loop.create_connection(*args, **kwargs) except cert_errors as exc: raise ClientConnectorCertificateError(req.connection_key, exc) from exc except ssl_errors as exc: @@ -1254,13 +1329,56 @@ async def _wrap_create_connection( raise raise client_error(req.connection_key, exc) from exc + def _fail_on_no_start_tls(self, req: "ClientRequest") -> None: + """Raise a :py:exc:`RuntimeError` on missing ``start_tls()``. + + It is necessary for TLS-in-TLS so that it is possible to + send HTTPS queries through HTTPS proxies. + + This doesn't affect regular HTTP requests, though. + """ + if not req.is_ssl(): + return + + proxy_url = req.proxy + assert proxy_url is not None + if proxy_url.scheme != "https": + return + + self._check_loop_for_start_tls() + + def _check_loop_for_start_tls(self) -> None: + try: + self._loop.start_tls + except AttributeError as attr_exc: + raise RuntimeError( + "An HTTPS request is being sent through an HTTPS proxy. " + "This needs support for TLS in TLS but it is not implemented " + "in your runtime for the stdlib asyncio.\n\n" + "Please upgrade to Python 3.11 or higher. For more details, " + "please see:\n" + "* https://bugs.python.org/issue37179\n" + "* https://github.com/python/cpython/pull/28073\n" + "* https://docs.aiohttp.org/en/stable/" + "client_advanced.html#proxy-support\n" + "* https://github.com/aio-libs/aiohttp/discussions/6044\n", + ) from attr_exc + + def _loop_supports_start_tls(self) -> bool: + try: + self._check_loop_for_start_tls() + except RuntimeError: + return False + else: + return True + def _warn_about_tls_in_tls( self, underlying_transport: asyncio.Transport, req: ClientRequest, ) -> None: """Issue a warning if the requested URL has HTTPS scheme.""" - if req.url.scheme != "https": + if req.request_info.url.scheme != "https": return # Check if uvloop is being used, which supports TLS in TLS, @@ -1281,7 +1399,7 @@ def _warn_about_tls_in_tls( warnings.warn( "An HTTPS request is being sent through an HTTPS proxy. " "This support for TLS in TLS is known to be disabled " - "in the stdlib asyncio. This is why you'll probably see " + "in the stdlib asyncio (Python <3.11). This is why you'll probably see " "an error in the log below.\n\n" "It is possible to enable it via monkeypatching. " "For more details, see:\n" @@ -1302,8 +1420,8 @@ async def _start_tls_connection( underlying_transport: asyncio.Transport, req: ClientRequest, timeout: "ClientTimeout", - client_error: type[Exception] = ClientConnectorError, - ) -> tuple[asyncio.BaseTransport, ResponseHandler]: + client_error: Type[Exception] = ClientConnectorError, + ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: """Wrap the raw TCP transport with TLS.""" tls_proto = self._factory() # Create a brand new proto for TLS sslcontext = self._get_ssl_context(req) @@ -1323,7 +1441,7 @@ async def _start_tls_connection( underlying_transport, tls_proto, sslcontext, - server_hostname=req.server_hostname or req.url.raw_host, + server_hostname=req.server_hostname or req.host, ssl_handshake_timeout=timeout.total, ssl_shutdown_timeout=self._ssl_shutdown_timeout, ) @@ -1332,7 +1450,7 @@ async def _start_tls_connection( underlying_transport, tls_proto, sslcontext, - server_hostname=req.server_hostname or req.url.raw_host, + server_hostname=req.server_hostname or req.host, ssl_handshake_timeout=timeout.total, ) except BaseException: @@ -1369,7 +1487,7 @@ async def _start_tls_connection( raise ClientConnectionError( "Cannot initialize a TLS-in-TLS connection to host " - f"{req.url.host!s}:{req.url.port:d} through an underlying connection " + f"{req.host!s}:{req.port:d} through an underlying connection " f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} " f"[{type_err!s}]" ) from type_err @@ -1384,14 +1502,14 @@ async def _start_tls_connection( return tls_transport, tls_proto def _convert_hosts_to_addr_infos( - self, hosts: list[ResolveResult] - ) -> list[AddrInfoType]: + self, hosts: List[ResolveResult] + ) -> List[AddrInfoType]: """Converts the list of hosts to a list of addr_infos. The list of hosts is the result of a DNS lookup. The list of addr_infos is the result of a call to `socket.getaddrinfo()`. """ - addr_infos: list[AddrInfoType] = [] + addr_infos: List[AddrInfoType] = [] for hinfo in hosts: host = hinfo["host"] is_ipv6 = ":" in host @@ -1406,12 +1524,12 @@ def _convert_hosts_to_addr_infos( async def _create_direct_connection( self, - req: ClientRequestBase, - traces: list["Trace"], + req: ClientRequest, + traces: List["Trace"], timeout: "ClientTimeout", *, - client_error: type[Exception] = ClientConnectorError, - ) -> tuple[asyncio.Transport, ResponseHandler]: + client_error: Type[Exception] = ClientConnectorError, + ) -> Tuple[asyncio.Transport, ResponseHandler]: sslcontext = self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) @@ -1422,7 +1540,7 @@ async def _create_direct_connection( # See https://github.com/aio-libs/aiohttp/pull/7364. if host.endswith(".."): host = host.rstrip(".") + "." - port = req.url.port + port = req.port assert port is not None try: # Cancelling this lookup should not cancel the underlying lookup @@ -1436,7 +1554,7 @@ async def _create_direct_connection( # it is problem of resolving proxy ip itself raise ClientConnectorDNSError(req.connection_key, exc) from exc - last_exc: Exception | None = None + last_exc: Optional[Exception] = None addr_infos = self._convert_hosts_to_addr_infos(hosts) while addr_infos: # Strip trailing dots, certificates contain FQDN without dots. @@ -1457,7 +1575,12 @@ async def _create_direct_connection( ) except (ClientConnectorError, asyncio.TimeoutError) as exc: last_exc = exc - aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave) + if self._happy_eyeballs_delay is None: + addr_infos.pop(0) + else: + aiohappyeyeballs.pop_addr_infos_interleave( + addr_infos, self._interleave + ) continue if req.is_ssl() and fingerprint: @@ -1475,21 +1598,47 @@ async def _create_direct_connection( continue return transp, proto - assert last_exc is not None - raise last_exc + else: + assert last_exc is not None + raise last_exc async def _create_proxy_connection( - self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" - ) -> tuple[asyncio.BaseTransport, ResponseHandler]: - proxy_req = self._update_proxy_auth_header_and_build_proxy_req(req) + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: + self._fail_on_no_start_tls(req) + runtime_has_start_tls = self._loop_supports_start_tls() + + headers: Dict[str, str] = {} + if req.proxy_headers is not None: + headers = req.proxy_headers # type: ignore[assignment] + headers[hdrs.HOST] = req.headers[hdrs.HOST] + + url = req.proxy + assert url is not None + proxy_req = ClientRequest( + hdrs.METH_GET, + url, + headers=headers, + auth=req.proxy_auth, + loop=self._loop, + ssl=req.ssl, + ) # create connection to proxy server transport, proto = await self._create_direct_connection( proxy_req, [], timeout, client_error=ClientProxyConnectionError ) + auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) + if auth is not None: + if not req.is_ssl(): + req.headers[hdrs.PROXY_AUTHORIZATION] = auth + else: + proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth + if req.is_ssl(): - self._warn_about_tls_in_tls(transport, req) + if runtime_has_start_tls: + self._warn_about_tls_in_tls(transport, req) # For HTTPS requests over HTTP proxy # we must notify proxy to tunnel connection @@ -1506,7 +1655,7 @@ async def _create_proxy_connection( proxy=None, proxy_auth=None, proxy_headers_hash=None ) conn = _ConnectTunnelConnection(self, key, proto, self._loop) - proxy_resp = await proxy_req._send(conn) + proxy_resp = await proxy_req.send(conn) try: protocol = conn._protocol assert protocol is not None @@ -1515,7 +1664,7 @@ async def _create_proxy_connection( # once the response is received and processed allowing # START_TLS to work on the connection below. protocol.set_response_params( - read_until_eof=True, + read_until_eof=runtime_has_start_tls, timeout_ceil_threshold=self._timeout_ceil_threshold, ) resp = await proxy_resp.start(conn) @@ -1537,12 +1686,35 @@ async def _create_proxy_connection( message=message, headers=resp.headers, ) + if not runtime_has_start_tls: + rawsock = transport.get_extra_info("socket", default=None) + if rawsock is None: + raise RuntimeError( + "Transport does not expose socket instance" + ) + # Duplicate the socket, so now we can close proxy transport + rawsock = rawsock.dup() except BaseException: # It shouldn't be closed in `finally` because it's fed to # `loop.start_tls()` and the docs say not to touch it after # passing there. transport.close() raise + finally: + if not runtime_has_start_tls: + transport.close() + + if not runtime_has_start_tls: + # HTTP proxy with support for upgrade to HTTPS + sslcontext = self._get_ssl_context(req) + return await self._wrap_existing_connection( + self._factory, + timeout=timeout, + ssl=sslcontext, + sock=rawsock, + server_hostname=req.host, + req=req, + ) return await self._start_tls_connection( # Access the old transport for the last time before it's @@ -1575,15 +1747,17 @@ def __init__( self, path: str, force_close: bool = False, - keepalive_timeout: _SENTINEL | float | None = sentinel, + keepalive_timeout: Union[object, float, None] = sentinel, limit: int = 100, limit_per_host: int = 0, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__( force_close=force_close, keepalive_timeout=keepalive_timeout, limit=limit, limit_per_host=limit_per_host, + loop=loop, ) self._path = path @@ -1593,7 +1767,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( @@ -1631,15 +1805,17 @@ def __init__( self, path: str, force_close: bool = False, - keepalive_timeout: _SENTINEL | float | None = sentinel, + keepalive_timeout: Union[object, float, None] = sentinel, limit: int = 100, limit_per_host: int = 0, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__( force_close=force_close, keepalive_timeout=keepalive_timeout, limit=limit, limit_per_host=limit_per_host, + loop=loop, ) if not isinstance( self._loop, @@ -1656,7 +1832,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 959f71cc19b..786149081b7 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -609,6 +609,9 @@ async def receive_str( self: "WebSocketResponse[Literal[False]]", *, timeout: float | None = None ) -> bytes: ... + @overload + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: ... + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. @@ -645,6 +648,14 @@ async def receive_json( timeout: float | None = None, ) -> Any: ... + @overload + async def receive_json( + self, + *, + loads: Callable[[bytes | bytearray | memoryview | str], Any] = ..., + timeout: float | None = None, + ) -> Any: ... + async def receive_json( self, *, @@ -653,8 +664,8 @@ async def receive_json( ) = json.loads, timeout: float | None = None, ) -> Any: - data: str | bytes = await self.receive_str(timeout=timeout) - return loads(data) + data = await self.receive_str(timeout=timeout) + return loads(data) # type: ignore[arg-type] async def write( self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] From ca14d5f24c36b3feebfc72afffb6a343b7e49780 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:35:17 -0600 Subject: [PATCH 10/42] fallback overloads --- aiohttp/connector.py | 533 ++++++++++++++----------------------------- 1 file changed, 174 insertions(+), 359 deletions(-) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 6978ca667e6..b1820358bae 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -6,33 +6,17 @@ import traceback import warnings from collections import OrderedDict, defaultdict, deque +from collections.abc import Awaitable, Callable, Iterator, Sequence from contextlib import suppress from http import HTTPStatus from itertools import chain, cycle, islice from time import monotonic from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - DefaultDict, - Deque, - Dict, - Iterator, - List, - Literal, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Literal, cast import aiohappyeyeballs from aiohappyeyeballs import AddrInfoType, SocketFactoryType +from multidict import CIMultiDict from . import hdrs, helpers from .abc import AbstractResolver, ResolveResult @@ -50,12 +34,16 @@ ssl_errors, ) from .client_proto import ResponseHandler -from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params +from .client_reqrep import ( + SSL_ALLOWED_TYPES, + ClientRequest, + ClientRequestBase, + Fingerprint, +) from .helpers import ( _SENTINEL, ceil_timeout, is_ip_address, - noop, sentinel, set_exception, set_result, @@ -66,20 +54,15 @@ if sys.version_info >= (3, 12): from collections.abc import Buffer else: - Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + Buffer = "bytes | bytearray | memoryview[int] | memoryview[bytes]" -if TYPE_CHECKING: +try: import ssl SSLContext = ssl.SSLContext -else: - try: - import ssl - - SSLContext = ssl.SSLContext - except ImportError: # pragma: no cover - ssl = None # type: ignore[assignment] - SSLContext = object # type: ignore[misc,assignment] +except ImportError: # pragma: no cover + ssl = None # type: ignore[assignment] + SSLContext = object # type: ignore[misc,assignment] EMPTY_SCHEMA_SET = frozenset({""}) HTTP_SCHEMA_SET = frozenset({"http", "https"}) @@ -113,37 +96,17 @@ from .tracing import Trace -class _DeprecationWaiter: - __slots__ = ("_awaitable", "_awaited") - - def __init__(self, awaitable: Awaitable[Any]) -> None: - self._awaitable = awaitable - self._awaited = False - - def __await__(self) -> Any: - self._awaited = True - return self._awaitable.__await__() - - def __del__(self) -> None: - if not self._awaited: - warnings.warn( - "Connector.close() is a coroutine, " - "please use await connector.close()", - DeprecationWarning, - ) - - -async def _wait_for_close(waiters: List[Awaitable[object]]) -> None: - """Wait for all waiters to finish closing.""" - results = await asyncio.gather(*waiters, return_exceptions=True) - for res in results: - if isinstance(res, Exception): - client_logger.debug("Error while closing connector: %r", res) - - class Connection: + """Represents a single connection.""" - _source_traceback = None + __slots__ = ( + "_key", + "_connector", + "_loop", + "_protocol", + "_callbacks", + "_source_traceback", + ) def __init__( self, @@ -155,19 +118,20 @@ def __init__( self._key = key self._connector = connector self._loop = loop - self._protocol: Optional[ResponseHandler] = protocol - self._callbacks: List[Callable[[], None]] = [] - - if loop.get_debug(): - self._source_traceback = traceback.extract_stack(sys._getframe(1)) + self._protocol: ResponseHandler | None = protocol + self._callbacks: list[Callable[[], None]] = [] + self._source_traceback = ( + traceback.extract_stack(sys._getframe(1)) if loop.get_debug() else None + ) def __repr__(self) -> str: return f"Connection<{self._key}>" def __del__(self, _warnings: Any = warnings) -> None: if self._protocol is not None: - kwargs = {"source": self} - _warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs) + _warnings.warn( + f"Unclosed connection {self!r}", ResourceWarning, source=self + ) if self._loop.is_closed(): return @@ -183,20 +147,13 @@ def __bool__(self) -> Literal[True]: return True @property - def loop(self) -> asyncio.AbstractEventLoop: - warnings.warn( - "connector.loop property is deprecated", DeprecationWarning, stacklevel=2 - ) - return self._loop - - @property - def transport(self) -> Optional[asyncio.Transport]: + def transport(self) -> asyncio.Transport | None: if self._protocol is None: return None return self._protocol.transport @property - def protocol(self) -> Optional[ResponseHandler]: + def protocol(self) -> ResponseHandler | None: return self._protocol def add_callback(self, callback: Callable[[], None]) -> None: @@ -254,7 +211,7 @@ class _TransportPlaceholder: __slots__ = ("closed", "transport") - def __init__(self, closed_future: asyncio.Future[Optional[Exception]]) -> None: + def __init__(self, closed_future: asyncio.Future[Exception | None]) -> None: """Initialize a placeholder for a transport.""" self.closed = closed_future self.transport = None @@ -292,15 +249,13 @@ class BaseConnector: def __init__( self, *, - keepalive_timeout: Union[object, None, float] = sentinel, + keepalive_timeout: _SENTINEL | None | float = sentinel, force_close: bool = False, limit: int = 100, limit_per_host: int = 0, enable_cleanup_closed: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, timeout_ceil_threshold: float = 5, ) -> None: - if force_close: if keepalive_timeout is not None and keepalive_timeout is not sentinel: raise ValueError( @@ -310,9 +265,10 @@ def __init__( if keepalive_timeout is sentinel: keepalive_timeout = 15.0 - loop = loop or asyncio.get_running_loop() self._timeout_ceil_threshold = timeout_ceil_threshold + loop = asyncio.get_running_loop() + self._closed = False if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) @@ -320,13 +276,13 @@ def __init__( # Connection pool of reusable connections. # We use a deque to store connections because it has O(1) popleft() # and O(1) append() operations to implement a FIFO queue. - self._conns: DefaultDict[ - ConnectionKey, Deque[Tuple[ResponseHandler, float]] + self._conns: defaultdict[ + ConnectionKey, deque[tuple[ResponseHandler, float]] ] = defaultdict(deque) self._limit = limit self._limit_per_host = limit_per_host - self._acquired: Set[ResponseHandler] = set() - self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = ( + self._acquired: set[ResponseHandler] = set() + self._acquired_per_host: defaultdict[ConnectionKey, set[ResponseHandler]] = ( defaultdict(set) ) self._keepalive_timeout = cast(float, keepalive_timeout) @@ -335,7 +291,7 @@ def __init__( # {host_key: FIFO list of waiters} # The FIFO is implemented with an OrderedDict with None keys because # python does not have an ordered set. - self._waiters: DefaultDict[ + self._waiters: defaultdict[ ConnectionKey, OrderedDict[asyncio.Future[None], None] ] = defaultdict(OrderedDict) @@ -343,10 +299,10 @@ def __init__( self._factory = functools.partial(ResponseHandler, loop=loop) # start keep-alive connection cleanup task - self._cleanup_handle: Optional[asyncio.TimerHandle] = None + self._cleanup_handle: asyncio.TimerHandle | None = None # start cleanup closed transports task - self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None + self._cleanup_closed_handle: asyncio.TimerHandle | None = None if enable_cleanup_closed and not NEEDS_CLEANUP_CLOSED: warnings.warn( @@ -359,8 +315,9 @@ def __init__( enable_cleanup_closed = False self._cleanup_closed_disabled = not enable_cleanup_closed - self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = [] - self._placeholder_future: asyncio.Future[Optional[Exception]] = ( + self._cleanup_closed_transports: list[asyncio.Transport | None] = [] + + self._placeholder_future: asyncio.Future[Exception | None] = ( loop.create_future() ) self._placeholder_future.set_result(None) @@ -374,10 +331,9 @@ def __del__(self, _warnings: Any = warnings) -> None: conns = [repr(c) for c in self._conns.values()] - self._close() + self._close_immediately() - kwargs = {"source": self} - _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs) + _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, source=self) context = { "connector": self, "connections": conns, @@ -387,25 +343,14 @@ def __del__(self, _warnings: Any = warnings) -> None: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - def __enter__(self) -> "BaseConnector": - warnings.warn( - '"with Connector():" is deprecated, ' - 'use "async with Connector():" instead', - DeprecationWarning, - ) - return self - - def __exit__(self, *exc: Any) -> None: - self._close() - async def __aenter__(self) -> "BaseConnector": return self async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - exc_traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + exc_traceback: TracebackType | None = None, ) -> None: await self.close() @@ -447,7 +392,7 @@ def _cleanup(self) -> None: connections = defaultdict(deque) deadline = now - timeout for key, conns in self._conns.items(): - alive: Deque[Tuple[ResponseHandler, float]] = deque() + alive: deque[tuple[ResponseHandler, float]] = deque() for proto, use_time in conns: if proto.is_connected() and use_time - deadline >= 0: alive.append((proto, use_time)) @@ -494,28 +439,23 @@ def _cleanup_closed(self) -> None: timeout_ceil_threshold=self._timeout_ceil_threshold, ) - def close(self, *, abort_ssl: bool = False) -> Awaitable[None]: + async def close(self, *, abort_ssl: bool = False) -> None: """Close all opened transports. :param abort_ssl: If True, SSL connections will be aborted immediately without performing the shutdown handshake. This provides faster cleanup at the cost of less graceful disconnection. """ - if not (waiters := self._close(abort_ssl=abort_ssl)): - # If there are no connections to close, we can return a noop - # awaitable to avoid scheduling a task on the event loop. - return _DeprecationWaiter(noop()) - coro = _wait_for_close(waiters) - if sys.version_info >= (3, 12): - # Optimization for Python 3.12, try to close connections - # immediately to avoid having to schedule the task on the event loop. - task = asyncio.Task(coro, loop=self._loop, eager_start=True) - else: - task = self._loop.create_task(coro) - return _DeprecationWaiter(task) + waiters = self._close_immediately(abort_ssl=abort_ssl) + if waiters: + results = await asyncio.gather(*waiters, return_exceptions=True) + for res in results: + if isinstance(res, Exception): + err_msg = "Error while closing connector: " + repr(res) + client_logger.debug(err_msg) - def _close(self, *, abort_ssl: bool = False) -> List[Awaitable[object]]: - waiters: List[Awaitable[object]] = [] + def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[object]]: + waiters: list[Awaitable[object]] = [] if self._closed: return waiters @@ -559,6 +499,7 @@ def _close(self, *, abort_ssl: bool = False) -> List[Awaitable[object]]: if closed := proto.closed: waiters.append(closed) + # TODO (A.Yushovskiy, 24-May-2019) collect transp. closing futures for transport in self._cleanup_closed_transports: if transport is not None: transport.abort() @@ -610,7 +551,7 @@ def _available_connections(self, key: "ConnectionKey") -> int: return total_remain async def connect( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> Connection: """Get from pool or create new connection.""" key = req.connection_key @@ -665,7 +606,7 @@ async def connect( return Connection(self, key, proto, self._loop) async def _wait_for_available_connection( - self, key: "ConnectionKey", traces: List["Trace"] + self, key: "ConnectionKey", traces: list["Trace"] ) -> None: """Wait for an available connection slot.""" # We loop here because there is a race between @@ -707,8 +648,8 @@ async def _wait_for_available_connection( attempts += 1 async def _get( - self, key: "ConnectionKey", traces: List["Trace"] - ) -> Optional[Connection]: + self, key: "ConnectionKey", traces: list["Trace"] + ) -> Connection | None: """Get next reusable connection for the key or None. The connection will be marked as acquired. @@ -803,7 +744,6 @@ def _release( if self._force_close or should_close or protocol.should_close: transport = protocol.transport protocol.close() - if key.is_ssl and not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transport) return @@ -820,27 +760,27 @@ def _release( ) async def _create_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: raise NotImplementedError() class _DNSCacheTable: - def __init__(self, ttl: Optional[float] = None) -> None: - self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {} - self._timestamps: Dict[Tuple[str, int], float] = {} + def __init__(self, ttl: float | None = None) -> None: + self._addrs_rr: dict[tuple[str, int], tuple[Iterator[ResolveResult], int]] = {} + self._timestamps: dict[tuple[str, int], float] = {} self._ttl = ttl def __contains__(self, host: object) -> bool: return host in self._addrs_rr - def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None: + def add(self, key: tuple[str, int], addrs: list[ResolveResult]) -> None: self._addrs_rr[key] = (cycle(addrs), len(addrs)) if self._ttl is not None: self._timestamps[key] = monotonic() - def remove(self, key: Tuple[str, int]) -> None: + def remove(self, key: tuple[str, int]) -> None: self._addrs_rr.pop(key, None) if self._ttl is not None: @@ -850,14 +790,14 @@ def clear(self) -> None: self._addrs_rr.clear() self._timestamps.clear() - def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]: + def next_addrs(self, key: tuple[str, int]) -> list[ResolveResult]: loop, length = self._addrs_rr[key] addrs = list(islice(loop, length)) # Consume one more element to shift internal state of `cycle` next(loop) return addrs - def expired(self, key: Tuple[str, int]) -> bool: + def expired(self, key: tuple[str, int]) -> bool: if self._ttl is None: return False @@ -872,7 +812,7 @@ def _make_ssl_context(verified: bool) -> SSLContext: """ if ssl is None: # No ssl support - return None + return None # type: ignore[unreachable] if verified: sslcontext = ssl.create_default_context() else: @@ -939,26 +879,22 @@ class TCPConnector(BaseConnector): def __init__( self, *, - verify_ssl: bool = True, - fingerprint: Optional[bytes] = None, use_dns_cache: bool = True, - ttl_dns_cache: Optional[int] = 10, + ttl_dns_cache: int | None = 10, family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC, - ssl_context: Optional[SSLContext] = None, - ssl: Union[bool, Fingerprint, SSLContext] = True, - local_addr: Optional[Tuple[str, int]] = None, - resolver: Optional[AbstractResolver] = None, - keepalive_timeout: Union[None, float, object] = sentinel, + ssl: bool | Fingerprint | SSLContext = True, + local_addr: tuple[str, int] | None = None, + resolver: AbstractResolver | None = None, + keepalive_timeout: None | float | _SENTINEL = sentinel, force_close: bool = False, limit: int = 100, limit_per_host: int = 0, enable_cleanup_closed: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, timeout_ceil_threshold: float = 5, - happy_eyeballs_delay: Optional[float] = 0.25, - interleave: Optional[int] = None, - socket_factory: Optional[SocketFactoryType] = None, - ssl_shutdown_timeout: Union[_SENTINEL, None, float] = sentinel, + happy_eyeballs_delay: float | None = 0.25, + interleave: int | None = None, + socket_factory: SocketFactoryType | None = None, + ssl_shutdown_timeout: _SENTINEL | None | float = sentinel, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -966,15 +902,19 @@ def __init__( limit=limit, limit_per_host=limit_per_host, enable_cleanup_closed=enable_cleanup_closed, - loop=loop, timeout_ceil_threshold=timeout_ceil_threshold, ) - self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) + if not isinstance(ssl, SSL_ALLOWED_TYPES): + raise TypeError( + "ssl should be SSLContext, Fingerprint, or bool, " + f"got {ssl!r} instead." + ) + self._ssl = ssl self._resolver: AbstractResolver if resolver is None: - self._resolver = DefaultResolver(loop=self._loop) + self._resolver = DefaultResolver() self._resolver_owner = True else: self._resolver = resolver @@ -982,16 +922,17 @@ def __init__( self._use_dns_cache = use_dns_cache self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) - self._throttle_dns_futures: Dict[ - Tuple[str, int], Set["asyncio.Future[None]"] - ] = {} + self._throttle_dns_futures: dict[tuple[str, int], set[asyncio.Future[None]]] = ( + {} + ) self._family = family self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) self._happy_eyeballs_delay = happy_eyeballs_delay self._interleave = interleave - self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() + self._resolve_host_tasks: set[asyncio.Task[list[ResolveResult]]] = set() self._socket_factory = socket_factory - self._ssl_shutdown_timeout: Optional[float] + self._ssl_shutdown_timeout: float | None + # Handle ssl_shutdown_timeout with warning for Python < 3.11 if ssl_shutdown_timeout is sentinel: self._ssl_shutdown_timeout = 0 @@ -1015,22 +956,8 @@ def __init__( ) self._ssl_shutdown_timeout = ssl_shutdown_timeout - def _close(self, *, abort_ssl: bool = False) -> List[Awaitable[object]]: - """Close all ongoing DNS calls.""" - for fut in chain.from_iterable(self._throttle_dns_futures.values()): - fut.cancel() - - waiters = super()._close(abort_ssl=abort_ssl) - - for t in self._resolve_host_tasks: - t.cancel() - waiters.append(t) - - return waiters - async def close(self, *, abort_ssl: bool = False) -> None: - """ - Close all opened transports. + """Close all opened transports. :param abort_ssl: If True, SSL connections will be aborted immediately without performing the shutdown handshake. If False (default), @@ -1043,6 +970,18 @@ async def close(self, *, abort_ssl: bool = False) -> None: # Use abort_ssl param if explicitly set, otherwise use ssl_shutdown_timeout default await super().close(abort_ssl=abort_ssl or self._ssl_shutdown_timeout == 0) + def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[object]]: + for fut in chain.from_iterable(self._throttle_dns_futures.values()): + fut.cancel() + + waiters = super()._close_immediately(abort_ssl=abort_ssl) + + for t in self._resolve_host_tasks: + t.cancel() + waiters.append(t) + + return waiters + @property def family(self) -> int: """Socket family like AF_INET.""" @@ -1053,9 +992,7 @@ def use_dns_cache(self) -> bool: """True if local DNS caching is enabled.""" return self._use_dns_cache - def clear_dns_cache( - self, host: Optional[str] = None, port: Optional[int] = None - ) -> None: + def clear_dns_cache(self, host: str | None = None, port: int | None = None) -> None: """Remove specified host/port or clear all dns local cache.""" if host is not None and port is not None: self._cached_hosts.remove((host, port)) @@ -1065,8 +1002,8 @@ def clear_dns_cache( self._cached_hosts.clear() async def _resolve_host( - self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None - ) -> List[ResolveResult]: + self, host: str, port: int, traces: Sequence["Trace"] | None = None + ) -> list[ResolveResult]: """Resolve host and return list of addresses.""" if is_ip_address(host): return [ @@ -1081,7 +1018,6 @@ async def _resolve_host( ] if not self._use_dns_cache: - if traces: for trace in traces: await trace.send_dns_resolvehost_start(host) @@ -1104,7 +1040,7 @@ async def _resolve_host( await trace.send_dns_cache_hit(host) return result - futures: Set["asyncio.Future[None]"] + futures: set[asyncio.Future[None]] # # If multiple connectors are resolving the same host, we wait # for the first one to resolve and then use the result for all of them. @@ -1148,7 +1084,7 @@ async def _resolve_host( return await asyncio.shield(resolved_host_task) except asyncio.CancelledError: - def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None: + def drop_exception(fut: "asyncio.Future[list[ResolveResult]]") -> None: with suppress(Exception, asyncio.CancelledError): fut.result() @@ -1157,12 +1093,12 @@ def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None: async def _resolve_host_with_throttle( self, - key: Tuple[str, int], + key: tuple[str, int], host: str, port: int, - futures: Set["asyncio.Future[None]"], - traces: Optional[Sequence["Trace"]], - ) -> List[ResolveResult]: + futures: set[asyncio.Future[None]], + traces: Sequence["Trace"] | None, + ) -> list[ResolveResult]: """Resolve host and set result for all waiters. This method must be run in a task and shielded from cancellation @@ -1197,7 +1133,7 @@ async def _resolve_host_with_throttle( return self._cached_hosts.next_addrs(key) async def _create_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: """Create connection. @@ -1210,7 +1146,7 @@ async def _create_connection( return proto - def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: + def _get_ssl_context(self, req: ClientRequestBase) -> SSLContext | None: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -1243,7 +1179,7 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: return _SSL_CONTEXT_UNVERIFIED return _SSL_CONTEXT_VERIFIED - def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: + def _get_fingerprint(self, req: ClientRequestBase) -> "Fingerprint | None": ret = req.ssl if isinstance(ret, Fingerprint): return ret @@ -1255,71 +1191,32 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: async def _wrap_create_connection( self, *args: Any, - addr_infos: List[AddrInfoType], - req: ClientRequest, - timeout: "ClientTimeout", - client_error: Type[Exception] = ClientConnectorError, - **kwargs: Any, - ) -> Tuple[asyncio.Transport, ResponseHandler]: - # Add ssl_shutdown_timeout for Python 3.11+ when SSL is used - if ( - kwargs.get("ssl") - and self._ssl_shutdown_timeout - and sys.version_info >= (3, 11) - ): - kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout - try: - async with ceil_timeout( - timeout.sock_connect, ceil_threshold=timeout.ceil_threshold - ): - if self._happy_eyeballs_delay is None: - # If happyeyeballs is disabled, connect in sequence - # this avoids a bug in uvloop where it can lose track - # of sockets passed between aiohappyeyeballs.start_connect - # and create_connection and try to reuse the same fd. - # https://github.com/aio-libs/aiohttp/issues/10506 - # https://github.com/MagicStack/uvloop/issues/645 - first_addr_infos = addr_infos[0] - address_tuple = first_addr_infos[4] - host: str = address_tuple[0] - port: int = address_tuple[1] - return await self._loop.create_connection( - *args, host=host, port=port, **kwargs - ) - else: - sock = await aiohappyeyeballs.start_connection( - addr_infos=addr_infos, - local_addr_infos=self._local_addr_infos, - happy_eyeballs_delay=self._happy_eyeballs_delay, - interleave=self._interleave, - loop=self._loop, - socket_factory=self._socket_factory, - ) - return await self._loop.create_connection( - *args, **kwargs, sock=sock - ) - except cert_errors as exc: - raise ClientConnectorCertificateError(req.connection_key, exc) from exc - except ssl_errors as exc: - raise ClientConnectorSSLError(req.connection_key, exc) from exc - except OSError as exc: - if exc.errno is None and isinstance(exc, asyncio.TimeoutError): - raise - raise client_error(req.connection_key, exc) from exc - - async def _wrap_existing_connection( - self, - *args: Any, - req: ClientRequest, + addr_infos: list[AddrInfoType], + req: ClientRequestBase, timeout: "ClientTimeout", - client_error: Type[Exception] = ClientConnectorError, + client_error: type[Exception] = ClientConnectorError, **kwargs: Any, - ) -> Tuple[asyncio.Transport, ResponseHandler]: + ) -> tuple[asyncio.Transport, ResponseHandler]: try: async with ceil_timeout( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold ): - return await self._loop.create_connection(*args, **kwargs) + sock = await aiohappyeyeballs.start_connection( + addr_infos=addr_infos, + local_addr_infos=self._local_addr_infos, + happy_eyeballs_delay=self._happy_eyeballs_delay, + interleave=self._interleave, + loop=self._loop, + socket_factory=self._socket_factory, + ) + # Add ssl_shutdown_timeout for Python 3.11+ when SSL is used + if ( + kwargs.get("ssl") + and self._ssl_shutdown_timeout + and sys.version_info >= (3, 11) + ): + kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout + return await self._loop.create_connection(*args, **kwargs, sock=sock) except cert_errors as exc: raise ClientConnectorCertificateError(req.connection_key, exc) from exc except ssl_errors as exc: @@ -1329,56 +1226,13 @@ async def _wrap_existing_connection( raise raise client_error(req.connection_key, exc) from exc - def _fail_on_no_start_tls(self, req: "ClientRequest") -> None: - """Raise a :py:exc:`RuntimeError` on missing ``start_tls()``. - - It is necessary for TLS-in-TLS so that it is possible to - send HTTPS queries through HTTPS proxies. - - This doesn't affect regular HTTP requests, though. - """ - if not req.is_ssl(): - return - - proxy_url = req.proxy - assert proxy_url is not None - if proxy_url.scheme != "https": - return - - self._check_loop_for_start_tls() - - def _check_loop_for_start_tls(self) -> None: - try: - self._loop.start_tls - except AttributeError as attr_exc: - raise RuntimeError( - "An HTTPS request is being sent through an HTTPS proxy. " - "This needs support for TLS in TLS but it is not implemented " - "in your runtime for the stdlib asyncio.\n\n" - "Please upgrade to Python 3.11 or higher. For more details, " - "please see:\n" - "* https://bugs.python.org/issue37179\n" - "* https://github.com/python/cpython/pull/28073\n" - "* https://docs.aiohttp.org/en/stable/" - "client_advanced.html#proxy-support\n" - "* https://github.com/aio-libs/aiohttp/discussions/6044\n", - ) from attr_exc - - def _loop_supports_start_tls(self) -> bool: - try: - self._check_loop_for_start_tls() - except RuntimeError: - return False - else: - return True - def _warn_about_tls_in_tls( self, underlying_transport: asyncio.Transport, req: ClientRequest, ) -> None: """Issue a warning if the requested URL has HTTPS scheme.""" - if req.request_info.url.scheme != "https": + if req.url.scheme != "https": return # Check if uvloop is being used, which supports TLS in TLS, @@ -1399,7 +1253,7 @@ def _warn_about_tls_in_tls( warnings.warn( "An HTTPS request is being sent through an HTTPS proxy. " "This support for TLS in TLS is known to be disabled " - "in the stdlib asyncio (Python <3.11). This is why you'll probably see " + "in the stdlib asyncio. This is why you'll probably see " "an error in the log below.\n\n" "It is possible to enable it via monkeypatching. " "For more details, see:\n" @@ -1420,8 +1274,8 @@ async def _start_tls_connection( underlying_transport: asyncio.Transport, req: ClientRequest, timeout: "ClientTimeout", - client_error: Type[Exception] = ClientConnectorError, - ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: + client_error: type[Exception] = ClientConnectorError, + ) -> tuple[asyncio.BaseTransport, ResponseHandler]: """Wrap the raw TCP transport with TLS.""" tls_proto = self._factory() # Create a brand new proto for TLS sslcontext = self._get_ssl_context(req) @@ -1441,7 +1295,7 @@ async def _start_tls_connection( underlying_transport, tls_proto, sslcontext, - server_hostname=req.server_hostname or req.host, + server_hostname=req.server_hostname or req.url.raw_host, ssl_handshake_timeout=timeout.total, ssl_shutdown_timeout=self._ssl_shutdown_timeout, ) @@ -1450,7 +1304,7 @@ async def _start_tls_connection( underlying_transport, tls_proto, sslcontext, - server_hostname=req.server_hostname or req.host, + server_hostname=req.server_hostname or req.url.raw_host, ssl_handshake_timeout=timeout.total, ) except BaseException: @@ -1487,7 +1341,7 @@ async def _start_tls_connection( raise ClientConnectionError( "Cannot initialize a TLS-in-TLS connection to host " - f"{req.host!s}:{req.port:d} through an underlying connection " + f"{req.url.host!s}:{req.url.port:d} through an underlying connection " f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} " f"[{type_err!s}]" ) from type_err @@ -1502,14 +1356,14 @@ async def _start_tls_connection( return tls_transport, tls_proto def _convert_hosts_to_addr_infos( - self, hosts: List[ResolveResult] - ) -> List[AddrInfoType]: + self, hosts: list[ResolveResult] + ) -> list[AddrInfoType]: """Converts the list of hosts to a list of addr_infos. The list of hosts is the result of a DNS lookup. The list of addr_infos is the result of a call to `socket.getaddrinfo()`. """ - addr_infos: List[AddrInfoType] = [] + addr_infos: list[AddrInfoType] = [] for hinfo in hosts: host = hinfo["host"] is_ipv6 = ":" in host @@ -1524,12 +1378,12 @@ def _convert_hosts_to_addr_infos( async def _create_direct_connection( self, - req: ClientRequest, - traces: List["Trace"], + req: ClientRequestBase, + traces: list["Trace"], timeout: "ClientTimeout", *, - client_error: Type[Exception] = ClientConnectorError, - ) -> Tuple[asyncio.Transport, ResponseHandler]: + client_error: type[Exception] = ClientConnectorError, + ) -> tuple[asyncio.Transport, ResponseHandler]: sslcontext = self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) @@ -1540,7 +1394,7 @@ async def _create_direct_connection( # See https://github.com/aio-libs/aiohttp/pull/7364. if host.endswith(".."): host = host.rstrip(".") + "." - port = req.port + port = req.url.port assert port is not None try: # Cancelling this lookup should not cancel the underlying lookup @@ -1554,7 +1408,7 @@ async def _create_direct_connection( # it is problem of resolving proxy ip itself raise ClientConnectorDNSError(req.connection_key, exc) from exc - last_exc: Optional[Exception] = None + last_exc: Exception | None = None addr_infos = self._convert_hosts_to_addr_infos(hosts) while addr_infos: # Strip trailing dots, certificates contain FQDN without dots. @@ -1575,12 +1429,7 @@ async def _create_direct_connection( ) except (ClientConnectorError, asyncio.TimeoutError) as exc: last_exc = exc - if self._happy_eyeballs_delay is None: - addr_infos.pop(0) - else: - aiohappyeyeballs.pop_addr_infos_interleave( - addr_infos, self._interleave - ) + aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave) continue if req.is_ssl() and fingerprint: @@ -1598,24 +1447,18 @@ async def _create_direct_connection( continue return transp, proto - else: - assert last_exc is not None - raise last_exc + assert last_exc is not None + raise last_exc async def _create_proxy_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" - ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: - self._fail_on_no_start_tls(req) - runtime_has_start_tls = self._loop_supports_start_tls() - - headers: Dict[str, str] = {} - if req.proxy_headers is not None: - headers = req.proxy_headers # type: ignore[assignment] + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" + ) -> tuple[asyncio.BaseTransport, ResponseHandler]: + headers = CIMultiDict[str]() if req.proxy_headers is None else req.proxy_headers headers[hdrs.HOST] = req.headers[hdrs.HOST] url = req.proxy assert url is not None - proxy_req = ClientRequest( + proxy_req = ClientRequestBase( hdrs.METH_GET, url, headers=headers, @@ -1637,8 +1480,7 @@ async def _create_proxy_connection( proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth if req.is_ssl(): - if runtime_has_start_tls: - self._warn_about_tls_in_tls(transport, req) + self._warn_about_tls_in_tls(transport, req) # For HTTPS requests over HTTP proxy # we must notify proxy to tunnel connection @@ -1655,7 +1497,7 @@ async def _create_proxy_connection( proxy=None, proxy_auth=None, proxy_headers_hash=None ) conn = _ConnectTunnelConnection(self, key, proto, self._loop) - proxy_resp = await proxy_req.send(conn) + proxy_resp = await proxy_req._send(conn) try: protocol = conn._protocol assert protocol is not None @@ -1664,7 +1506,7 @@ async def _create_proxy_connection( # once the response is received and processed allowing # START_TLS to work on the connection below. protocol.set_response_params( - read_until_eof=runtime_has_start_tls, + read_until_eof=True, timeout_ceil_threshold=self._timeout_ceil_threshold, ) resp = await proxy_resp.start(conn) @@ -1686,35 +1528,12 @@ async def _create_proxy_connection( message=message, headers=resp.headers, ) - if not runtime_has_start_tls: - rawsock = transport.get_extra_info("socket", default=None) - if rawsock is None: - raise RuntimeError( - "Transport does not expose socket instance" - ) - # Duplicate the socket, so now we can close proxy transport - rawsock = rawsock.dup() except BaseException: # It shouldn't be closed in `finally` because it's fed to # `loop.start_tls()` and the docs say not to touch it after # passing there. transport.close() raise - finally: - if not runtime_has_start_tls: - transport.close() - - if not runtime_has_start_tls: - # HTTP proxy with support for upgrade to HTTPS - sslcontext = self._get_ssl_context(req) - return await self._wrap_existing_connection( - self._factory, - timeout=timeout, - ssl=sslcontext, - sock=rawsock, - server_hostname=req.host, - req=req, - ) return await self._start_tls_connection( # Access the old transport for the last time before it's @@ -1747,17 +1566,15 @@ def __init__( self, path: str, force_close: bool = False, - keepalive_timeout: Union[object, float, None] = sentinel, + keepalive_timeout: _SENTINEL | float | None = sentinel, limit: int = 100, limit_per_host: int = 0, - loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__( force_close=force_close, keepalive_timeout=keepalive_timeout, limit=limit, limit_per_host=limit_per_host, - loop=loop, ) self._path = path @@ -1767,7 +1584,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( @@ -1805,17 +1622,15 @@ def __init__( self, path: str, force_close: bool = False, - keepalive_timeout: Union[object, float, None] = sentinel, + keepalive_timeout: _SENTINEL | float | None = sentinel, limit: int = 100, limit_per_host: int = 0, - loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__( force_close=force_close, keepalive_timeout=keepalive_timeout, limit=limit, limit_per_host=limit_per_host, - loop=loop, ) if not isinstance( self._loop, @@ -1832,7 +1647,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( From 2e6d996b2d301cdcd9846c08cef34e781d96e96b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:36:42 -0600 Subject: [PATCH 11/42] fallback overloads --- aiohttp/client.py | 8 ++++---- aiohttp/connector.py | 49 ++++++++++++++++++++++++++------------------ 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index a4e9ee13de2..153fbf2259b 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -924,7 +924,7 @@ def ws_connect( *, decode_text: bool = ..., **kwargs: Unpack[_WSConnectOptions], - ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Any]]": ... + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ... def ws_connect( self, @@ -949,7 +949,7 @@ def ws_connect( compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, decode_text: bool = True, - ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Any]]": + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": """Initiate websocket connection.""" return _WSRequestContextManager( self._ws_connect( @@ -1003,7 +1003,7 @@ async def _ws_connect( *, decode_text: bool = ..., **kwargs: Unpack[_WSConnectOptions], - ) -> "ClientWebSocketResponse[Any]": ... + ) -> "ClientWebSocketResponse[bool]": ... async def _ws_connect( self, @@ -1028,7 +1028,7 @@ async def _ws_connect( compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, decode_text: bool = True, - ) -> "ClientWebSocketResponse[Any]": + ) -> "ClientWebSocketResponse[bool]": if timeout is not sentinel: if isinstance(timeout, ClientWSTimeout): ws_timeout = timeout diff --git a/aiohttp/connector.py b/aiohttp/connector.py index b1820358bae..2cdc425d83f 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -550,6 +550,30 @@ def _available_connections(self, key: "ConnectionKey") -> int: return total_remain + def _update_proxy_auth_header_and_build_proxy_req( + self, req: ClientRequest + ) -> ClientRequestBase: + """Set Proxy-Authorization header for non-SSL proxy requests and builds the proxy request for SSL proxy requests.""" + url = req.proxy + assert url is not None + headers = req.proxy_headers or CIMultiDict[str]() + headers[hdrs.HOST] = req.headers[hdrs.HOST] + proxy_req = ClientRequestBase( + hdrs.METH_GET, + url, + headers=headers, + auth=req.proxy_auth, + loop=self._loop, + ssl=req.ssl, + ) + auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) + if auth is not None: + if not req.is_ssl(): + req.headers[hdrs.PROXY_AUTHORIZATION] = auth + else: + proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth + return proxy_req + async def connect( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> Connection: @@ -558,12 +582,16 @@ async def connect( if (conn := await self._get(key, traces)) is not None: # If we do not have to wait and we can get a connection from the pool # we can avoid the timeout ceil logic and directly return the connection + if req.proxy: + self._update_proxy_auth_header_and_build_proxy_req(req) return conn async with ceil_timeout(timeout.connect, timeout.ceil_threshold): if self._available_connections(key) <= 0: await self._wait_for_available_connection(key, traces) if (conn := await self._get(key, traces)) is not None: + if req.proxy: + self._update_proxy_auth_header_and_build_proxy_req(req) return conn placeholder = cast( @@ -1453,32 +1481,13 @@ async def _create_direct_connection( async def _create_proxy_connection( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> tuple[asyncio.BaseTransport, ResponseHandler]: - headers = CIMultiDict[str]() if req.proxy_headers is None else req.proxy_headers - headers[hdrs.HOST] = req.headers[hdrs.HOST] - - url = req.proxy - assert url is not None - proxy_req = ClientRequestBase( - hdrs.METH_GET, - url, - headers=headers, - auth=req.proxy_auth, - loop=self._loop, - ssl=req.ssl, - ) + proxy_req = self._update_proxy_auth_header_and_build_proxy_req(req) # create connection to proxy server transport, proto = await self._create_direct_connection( proxy_req, [], timeout, client_error=ClientProxyConnectionError ) - auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) - if auth is not None: - if not req.is_ssl(): - req.headers[hdrs.PROXY_AUTHORIZATION] = auth - else: - proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth - if req.is_ssl(): self._warn_about_tls_in_tls(transport, req) From 550ca1a5d9d233759f8b70865787b4d0f317b522 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:38:29 -0600 Subject: [PATCH 12/42] fallback overloads --- aiohttp/client_ws.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index a0556140c98..7e8b9d049f6 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -36,7 +36,8 @@ Self = TypeVar("Self", bound="ClientWebSocketResponse[Any]") # TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) -_DecodeText = TypeVar("_DecodeText", bound=bool, default=Literal[True]) +# Covariant because it only affects return types, not input types +_DecodeText = TypeVar("_DecodeText", bound=bool, default=Literal[True], covariant=True) @frozen_dataclass_decorator From bc4f6a09ae45db438e4a7a703b82dd9543268745 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:40:34 -0600 Subject: [PATCH 13/42] fallback overloads --- aiohttp/client.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index 153fbf2259b..3236202048f 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -246,7 +246,11 @@ class ClientTimeout: # https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) -_RetType = TypeVar("_RetType", bound="ClientResponse | ClientWebSocketResponse[Any]") +_RetType_co = TypeVar( + "_RetType_co", + bound="ClientResponse | ClientWebSocketResponse[bool]", + covariant=True, +) _CharsetResolver = Callable[[ClientResponse, bytes], str] @@ -1467,11 +1471,15 @@ async def __aexit__( await self.close() -class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]): +class _BaseRequestContextManager( + Coroutine[Any, Any, _RetType_co], Generic[_RetType_co] +): __slots__ = ("_coro", "_resp") - def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None: - self._coro: Coroutine[asyncio.Future[Any], None, _RetType] = coro + def __init__( + self, coro: Coroutine["asyncio.Future[Any]", None, _RetType_co] + ) -> None: + self._coro: Coroutine[asyncio.Future[Any], None, _RetType_co] = coro def send(self, arg: None) -> "asyncio.Future[Any]": return self._coro.send(arg) @@ -1482,15 +1490,15 @@ def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]": def close(self) -> None: return self._coro.close() - def __await__(self) -> Generator[Any, None, _RetType]: + def __await__(self) -> Generator[Any, None, _RetType_co]: ret = self._coro.__await__() return ret - def __iter__(self) -> Generator[Any, None, _RetType]: + def __iter__(self) -> Generator[Any, None, _RetType_co]: return self.__await__() - async def __aenter__(self) -> _RetType: - self._resp: _RetType = await self._coro + async def __aenter__(self) -> _RetType_co: + self._resp: _RetType_co = await self._coro return await self._resp.__aenter__() async def __aexit__( From 56206480f66672ff13b0db8289d22edacb7d437c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:41:19 -0600 Subject: [PATCH 14/42] fallback overloads --- aiohttp/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index 3236202048f..ee485fb10a1 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -1511,7 +1511,7 @@ async def __aexit__( _RequestContextManager = _BaseRequestContextManager[ClientResponse] -_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse] +_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse[bool]] class _SessionRequestContextManager: From e0255460b6e06aeac5c305f04903e05f25f7e185 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:44:31 -0600 Subject: [PATCH 15/42] fallback overloads --- aiohttp/client.py | 12 +++++------- aiohttp/client_ws.py | 2 +- aiohttp/web_ws.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index ee485fb10a1..b7b5c8a7acb 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -1476,15 +1476,13 @@ class _BaseRequestContextManager( ): __slots__ = ("_coro", "_resp") - def __init__( - self, coro: Coroutine["asyncio.Future[Any]", None, _RetType_co] - ) -> None: + def __init__(self, coro: Coroutine[asyncio.Future[Any], None, _RetType_co]) -> None: self._coro: Coroutine[asyncio.Future[Any], None, _RetType_co] = coro - def send(self, arg: None) -> "asyncio.Future[Any]": + def send(self, arg: None) -> asyncio.Future[Any]: return self._coro.send(arg) - def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]": + def throw(self, *args: Any, **kwargs: Any) -> asyncio.Future[Any]: return self._coro.throw(*args, **kwargs) def close(self) -> None: @@ -1499,7 +1497,7 @@ def __iter__(self) -> Generator[Any, None, _RetType_co]: async def __aenter__(self) -> _RetType_co: self._resp: _RetType_co = await self._coro - return await self._resp.__aenter__() + return await self._resp.__aenter__() # type: ignore[return-value] async def __aexit__( self, @@ -1519,7 +1517,7 @@ class _SessionRequestContextManager: def __init__( self, - coro: Coroutine["asyncio.Future[Any]", None, ClientResponse], + coro: Coroutine[asyncio.Future[Any], None, ClientResponse], session: ClientSession, ) -> None: self._coro = coro diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 7e8b9d049f6..feea531e4d9 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -33,7 +33,7 @@ else: import async_timeout - Self = TypeVar("Self", bound="ClientWebSocketResponse[Any]") + Self = TypeVar("Self", bound="ClientWebSocketResponse[bool]") # TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) # Covariant because it only affects return types, not input types diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 786149081b7..69ac076f0b4 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -47,7 +47,7 @@ else: import async_timeout - Self = TypeVar("Self", bound="WebSocketResponse[Any]") + Self = TypeVar("Self", bound="WebSocketResponse[bool]") __all__ = ( "WebSocketResponse", From 5ef6d87110485dbb6a8cfc544223df51b1639c2b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:49:21 -0600 Subject: [PATCH 16/42] narrow receive --- aiohttp/_websocket/models.py | 26 ++++++++++++++++++++++++++ aiohttp/client_ws.py | 15 +++++++++++++++ aiohttp/http.py | 4 ++++ aiohttp/web_ws.py | 15 +++++++++++++++ 4 files changed, 60 insertions(+) diff --git a/aiohttp/_websocket/models.py b/aiohttp/_websocket/models.py index b42f88fd5f3..8819ae80e70 100644 --- a/aiohttp/_websocket/models.py +++ b/aiohttp/_websocket/models.py @@ -142,6 +142,32 @@ class WSMessageError(NamedTuple): WSMessageError, ] +# Message type when decode_text=True (default) - TEXT messages have str data +WSMessageDecodeText = Union[ + WSMessageContinuation, + WSMessageText, + WSMessageBinary, + WSMessagePing, + WSMessagePong, + WSMessageClose, + WSMessageClosing, + WSMessageClosed, + WSMessageError, +] + +# Message type when decode_text=False - TEXT messages have bytes data +WSMessageNoDecodeText = Union[ + WSMessageContinuation, + WSMessageTextBytes, + WSMessageBinary, + WSMessagePing, + WSMessagePong, + WSMessageClose, + WSMessageClosing, + WSMessageClosed, + WSMessageError, +] + WS_CLOSED_MESSAGE = WSMessageClosed() WS_CLOSING_MESSAGE = WSMessageClosing() diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index feea531e4d9..72caa64a23b 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -16,6 +16,8 @@ WebSocketError, WSCloseCode, WSMessage, + WSMessageDecodeText, + WSMessageNoDecodeText, WSMsgType, ) from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter, WSMessageError @@ -317,6 +319,19 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo self._response.close() return True + @overload + async def receive( + self: "ClientWebSocketResponse[Literal[True]]", timeout: float | None = None + ) -> WSMessageDecodeText: ... + + @overload + async def receive( + self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None + ) -> WSMessageNoDecodeText: ... + + @overload + async def receive(self, timeout: float | None = None) -> WSMessage: ... + async def receive(self, timeout: float | None = None) -> WSMessage: receive_timeout = timeout or self._timeout.ws_receive diff --git a/aiohttp/http.py b/aiohttp/http.py index 6dad94bb11c..9d50377edf6 100644 --- a/aiohttp/http.py +++ b/aiohttp/http.py @@ -19,6 +19,8 @@ WebSocketWriter, WSCloseCode, WSMessage, + WSMessageDecodeText, + WSMessageNoDecodeText, WSMsgType, ws_ext_gen, ws_ext_parse, @@ -49,6 +51,8 @@ "ws_ext_gen", "ws_ext_parse", "WSMessage", + "WSMessageDecodeText", + "WSMessageNoDecodeText", "WebSocketError", "WSMsgType", "WSCloseCode", diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 69ac076f0b4..640002df223 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -29,6 +29,8 @@ WebSocketWriter, WSCloseCode, WSMessage, + WSMessageDecodeText, + WSMessageNoDecodeText, WSMsgType, ws_ext_gen, ws_ext_parse, @@ -525,6 +527,19 @@ def _close_transport(self) -> None: if self._req is not None and self._req.transport is not None: self._req.transport.close() + @overload + async def receive( + self: "WebSocketResponse[Literal[True]]", timeout: float | None = None + ) -> WSMessageDecodeText: ... + + @overload + async def receive( + self: "WebSocketResponse[Literal[False]]", timeout: float | None = None + ) -> WSMessageNoDecodeText: ... + + @overload + async def receive(self, timeout: float | None = None) -> WSMessage: ... + async def receive(self, timeout: float | None = None) -> WSMessage: if self._reader is None: raise RuntimeError("Call .prepare() first") From 500c23946be4bcfc1cc84d19b08a2dd91c359e8c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:49:43 -0600 Subject: [PATCH 17/42] narrow receive --- aiohttp/http_websocket.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index 830318c0b9a..bc6b387c6b3 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -13,7 +13,9 @@ WSMessageClosed, WSMessageClosing, WSMessageContinuation, + WSMessageDecodeText, WSMessageError, + WSMessageNoDecodeText, WSMessagePing, WSMessagePong, WSMessageText, @@ -36,6 +38,8 @@ "WebSocketReader", "WebSocketWriter", "WSMessage", + "WSMessageDecodeText", + "WSMessageNoDecodeText", "WebSocketError", "WSMsgType", "WSCloseCode", From 2b85337fa8f007446b2e6fd131004bf1b2915f1e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:58:13 -0600 Subject: [PATCH 18/42] narrow --- aiohttp/client_ws.py | 19 ++++++++++++++++--- aiohttp/web_ws.py | 19 ++++++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 72caa64a23b..f27da03756b 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -330,7 +330,7 @@ async def receive( ) -> WSMessageNoDecodeText: ... @overload - async def receive(self, timeout: float | None = None) -> WSMessage: ... + async def receive(self, timeout: float | None = None) -> WSMessageDecodeText: ... async def receive(self, timeout: float | None = None) -> WSMessage: receive_timeout = timeout or self._timeout.ws_receive @@ -417,7 +417,7 @@ async def receive_str( ) -> bytes: ... @overload - async def receive_str(self, *, timeout: float | None = None) -> str | bytes: ... + async def receive_str(self, *, timeout: float | None = None) -> str: ... async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. @@ -459,7 +459,7 @@ async def receive_json( async def receive_json( self, *, - loads: Callable[[bytes | bytearray | memoryview | str], Any] = ..., + loads: JSONDecoder = ..., timeout: float | None = None, ) -> Any: ... @@ -477,6 +477,19 @@ async def receive_json( def __aiter__(self) -> Self: return self + @overload + async def __anext__( + self: "ClientWebSocketResponse[Literal[True]]", + ) -> WSMessageDecodeText: ... + + @overload + async def __anext__( + self: "ClientWebSocketResponse[Literal[False]]", + ) -> WSMessageNoDecodeText: ... + + @overload + async def __anext__(self) -> WSMessageDecodeText: ... + async def __anext__(self) -> WSMessage: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 640002df223..d6fe16a5c3c 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -538,7 +538,7 @@ async def receive( ) -> WSMessageNoDecodeText: ... @overload - async def receive(self, timeout: float | None = None) -> WSMessage: ... + async def receive(self, timeout: float | None = None) -> WSMessageDecodeText: ... async def receive(self, timeout: float | None = None) -> WSMessage: if self._reader is None: @@ -625,7 +625,7 @@ async def receive_str( ) -> bytes: ... @overload - async def receive_str(self, *, timeout: float | None = None) -> str | bytes: ... + async def receive_str(self, *, timeout: float | None = None) -> str: ... async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. @@ -667,7 +667,7 @@ async def receive_json( async def receive_json( self, *, - loads: Callable[[bytes | bytearray | memoryview | str], Any] = ..., + loads: JSONDecoder = ..., timeout: float | None = None, ) -> Any: ... @@ -690,6 +690,19 @@ async def write( def __aiter__(self) -> Self: return self + @overload + async def __anext__( + self: "WebSocketResponse[Literal[True]]", + ) -> WSMessageDecodeText: ... + + @overload + async def __anext__( + self: "WebSocketResponse[Literal[False]]", + ) -> WSMessageNoDecodeText: ... + + @overload + async def __anext__(self) -> WSMessageDecodeText: ... + async def __anext__(self) -> WSMessage: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): From ef7b0b2f7f7d029cd0bf752db6aab720d01631f1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 13:59:27 -0600 Subject: [PATCH 19/42] default only works on py3.13+ --- aiohttp/client_ws.py | 2 +- aiohttp/web_ws.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index f27da03756b..55ed1eca187 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -39,7 +39,7 @@ # TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) # Covariant because it only affects return types, not input types -_DecodeText = TypeVar("_DecodeText", bound=bool, default=Literal[True], covariant=True) +_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True) @frozen_dataclass_decorator diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index d6fe16a5c3c..b2a9acded72 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -60,7 +60,7 @@ THRESHOLD_CONNLOST_ACCESS: Final[int] = 5 # TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) -_DecodeText = TypeVar("_DecodeText", bound=bool, default=Literal[True]) +_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True) @frozen_dataclass_decorator From 51b6154781dbd7ea5dd30bd5593dc38b4d0c6109 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 14:14:16 -0600 Subject: [PATCH 20/42] no fallback; --- aiohttp/client_ws.py | 17 ----------------- aiohttp/web_ws.py | 17 ----------------- 2 files changed, 34 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 55ed1eca187..c963234983e 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -329,9 +329,6 @@ async def receive( self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None ) -> WSMessageNoDecodeText: ... - @overload - async def receive(self, timeout: float | None = None) -> WSMessageDecodeText: ... - async def receive(self, timeout: float | None = None) -> WSMessage: receive_timeout = timeout or self._timeout.ws_receive @@ -416,9 +413,6 @@ async def receive_str( self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None ) -> bytes: ... - @overload - async def receive_str(self, *, timeout: float | None = None) -> str: ... - async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. @@ -455,14 +449,6 @@ async def receive_json( timeout: float | None = None, ) -> Any: ... - @overload - async def receive_json( - self, - *, - loads: JSONDecoder = ..., - timeout: float | None = None, - ) -> Any: ... - async def receive_json( self, *, @@ -487,9 +473,6 @@ async def __anext__( self: "ClientWebSocketResponse[Literal[False]]", ) -> WSMessageNoDecodeText: ... - @overload - async def __anext__(self) -> WSMessageDecodeText: ... - async def __anext__(self) -> WSMessage: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index b2a9acded72..d8600db7d0c 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -537,9 +537,6 @@ async def receive( self: "WebSocketResponse[Literal[False]]", timeout: float | None = None ) -> WSMessageNoDecodeText: ... - @overload - async def receive(self, timeout: float | None = None) -> WSMessageDecodeText: ... - async def receive(self, timeout: float | None = None) -> WSMessage: if self._reader is None: raise RuntimeError("Call .prepare() first") @@ -624,9 +621,6 @@ async def receive_str( self: "WebSocketResponse[Literal[False]]", *, timeout: float | None = None ) -> bytes: ... - @overload - async def receive_str(self, *, timeout: float | None = None) -> str: ... - async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. @@ -663,14 +657,6 @@ async def receive_json( timeout: float | None = None, ) -> Any: ... - @overload - async def receive_json( - self, - *, - loads: JSONDecoder = ..., - timeout: float | None = None, - ) -> Any: ... - async def receive_json( self, *, @@ -700,9 +686,6 @@ async def __anext__( self: "WebSocketResponse[Literal[False]]", ) -> WSMessageNoDecodeText: ... - @overload - async def __anext__(self) -> WSMessageDecodeText: ... - async def __anext__(self) -> WSMessage: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): From 53285131d24353ff945995a21fa23801bc16b09f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 14:16:44 -0600 Subject: [PATCH 21/42] no fallback; --- aiohttp/client_ws.py | 7 ++++--- aiohttp/web_ws.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index c963234983e..adf5e32717d 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -15,7 +15,6 @@ WS_CLOSING_MESSAGE, WebSocketError, WSCloseCode, - WSMessage, WSMessageDecodeText, WSMessageNoDecodeText, WSMsgType, @@ -329,7 +328,9 @@ async def receive( self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None ) -> WSMessageNoDecodeText: ... - async def receive(self, timeout: float | None = None) -> WSMessage: + async def receive( + self, timeout: float | None = None + ) -> WSMessageDecodeText | WSMessageNoDecodeText: receive_timeout = timeout or self._timeout.ws_receive while True: @@ -473,7 +474,7 @@ async def __anext__( self: "ClientWebSocketResponse[Literal[False]]", ) -> WSMessageNoDecodeText: ... - async def __anext__(self) -> WSMessage: + async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): raise StopAsyncIteration diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index d8600db7d0c..0764b039652 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -28,7 +28,6 @@ WebSocketReader, WebSocketWriter, WSCloseCode, - WSMessage, WSMessageDecodeText, WSMessageNoDecodeText, WSMsgType, @@ -537,7 +536,9 @@ async def receive( self: "WebSocketResponse[Literal[False]]", timeout: float | None = None ) -> WSMessageNoDecodeText: ... - async def receive(self, timeout: float | None = None) -> WSMessage: + async def receive( + self, timeout: float | None = None + ) -> WSMessageDecodeText | WSMessageNoDecodeText: if self._reader is None: raise RuntimeError("Call .prepare() first") @@ -686,7 +687,7 @@ async def __anext__( self: "WebSocketResponse[Literal[False]]", ) -> WSMessageNoDecodeText: ... - async def __anext__(self) -> WSMessage: + async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): raise StopAsyncIteration From 976d816fd3e9e26cf7ddf3a09149400c4f1e1054 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 14:21:53 -0600 Subject: [PATCH 22/42] try another way for fallback --- aiohttp/client_ws.py | 25 +++++++++++++++++++++++++ aiohttp/web_ws.py | 20 ++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index adf5e32717d..75ea2058173 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -328,6 +328,11 @@ async def receive( self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None ) -> WSMessageNoDecodeText: ... + @overload + async def receive( + self: "ClientWebSocketResponse[_DecodeText]", timeout: float | None = None + ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... + async def receive( self, timeout: float | None = None ) -> WSMessageDecodeText | WSMessageNoDecodeText: @@ -414,6 +419,11 @@ async def receive_str( self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None ) -> bytes: ... + @overload + async def receive_str( + self: "ClientWebSocketResponse[_DecodeText]", *, timeout: float | None = None + ) -> str | bytes: ... + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. @@ -450,6 +460,16 @@ async def receive_json( timeout: float | None = None, ) -> Any: ... + @overload + async def receive_json( + self: "ClientWebSocketResponse[_DecodeText]", + *, + loads: ( + JSONDecoder | Callable[[bytes | bytearray | memoryview | str], Any] + ) = ..., + timeout: float | None = None, + ) -> Any: ... + async def receive_json( self, *, @@ -474,6 +494,11 @@ async def __anext__( self: "ClientWebSocketResponse[Literal[False]]", ) -> WSMessageNoDecodeText: ... + @overload + async def __anext__( + self: "ClientWebSocketResponse[_DecodeText]", + ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... + async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 0764b039652..4d5feec74c1 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -536,6 +536,11 @@ async def receive( self: "WebSocketResponse[Literal[False]]", timeout: float | None = None ) -> WSMessageNoDecodeText: ... + @overload + async def receive( + self: "WebSocketResponse[_DecodeText]", timeout: float | None = None + ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... + async def receive( self, timeout: float | None = None ) -> WSMessageDecodeText | WSMessageNoDecodeText: @@ -622,6 +627,11 @@ async def receive_str( self: "WebSocketResponse[Literal[False]]", *, timeout: float | None = None ) -> bytes: ... + @overload + async def receive_str( + self: "WebSocketResponse[_DecodeText]", *, timeout: float | None = None + ) -> str | bytes: ... + async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. @@ -658,6 +668,16 @@ async def receive_json( timeout: float | None = None, ) -> Any: ... + @overload + async def receive_json( + self: "WebSocketResponse[_DecodeText]", + *, + loads: ( + JSONDecoder | Callable[[bytes | bytearray | memoryview | str], Any] + ) = ..., + timeout: float | None = None, + ) -> Any: ... + async def receive_json( self, *, From f6d9d1e89e4f3872c0b3c281df467177b3a2d28e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 14:22:22 -0600 Subject: [PATCH 23/42] try another way for fallback --- aiohttp/web_ws.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 4d5feec74c1..e7e75217464 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -707,6 +707,11 @@ async def __anext__( self: "WebSocketResponse[Literal[False]]", ) -> WSMessageNoDecodeText: ... + @overload + async def __anext__( + self: "WebSocketResponse[_DecodeText]", + ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... + async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): From 39f2a4790283971ca214e6da0fa3969521d075f2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 14:35:01 -0600 Subject: [PATCH 24/42] need to have default or everything has to be updated --- aiohttp/client_ws.py | 9 +++++++-- aiohttp/web_ws.py | 9 +++++++-- pyproject.toml | 1 + requirements/base.txt | 3 ++- requirements/runtime-deps.in | 1 + requirements/runtime-deps.txt | 5 +++-- requirements/test.txt | 3 ++- 7 files changed, 23 insertions(+), 8 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 75ea2058173..ec22f38a52c 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -4,7 +4,7 @@ import sys from collections.abc import Callable from types import TracebackType -from typing import Any, Final, Generic, Literal, TypeVar, overload +from typing import Any, Final, Generic, Literal, overload from ._websocket.reader import WebSocketDataQueue from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError @@ -28,6 +28,11 @@ JSONEncoder, ) +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + if sys.version_info >= (3, 11): import asyncio as async_timeout from typing import Self @@ -38,7 +43,7 @@ # TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) # Covariant because it only affects return types, not input types -_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True) +_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) @frozen_dataclass_decorator diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index e7e75217464..9b32349d089 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -5,7 +5,7 @@ import json import sys from collections.abc import Callable, Iterable -from typing import Any, Final, Generic, Literal, TypeVar, Union, overload +from typing import Any, Final, Generic, Literal, Union, overload from multidict import CIMultiDict @@ -42,6 +42,11 @@ from .web_request import BaseRequest from .web_response import StreamResponse +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + if sys.version_info >= (3, 11): import asyncio as async_timeout from typing import Self @@ -59,7 +64,7 @@ THRESHOLD_CONNLOST_ACCESS: Final[int] = 5 # TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) -_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True) +_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) @frozen_dataclass_decorator diff --git a/pyproject.toml b/pyproject.toml index 8b707ddc4cb..f7e6f3168db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "frozenlist >= 1.1.1", "multidict >=4.5, < 7.0", "propcache >= 0.2.0", + "typing_extensions >= 4.0 ; python_version < '3.13'", "yarl >= 1.17.0, < 2.0", ] dynamic = [ diff --git a/requirements/base.txt b/requirements/base.txt index aded022bbca..eb87dd6da6e 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -40,8 +40,9 @@ pycares==4.11.0 # via aiodns pycparser==2.23 # via cffi -typing-extensions==4.15.0 +typing-extensions==4.15.0 ; python_version < "3.13" # via + # -r requirements/runtime-deps.in # aiosignal # multidict uvloop==0.21.0 ; platform_system != "Windows" and implementation_name == "cpython" diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index 0be3bb7f98f..f1d488679c1 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -10,4 +10,5 @@ brotlicffi; platform_python_implementation != 'CPython' frozenlist >= 1.1.1 multidict >=4.5, < 7.0 propcache >= 0.2.0 +typing_extensions >= 4.0 ; python_version < '3.13' yarl >= 1.17.0, < 2.0 diff --git a/requirements/runtime-deps.txt b/requirements/runtime-deps.txt index f45d006b614..e02c165910f 100644 --- a/requirements/runtime-deps.txt +++ b/requirements/runtime-deps.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.12 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/runtime-deps.txt --strip-extras requirements/runtime-deps.in @@ -36,8 +36,9 @@ pycares==4.11.0 # via aiodns pycparser==2.23 # via cffi -typing-extensions==4.15.0 +typing-extensions==4.15.0 ; python_version < "3.13" # via + # -r requirements/runtime-deps.in # aiosignal # multidict yarl==1.22.0 diff --git a/requirements/test.txt b/requirements/test.txt index ebe8680fbec..11781407328 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -131,8 +131,9 @@ tomli==2.3.0 # pytest trustme==1.2.1 ; platform_machine != "i686" # via -r requirements/test-common.in -typing-extensions==4.15.0 +typing-extensions==4.15.0 ; python_version < "3.13" # via + # -r requirements/runtime-deps.in # aiosignal # cryptography # exceptiongroup From df9d89210bac69d48aae52155390ad5f2a00938e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 14:42:22 -0600 Subject: [PATCH 25/42] update tests as well --- aiohttp/test_utils.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index fc0c4ed1b1e..c333f6a2236 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -287,7 +287,7 @@ def __init__( # type: ignore[misc] self._session._retry_connection = False self._closed = False self._responses: list[ClientResponse] = [] - self._websockets: list[ClientWebSocketResponse] = [] + self._websockets: list[ClientWebSocketResponse[bool]] = [] async def start_server(self) -> None: await self._server.start_server() @@ -440,20 +440,44 @@ def ws_connect( self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... + @overload + def ws_connect( + self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ... + def ws_connect( - self, path: StrOrURL, **kwargs: Any - ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Any]]": + self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any + ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": """Initiate websocket connection. The api corresponds to aiohttp.ClientSession.ws_connect. """ - return _WSRequestContextManager(self._ws_connect(path, **kwargs)) + return _WSRequestContextManager( + self._ws_connect(path, decode_text=decode_text, **kwargs) + ) + + @overload + async def _ws_connect( + self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any + ) -> "ClientWebSocketResponse[Literal[True]]": ... + + @overload + async def _ws_connect( + self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any + ) -> "ClientWebSocketResponse[Literal[False]]": ... + + @overload + async def _ws_connect( + self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any + ) -> "ClientWebSocketResponse[bool]": ... async def _ws_connect( - self, path: StrOrURL, **kwargs: Any - ) -> "ClientWebSocketResponse[Any]": - ws = await self._session.ws_connect(self.make_url(path), **kwargs) + self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any + ) -> "ClientWebSocketResponse[bool]": + ws = await self._session.ws_connect( + self.make_url(path), decode_text=decode_text, **kwargs + ) self._websockets.append(ws) return ws From 4ecc82c5d0e6a4e00aa5557f3fdb07f9ed30ea7c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:07:42 -0600 Subject: [PATCH 26/42] infer from false --- aiohttp/web_ws.py | 40 +++++++++++++++++++++++++- tests/test_client_ws_functional.py | 17 +++++++---- tests/test_web_websocket_functional.py | 10 +++++-- 3 files changed, 57 insertions(+), 10 deletions(-) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 9b32349d089..18f838902c0 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -5,7 +5,7 @@ import json import sys from collections.abc import Callable, Iterable -from typing import Any, Final, Generic, Literal, Union, overload +from typing import Any, Final, Generic, Literal, TypedDict, Union, Unpack, overload from multidict import CIMultiDict @@ -67,6 +67,20 @@ _DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) +class _WebSocketResponseParams(TypedDict, total=False): + """Parameters for WebSocketResponse constructor.""" + + timeout: float + receive_timeout: float | None + autoclose: bool + autoping: bool + heartbeat: float | None + protocols: Iterable[str] + compress: bool + max_msg_size: int + writer_limit: int + + @frozen_dataclass_decorator class WebSocketReady: ok: bool @@ -95,6 +109,30 @@ class WebSocketResponse(StreamResponse, Generic[_DecodeText]): _pong_response_cb: asyncio.TimerHandle | None = None _ping_task: asyncio.Task[None] | None = None + @overload + def __new__( + cls, + *, + decode_text: Literal[True] = ..., + **kwargs: Unpack[_WebSocketResponseParams], + ) -> "WebSocketResponse[Literal[True]]": ... + + @overload + def __new__( + cls, + *, + decode_text: Literal[False], + **kwargs: Unpack[_WebSocketResponseParams], + ) -> "WebSocketResponse[Literal[False]]": ... + + def __new__( + cls, + *, + decode_text: bool = True, + **kwargs: Unpack[_WebSocketResponseParams], + ) -> "WebSocketResponse[bool]": + return super().__new__(cls) + def __init__( self, *, diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index a06eb4e14c6..798b3b21682 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,7 +1,7 @@ import asyncio import json import sys -from typing import NoReturn +from typing import Literal, NoReturn from unittest import mock import pytest @@ -1277,7 +1277,7 @@ async def handler(request: web.Request) -> NoReturn: app = web.Application() app.router.add_route("GET", "/", handler) - sync_future: asyncio.Future[list[aiohttp.ClientWebSocketResponse]] = ( + sync_future: asyncio.Future[list[aiohttp.ClientWebSocketResponse[bool]]] = ( loop.create_future() ) client = await aiohttp_client(app) @@ -1339,8 +1339,10 @@ async def handler(request: web.Request) -> web.WebSocketResponse: async def test_receive_text_as_bytes_server_side(aiohttp_client: AiohttpClient) -> None: """Test server receiving TEXT messages as raw bytes with decode_text=False.""" - async def handler(request: web.Request) -> web.WebSocketResponse: - ws = web.WebSocketResponse(decode_text=False) + async def handler(request: web.Request) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) await ws.prepare(request) # Receive TEXT message as bytes @@ -1483,11 +1485,14 @@ async def test_receive_json_with_orjson_style_loads( ) -> None: """Test receive_json() with orjson-style loads that accepts bytes.""" - def orjson_style_loads(data: bytes | bytearray | memoryview | str) -> dict: + def orjson_style_loads( + data: bytes | bytearray | memoryview | str, + ) -> dict[str, int]: """Mock orjson.loads that accepts bytes/str.""" if isinstance(data, (bytes, bytearray, memoryview)): data = bytes(data).decode("utf-8") - return json.loads(data) + result: dict[str, int] = json.loads(data) + return result async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 7dbfcdcda8c..5829064400b 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -5,7 +5,7 @@ import json import sys import weakref -from typing import NoReturn +from typing import Literal, NoReturn from unittest import mock import pytest @@ -1485,8 +1485,12 @@ async def test_receive_text_as_bytes_server_iteration( ) -> None: """Test server iterating over WebSocket with decode_text=False.""" - async def websocket_handler(request: web.Request) -> web.WebSocketResponse: - ws = web.WebSocketResponse(decode_text=False) + async def websocket_handler( + request: web.Request, + ) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) await ws.prepare(request) async for msg in ws: From 6bd3b603bbf536164cfe7718ffae39a25d472a3c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:11:11 -0600 Subject: [PATCH 27/42] just set them --- aiohttp/web_ws.py | 40 +------------------------- tests/test_web_websocket_functional.py | 24 ++++++++++++---- 2 files changed, 19 insertions(+), 45 deletions(-) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 18f838902c0..9b32349d089 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -5,7 +5,7 @@ import json import sys from collections.abc import Callable, Iterable -from typing import Any, Final, Generic, Literal, TypedDict, Union, Unpack, overload +from typing import Any, Final, Generic, Literal, Union, overload from multidict import CIMultiDict @@ -67,20 +67,6 @@ _DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) -class _WebSocketResponseParams(TypedDict, total=False): - """Parameters for WebSocketResponse constructor.""" - - timeout: float - receive_timeout: float | None - autoclose: bool - autoping: bool - heartbeat: float | None - protocols: Iterable[str] - compress: bool - max_msg_size: int - writer_limit: int - - @frozen_dataclass_decorator class WebSocketReady: ok: bool @@ -109,30 +95,6 @@ class WebSocketResponse(StreamResponse, Generic[_DecodeText]): _pong_response_cb: asyncio.TimerHandle | None = None _ping_task: asyncio.Task[None] | None = None - @overload - def __new__( - cls, - *, - decode_text: Literal[True] = ..., - **kwargs: Unpack[_WebSocketResponseParams], - ) -> "WebSocketResponse[Literal[True]]": ... - - @overload - def __new__( - cls, - *, - decode_text: Literal[False], - **kwargs: Unpack[_WebSocketResponseParams], - ) -> "WebSocketResponse[Literal[False]]": ... - - def __new__( - cls, - *, - decode_text: bool = True, - **kwargs: Unpack[_WebSocketResponseParams], - ) -> "WebSocketResponse[bool]": - return super().__new__(cls) - def __init__( self, *, diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 5829064400b..966856bc7f1 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1451,8 +1451,12 @@ async def handler(request: web.Request) -> web.WebSocketResponse: async def test_receive_text_as_bytes_server_side(aiohttp_client: AiohttpClient) -> None: """Test server receiving TEXT messages as raw bytes with decode_text=False.""" - async def websocket_handler(request: web.Request) -> web.WebSocketResponse: - ws = web.WebSocketResponse(decode_text=False) + async def websocket_handler( + request: web.Request, + ) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) await ws.prepare(request) # Receive TEXT message as bytes @@ -1561,8 +1565,12 @@ async def test_server_receive_str_returns_bytes_with_decode_text_false( ) -> None: """Test that server receive_str() returns bytes when decode_text=False.""" - async def websocket_handler(request: web.Request) -> web.WebSocketResponse: - ws = web.WebSocketResponse(decode_text=False) + async def websocket_handler( + request: web.Request, + ) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) await ws.prepare(request) # receive_str() should return bytes when decode_text=False @@ -1623,8 +1631,12 @@ def orjson_style_loads(data: bytes | bytearray | memoryview | str) -> dict: data = bytes(data).decode("utf-8") return json.loads(data) - async def websocket_handler(request: web.Request) -> web.WebSocketResponse: - ws = web.WebSocketResponse(decode_text=False) + async def websocket_handler( + request: web.Request, + ) -> web.WebSocketResponse[Literal[False]]: + ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( + decode_text=False + ) await ws.prepare(request) # receive_json() with orjson-style loads should work with bytes From cd9a044b6892aa556c42e9a5303d0912801cfd0a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:11:38 -0600 Subject: [PATCH 28/42] just set them --- tests/test_web_websocket_functional.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 966856bc7f1..6d7bdcd32cd 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1625,11 +1625,14 @@ async def test_server_receive_json_with_orjson_style_loads( ) -> None: """Test server receive_json() with orjson-style loads that accepts bytes.""" - def orjson_style_loads(data: bytes | bytearray | memoryview | str) -> dict: + def orjson_style_loads( + data: bytes | bytearray | memoryview | str, + ) -> dict[str, str]: """Mock orjson.loads that accepts bytes/str.""" if isinstance(data, (bytes, bytearray, memoryview)): data = bytes(data).decode("utf-8") - return json.loads(data) + result: dict[str, str] = json.loads(data) + return result async def websocket_handler( request: web.Request, From c34ba73087e6f19425e4e535524f5aa843b303b0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:13:39 -0600 Subject: [PATCH 29/42] just set them --- aiohttp/client_ws.py | 3 +-- aiohttp/web_ws.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index ec22f38a52c..a269c8fd4d3 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -38,8 +38,7 @@ from typing import Self else: import async_timeout - - Self = TypeVar("Self", bound="ClientWebSocketResponse[bool]") + from typing_extensions import Self # TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) # Covariant because it only affects return types, not input types diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 9b32349d089..78c8f1e1fcd 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -52,8 +52,7 @@ from typing import Self else: import async_timeout - - Self = TypeVar("Self", bound="WebSocketResponse[bool]") + from typing_extensions import Self __all__ = ( "WebSocketResponse", From fc32471b935cb1cc9a6ba70be827d3144ce64843 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:28:30 -0600 Subject: [PATCH 30/42] cleanup --- tests/test_client_ws_functional.py | 13 +++++-------- tests/test_web_websocket_functional.py | 13 ++++++------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 798b3b21682..3cefbb26d3d 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1398,8 +1398,8 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert msg.type is WSMsgType.TEXT assert isinstance(msg.data, bytes) - # Parse JSON directly from bytes (like orjson would) - data = json.loads(msg.data) + # Parse JSON using msg.json() method (covers WSMessageTextBytes.json()) + data = msg.json() assert data == {"response": 84} await resp.close() @@ -1485,12 +1485,9 @@ async def test_receive_json_with_orjson_style_loads( ) -> None: """Test receive_json() with orjson-style loads that accepts bytes.""" - def orjson_style_loads( - data: bytes | bytearray | memoryview | str, - ) -> dict[str, int]: - """Mock orjson.loads that accepts bytes/str.""" - if isinstance(data, (bytes, bytearray, memoryview)): - data = bytes(data).decode("utf-8") + def orjson_style_loads(data: bytes) -> dict[str, int]: + """Mock orjson.loads that accepts bytes.""" + assert isinstance(data, bytes) result: dict[str, int] = json.loads(data) return result diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 6d7bdcd32cd..0e41faa21f2 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1503,7 +1503,9 @@ async def websocket_handler( assert isinstance(msg.data, bytes) # Echo back await ws.send_bytes(msg.data) - elif msg.type is aiohttp.WSMsgType.BINARY: + else: + assert msg.type is aiohttp.WSMsgType.BINARY + assert isinstance(msg.data, bytes) await ws.send_bytes(msg.data) return ws @@ -1625,12 +1627,9 @@ async def test_server_receive_json_with_orjson_style_loads( ) -> None: """Test server receive_json() with orjson-style loads that accepts bytes.""" - def orjson_style_loads( - data: bytes | bytearray | memoryview | str, - ) -> dict[str, str]: - """Mock orjson.loads that accepts bytes/str.""" - if isinstance(data, (bytes, bytearray, memoryview)): - data = bytes(data).decode("utf-8") + def orjson_style_loads(data: bytes) -> dict[str, str]: + """Mock orjson.loads that accepts bytes.""" + assert isinstance(data, bytes) result: dict[str, str] = json.loads(data) return result From 743132dd20b093739447ca6198eb6825d66f33d4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:29:14 -0600 Subject: [PATCH 31/42] changelog --- CHANGES/11763.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 CHANGES/11763.feature.rst diff --git a/CHANGES/11763.feature.rst b/CHANGES/11763.feature.rst new file mode 100644 index 00000000000..07f9a6afdad --- /dev/null +++ b/CHANGES/11763.feature.rst @@ -0,0 +1 @@ +Added ``decode_text`` parameter to :meth:`ClientSession.ws_connect` and :class:`WebSocketResponse` to receive WebSocket TEXT messages as raw bytes instead of decoded strings, enabling direct use with high-performance JSON parsers like ``orjson`` -- by :user:`bdraco`. From 633fc469eeaf7d9636cb9c61173035ccfccd50f2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:31:40 -0600 Subject: [PATCH 32/42] reduce --- aiohttp/_websocket/models.py | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/aiohttp/_websocket/models.py b/aiohttp/_websocket/models.py index 8819ae80e70..46a5fd8dd7f 100644 --- a/aiohttp/_websocket/models.py +++ b/aiohttp/_websocket/models.py @@ -129,10 +129,9 @@ class WSMessageError(NamedTuple): type: Literal[WSMsgType.ERROR] = WSMsgType.ERROR -WSMessage = Union[ +# Base message types (excluding TEXT variants) +_WSMessageBase = Union[ WSMessageContinuation, - WSMessageText, - WSMessageTextBytes, WSMessageBinary, WSMessagePing, WSMessagePong, @@ -142,31 +141,14 @@ class WSMessageError(NamedTuple): WSMessageError, ] +# All message types +WSMessage = Union[_WSMessageBase, WSMessageText, WSMessageTextBytes] + # Message type when decode_text=True (default) - TEXT messages have str data -WSMessageDecodeText = Union[ - WSMessageContinuation, - WSMessageText, - WSMessageBinary, - WSMessagePing, - WSMessagePong, - WSMessageClose, - WSMessageClosing, - WSMessageClosed, - WSMessageError, -] +WSMessageDecodeText = Union[_WSMessageBase, WSMessageText] # Message type when decode_text=False - TEXT messages have bytes data -WSMessageNoDecodeText = Union[ - WSMessageContinuation, - WSMessageTextBytes, - WSMessageBinary, - WSMessagePing, - WSMessagePong, - WSMessageClose, - WSMessageClosing, - WSMessageClosed, - WSMessageError, -] +WSMessageNoDecodeText = Union[_WSMessageBase, WSMessageTextBytes] WS_CLOSED_MESSAGE = WSMessageClosed() WS_CLOSING_MESSAGE = WSMessageClosing() From a49441f5b92bc4320eba493a27ae252cb412a206 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:33:57 -0600 Subject: [PATCH 33/42] reduce --- tests/test_client_ws_functional.py | 4 +++- tests/test_web_websocket_functional.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 3cefbb26d3d..776666c1154 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1485,7 +1485,9 @@ async def test_receive_json_with_orjson_style_loads( ) -> None: """Test receive_json() with orjson-style loads that accepts bytes.""" - def orjson_style_loads(data: bytes) -> dict[str, int]: + def orjson_style_loads( + data: bytes | bytearray | memoryview | str, + ) -> dict[str, int]: """Mock orjson.loads that accepts bytes.""" assert isinstance(data, bytes) result: dict[str, int] = json.loads(data) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 0e41faa21f2..3cad9f420c8 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1627,7 +1627,9 @@ async def test_server_receive_json_with_orjson_style_loads( ) -> None: """Test server receive_json() with orjson-style loads that accepts bytes.""" - def orjson_style_loads(data: bytes) -> dict[str, str]: + def orjson_style_loads( + data: bytes | bytearray | memoryview | str, + ) -> dict[str, str]: """Mock orjson.loads that accepts bytes.""" assert isinstance(data, bytes) result: dict[str, str] = json.loads(data) From 55852de565a5d1fcbe7e234fc0e6f1a7146e370a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:40:51 -0600 Subject: [PATCH 34/42] fix changelog --- CHANGES/11763.feature.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES/11763.feature.rst b/CHANGES/11763.feature.rst index 07f9a6afdad..b34bfafaca8 100644 --- a/CHANGES/11763.feature.rst +++ b/CHANGES/11763.feature.rst @@ -1 +1 @@ -Added ``decode_text`` parameter to :meth:`ClientSession.ws_connect` and :class:`WebSocketResponse` to receive WebSocket TEXT messages as raw bytes instead of decoded strings, enabling direct use with high-performance JSON parsers like ``orjson`` -- by :user:`bdraco`. +Added ``decode_text`` parameter to :meth:`~aiohttp.ClientSession.ws_connect` and :class:`~aiohttp.web.WebSocketResponse` to receive WebSocket TEXT messages as raw bytes instead of decoded strings, enabling direct use with high-performance JSON parsers like ``orjson`` -- by :user:`bdraco`. From b6493eeab0e2dba206faf82020efe5081be12008 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 15:49:09 -0600 Subject: [PATCH 35/42] changelog --- CHANGES/11764.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 120000 CHANGES/11764.feature.rst diff --git a/CHANGES/11764.feature.rst b/CHANGES/11764.feature.rst new file mode 120000 index 00000000000..0860becd808 --- /dev/null +++ b/CHANGES/11764.feature.rst @@ -0,0 +1 @@ +11763.feature.rst \ No newline at end of file From f8015af08a7143d3dc151d8828ad3923cb877282 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 16:16:51 -0600 Subject: [PATCH 36/42] Update pyproject.toml Co-authored-by: Sam Bull --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f7e6f3168db..0cfa7a3221b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "frozenlist >= 1.1.1", "multidict >=4.5, < 7.0", "propcache >= 0.2.0", - "typing_extensions >= 4.0 ; python_version < '3.13'", + "typing_extensions >= 4.4 ; python_version < '3.13'", "yarl >= 1.17.0, < 2.0", ] dynamic = [ From 69c13bf11807a36e7ff23b389bafe207c95d4ba1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 16:19:29 -0600 Subject: [PATCH 37/42] Update aiohttp/_websocket/models.py Co-authored-by: Sam Bull --- aiohttp/_websocket/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/_websocket/models.py b/aiohttp/_websocket/models.py index 46a5fd8dd7f..d6065f50003 100644 --- a/aiohttp/_websocket/models.py +++ b/aiohttp/_websocket/models.py @@ -68,7 +68,7 @@ class WSMessageTextBytes(NamedTuple): type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT def json( - self, *, loads: Callable[[str | bytes | bytearray], Any] = json.loads + self, *, loads: Callable[[bytes], Any] = json.loads ) -> Any: """Return parsed JSON data.""" return loads(self.data) From 79a90368ebaa319227c56fe094371ff647436996 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Dec 2025 22:20:06 +0000 Subject: [PATCH 38/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- aiohttp/_websocket/models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aiohttp/_websocket/models.py b/aiohttp/_websocket/models.py index d6065f50003..15159bcacb0 100644 --- a/aiohttp/_websocket/models.py +++ b/aiohttp/_websocket/models.py @@ -67,9 +67,7 @@ class WSMessageTextBytes(NamedTuple): extra: str | None = None type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT - def json( - self, *, loads: Callable[[bytes], Any] = json.loads - ) -> Any: + def json(self, *, loads: Callable[[bytes], Any] = json.loads) -> Any: """Return parsed JSON data.""" return loads(self.data) From ec38532bc747cd1b2a7e40f81ad1ccc22d0b362c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 16:21:19 -0600 Subject: [PATCH 39/42] tweaks --- requirements/runtime-deps.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index f1d488679c1..16515e7551a 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -10,5 +10,5 @@ brotlicffi; platform_python_implementation != 'CPython' frozenlist >= 1.1.1 multidict >=4.5, < 7.0 propcache >= 0.2.0 -typing_extensions >= 4.0 ; python_version < '3.13' +typing_extensions >= 4.4 ; python_version < '3.13' yarl >= 1.17.0, < 2.0 From 080880d173268a2d3f31f6f2d742db064cb5b3ce Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 16:23:14 -0600 Subject: [PATCH 40/42] newer syntax --- aiohttp/_websocket/models.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/aiohttp/_websocket/models.py b/aiohttp/_websocket/models.py index 15159bcacb0..3d7e6d7d5ac 100644 --- a/aiohttp/_websocket/models.py +++ b/aiohttp/_websocket/models.py @@ -3,7 +3,7 @@ import json from collections.abc import Callable from enum import IntEnum -from typing import Any, Final, Literal, NamedTuple, Union, cast +from typing import Any, Final, Literal, NamedTuple, cast WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF]) @@ -128,25 +128,25 @@ class WSMessageError(NamedTuple): # Base message types (excluding TEXT variants) -_WSMessageBase = Union[ - WSMessageContinuation, - WSMessageBinary, - WSMessagePing, - WSMessagePong, - WSMessageClose, - WSMessageClosing, - WSMessageClosed, - WSMessageError, -] +_WSMessageBase = ( + WSMessageContinuation + | WSMessageBinary + | WSMessagePing + | WSMessagePong + | WSMessageClose + | WSMessageClosing + | WSMessageClosed + | WSMessageError +) # All message types -WSMessage = Union[_WSMessageBase, WSMessageText, WSMessageTextBytes] +WSMessage = _WSMessageBase | WSMessageText | WSMessageTextBytes # Message type when decode_text=True (default) - TEXT messages have str data -WSMessageDecodeText = Union[_WSMessageBase, WSMessageText] +WSMessageDecodeText = _WSMessageBase | WSMessageText # Message type when decode_text=False - TEXT messages have bytes data -WSMessageNoDecodeText = Union[_WSMessageBase, WSMessageTextBytes] +WSMessageNoDecodeText = _WSMessageBase | WSMessageTextBytes WS_CLOSED_MESSAGE = WSMessageClosed() WS_CLOSING_MESSAGE = WSMessageClosing() From d639aafeed31a8d4f31362777d84fc14d83abbab Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 12 Dec 2025 16:25:52 -0600 Subject: [PATCH 41/42] narrow to bytes --- aiohttp/client_ws.py | 10 +++------- aiohttp/web_ws.py | 10 +++------- tests/test_client_ws_functional.py | 4 +--- tests/test_web_websocket_functional.py | 4 +--- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index a269c8fd4d3..f2e92149e55 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -460,7 +460,7 @@ async def receive_json( async def receive_json( self: "ClientWebSocketResponse[Literal[False]]", *, - loads: Callable[[bytes | bytearray | memoryview | str], Any] = ..., + loads: Callable[[bytes], Any] = ..., timeout: float | None = None, ) -> Any: ... @@ -468,18 +468,14 @@ async def receive_json( async def receive_json( self: "ClientWebSocketResponse[_DecodeText]", *, - loads: ( - JSONDecoder | Callable[[bytes | bytearray | memoryview | str], Any] - ) = ..., + loads: JSONDecoder | Callable[[bytes], Any] = ..., timeout: float | None = None, ) -> Any: ... async def receive_json( self, *, - loads: ( - JSONDecoder | Callable[[bytes | bytearray | memoryview | str], Any] - ) = DEFAULT_JSON_DECODER, + loads: JSONDecoder | Callable[[bytes], Any] = DEFAULT_JSON_DECODER, timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 78c8f1e1fcd..d55d3687d92 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -668,7 +668,7 @@ async def receive_json( async def receive_json( self: "WebSocketResponse[Literal[False]]", *, - loads: Callable[[bytes | bytearray | memoryview | str], Any] = ..., + loads: Callable[[bytes], Any] = ..., timeout: float | None = None, ) -> Any: ... @@ -676,18 +676,14 @@ async def receive_json( async def receive_json( self: "WebSocketResponse[_DecodeText]", *, - loads: ( - JSONDecoder | Callable[[bytes | bytearray | memoryview | str], Any] - ) = ..., + loads: JSONDecoder | Callable[[bytes], Any] = ..., timeout: float | None = None, ) -> Any: ... async def receive_json( self, *, - loads: ( - JSONDecoder | Callable[[bytes | bytearray | memoryview | str], Any] - ) = json.loads, + loads: JSONDecoder | Callable[[bytes], Any] = json.loads, timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 776666c1154..3cefbb26d3d 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1485,9 +1485,7 @@ async def test_receive_json_with_orjson_style_loads( ) -> None: """Test receive_json() with orjson-style loads that accepts bytes.""" - def orjson_style_loads( - data: bytes | bytearray | memoryview | str, - ) -> dict[str, int]: + def orjson_style_loads(data: bytes) -> dict[str, int]: """Mock orjson.loads that accepts bytes.""" assert isinstance(data, bytes) result: dict[str, int] = json.loads(data) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 3cad9f420c8..0e41faa21f2 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1627,9 +1627,7 @@ async def test_server_receive_json_with_orjson_style_loads( ) -> None: """Test server receive_json() with orjson-style loads that accepts bytes.""" - def orjson_style_loads( - data: bytes | bytearray | memoryview | str, - ) -> dict[str, str]: + def orjson_style_loads(data: bytes) -> dict[str, str]: """Mock orjson.loads that accepts bytes.""" assert isinstance(data, bytes) result: dict[str, str] = json.loads(data) From c85cce7557e5f1941163357736aad1cb4953838b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 14 Dec 2025 10:06:48 -0600 Subject: [PATCH 42/42] docs --- docs/client_reference.rst | 11 ++++++++++- docs/web_reference.rst | 10 +++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 50d158b5c2a..ab52fbfa5fb 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -722,7 +722,8 @@ The client session supports the context manager protocol for self closing. proxy=None, proxy_auth=None, ssl=True, \ verify_ssl=None, fingerprint=None, \ ssl_context=None, proxy_headers=None, \ - compress=0, max_msg_size=4194304) + compress=0, max_msg_size=4194304, \ + decode_text=True) :async: Create a websocket connection. Returns a @@ -851,6 +852,14 @@ The client session supports the context manager protocol for self closing. .. versionadded:: 3.5 + :param bool decode_text: If ``True`` (default), TEXT messages are + decoded to strings. If ``False``, TEXT messages + are returned as raw bytes, which can improve + performance when using JSON parsers like + ``orjson`` that accept bytes directly. + + .. versionadded:: 3.14 + .. method:: close() :async: diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 048d798f8c1..01b237f1b0a 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -939,7 +939,7 @@ and :ref:`aiohttp-web-signals` handlers:: .. class:: WebSocketResponse(*, timeout=10.0, receive_timeout=None, \ autoclose=True, autoping=True, heartbeat=None, \ protocols=(), compress=True, max_msg_size=4194304, \ - writer_limit=65536) + writer_limit=65536, decode_text=True) Class for handling server-side websockets, inherited from :class:`StreamResponse`. @@ -1002,6 +1002,14 @@ and :ref:`aiohttp-web-signals` handlers:: .. versionadded:: 3.11 + :param bool decode_text: If ``True`` (default), TEXT messages are + decoded to strings. If ``False``, TEXT messages + are returned as raw bytes, which can improve + performance when using JSON parsers like + ``orjson`` that accept bytes directly. + + .. versionadded:: 3.14 + The class supports ``async for`` statement for iterating over incoming messages::