2121)
2222from contextlib import suppress
2323from types import TracebackType
24- from typing import TYPE_CHECKING , Any , Final , Generic , TypedDict , TypeVar , final
24+ from typing import (
25+ TYPE_CHECKING ,
26+ Any ,
27+ Final ,
28+ Generic ,
29+ Literal ,
30+ TypedDict ,
31+ TypeVar ,
32+ final ,
33+ overload ,
34+ )
2535
2636from multidict import CIMultiDict , MultiDict , MultiDictProxy , istr
2737from yarl import URL , Query
@@ -187,6 +197,27 @@ class _RequestOptions(TypedDict, total=False):
187197 middlewares : Sequence [ClientMiddlewareType ] | None
188198
189199
200+ class _WSConnectOptions (TypedDict , total = False ):
201+ method : str
202+ protocols : Collection [str ]
203+ timeout : "ClientWSTimeout | _SENTINEL"
204+ receive_timeout : float | None
205+ autoclose : bool
206+ autoping : bool
207+ heartbeat : float | None
208+ auth : BasicAuth | None
209+ origin : str | None
210+ params : Query
211+ headers : LooseHeaders | None
212+ proxy : StrOrURL | None
213+ proxy_auth : BasicAuth | None
214+ ssl : SSLContext | bool | Fingerprint
215+ server_hostname : str | None
216+ proxy_headers : LooseHeaders | None
217+ compress : int
218+ max_msg_size : int
219+
220+
190221@frozen_dataclass_decorator
191222class ClientTimeout :
192223 total : float | None = None
@@ -215,7 +246,11 @@ class ClientTimeout:
215246# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
216247IDEMPOTENT_METHODS = frozenset ({"GET" , "HEAD" , "OPTIONS" , "TRACE" , "PUT" , "DELETE" })
217248
218- _RetType = TypeVar ("_RetType" , ClientResponse , ClientWebSocketResponse )
249+ _RetType_co = TypeVar (
250+ "_RetType_co" ,
251+ bound = "ClientResponse | ClientWebSocketResponse[bool]" ,
252+ covariant = True ,
253+ )
219254_CharsetResolver = Callable [[ClientResponse , bytes ], str ]
220255
221256
@@ -866,6 +901,35 @@ async def _connect_and_send_request(
866901 )
867902 raise
868903
904+ if sys .version_info >= (3 , 11 ) and TYPE_CHECKING :
905+
906+ @overload
907+ def ws_connect (
908+ self ,
909+ url : StrOrURL ,
910+ * ,
911+ decode_text : Literal [True ] = ...,
912+ ** kwargs : Unpack [_WSConnectOptions ],
913+ ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]" : ...
914+
915+ @overload
916+ def ws_connect (
917+ self ,
918+ url : StrOrURL ,
919+ * ,
920+ decode_text : Literal [False ],
921+ ** kwargs : Unpack [_WSConnectOptions ],
922+ ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]" : ...
923+
924+ @overload
925+ def ws_connect (
926+ self ,
927+ url : StrOrURL ,
928+ * ,
929+ decode_text : bool = ...,
930+ ** kwargs : Unpack [_WSConnectOptions ],
931+ ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]" : ...
932+
869933 def ws_connect (
870934 self ,
871935 url : StrOrURL ,
@@ -888,7 +952,8 @@ def ws_connect(
888952 proxy_headers : LooseHeaders | None = None ,
889953 compress : int = 0 ,
890954 max_msg_size : int = 4 * 1024 * 1024 ,
891- ) -> "_WSRequestContextManager" :
955+ decode_text : bool = True ,
956+ ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]" :
892957 """Initiate websocket connection."""
893958 return _WSRequestContextManager (
894959 self ._ws_connect (
@@ -911,9 +976,39 @@ def ws_connect(
911976 proxy_headers = proxy_headers ,
912977 compress = compress ,
913978 max_msg_size = max_msg_size ,
979+ decode_text = decode_text ,
914980 )
915981 )
916982
983+ if sys .version_info >= (3 , 11 ) and TYPE_CHECKING :
984+
985+ @overload
986+ async def _ws_connect (
987+ self ,
988+ url : StrOrURL ,
989+ * ,
990+ decode_text : Literal [True ] = ...,
991+ ** kwargs : Unpack [_WSConnectOptions ],
992+ ) -> "ClientWebSocketResponse[Literal[True]]" : ...
993+
994+ @overload
995+ async def _ws_connect (
996+ self ,
997+ url : StrOrURL ,
998+ * ,
999+ decode_text : Literal [False ],
1000+ ** kwargs : Unpack [_WSConnectOptions ],
1001+ ) -> "ClientWebSocketResponse[Literal[False]]" : ...
1002+
1003+ @overload
1004+ async def _ws_connect (
1005+ self ,
1006+ url : StrOrURL ,
1007+ * ,
1008+ decode_text : bool = ...,
1009+ ** kwargs : Unpack [_WSConnectOptions ],
1010+ ) -> "ClientWebSocketResponse[bool]" : ...
1011+
9171012 async def _ws_connect (
9181013 self ,
9191014 url : StrOrURL ,
@@ -936,7 +1031,8 @@ async def _ws_connect(
9361031 proxy_headers : LooseHeaders | None = None ,
9371032 compress : int = 0 ,
9381033 max_msg_size : int = 4 * 1024 * 1024 ,
939- ) -> ClientWebSocketResponse :
1034+ decode_text : bool = True ,
1035+ ) -> "ClientWebSocketResponse[bool]" :
9401036 if timeout is not sentinel :
9411037 if isinstance (timeout , ClientWSTimeout ):
9421038 ws_timeout = timeout
@@ -1098,7 +1194,9 @@ async def _ws_connect(
10981194 transport = conn .transport
10991195 assert transport is not None
11001196 reader = WebSocketDataQueue (conn_proto , 2 ** 16 , loop = self ._loop )
1101- conn_proto .set_parser (WebSocketReader (reader , max_msg_size ), reader )
1197+ conn_proto .set_parser (
1198+ WebSocketReader (reader , max_msg_size , decode_text = decode_text ), reader
1199+ )
11021200 writer = WebSocketWriter (
11031201 conn_proto ,
11041202 transport ,
@@ -1373,31 +1471,33 @@ async def __aexit__(
13731471 await self .close ()
13741472
13751473
1376- class _BaseRequestContextManager (Coroutine [Any , Any , _RetType ], Generic [_RetType ]):
1474+ class _BaseRequestContextManager (
1475+ Coroutine [Any , Any , _RetType_co ], Generic [_RetType_co ]
1476+ ):
13771477 __slots__ = ("_coro" , "_resp" )
13781478
1379- def __init__ (self , coro : Coroutine [" asyncio.Future[Any]" , None , _RetType ]) -> None :
1380- self ._coro : Coroutine [asyncio .Future [Any ], None , _RetType ] = coro
1479+ def __init__ (self , coro : Coroutine [asyncio .Future [Any ], None , _RetType_co ]) -> None :
1480+ self ._coro : Coroutine [asyncio .Future [Any ], None , _RetType_co ] = coro
13811481
1382- def send (self , arg : None ) -> " asyncio.Future[Any]" :
1482+ def send (self , arg : None ) -> asyncio .Future [Any ]:
13831483 return self ._coro .send (arg )
13841484
1385- def throw (self , * args : Any , ** kwargs : Any ) -> " asyncio.Future[Any]" :
1485+ def throw (self , * args : Any , ** kwargs : Any ) -> asyncio .Future [Any ]:
13861486 return self ._coro .throw (* args , ** kwargs )
13871487
13881488 def close (self ) -> None :
13891489 return self ._coro .close ()
13901490
1391- def __await__ (self ) -> Generator [Any , None , _RetType ]:
1491+ def __await__ (self ) -> Generator [Any , None , _RetType_co ]:
13921492 ret = self ._coro .__await__ ()
13931493 return ret
13941494
1395- def __iter__ (self ) -> Generator [Any , None , _RetType ]:
1495+ def __iter__ (self ) -> Generator [Any , None , _RetType_co ]:
13961496 return self .__await__ ()
13971497
1398- async def __aenter__ (self ) -> _RetType :
1399- self ._resp : _RetType = await self ._coro
1400- return await self ._resp .__aenter__ ()
1498+ async def __aenter__ (self ) -> _RetType_co :
1499+ self ._resp : _RetType_co = await self ._coro
1500+ return await self ._resp .__aenter__ () # type: ignore[return-value]
14011501
14021502 async def __aexit__ (
14031503 self ,
@@ -1409,15 +1509,15 @@ async def __aexit__(
14091509
14101510
14111511_RequestContextManager = _BaseRequestContextManager [ClientResponse ]
1412- _WSRequestContextManager = _BaseRequestContextManager [ClientWebSocketResponse ]
1512+ _WSRequestContextManager = _BaseRequestContextManager [ClientWebSocketResponse [ bool ] ]
14131513
14141514
14151515class _SessionRequestContextManager :
14161516 __slots__ = ("_coro" , "_resp" , "_session" )
14171517
14181518 def __init__ (
14191519 self ,
1420- coro : Coroutine [" asyncio.Future[Any]" , None , ClientResponse ],
1520+ coro : Coroutine [asyncio .Future [Any ], None , ClientResponse ],
14211521 session : ClientSession ,
14221522 ) -> None :
14231523 self ._coro = coro
0 commit comments