1111from collections .abc import (
1212 Awaitable ,
1313 Callable ,
14+ Collection ,
1415 Coroutine ,
1516 Generator ,
1617 Iterable ,
1920)
2021from contextlib import suppress
2122from types import TracebackType
22- from typing import TYPE_CHECKING , Any , Final , Generic , TypedDict , TypeVar
23+ from typing import (
24+ TYPE_CHECKING ,
25+ Any ,
26+ Final ,
27+ Generic ,
28+ Literal ,
29+ TypedDict ,
30+ TypeVar ,
31+ overload ,
32+ )
2333
2434import attr
2535from multidict import CIMultiDict , MultiDict , MultiDictProxy , istr
@@ -186,6 +196,30 @@ class _RequestOptions(TypedDict, total=False):
186196 middlewares : Sequence [ClientMiddlewareType ] | None
187197
188198
199+ class _WSConnectOptions (TypedDict , total = False ):
200+ method : str
201+ protocols : Collection [str ]
202+ timeout : "ClientWSTimeout | _SENTINEL"
203+ receive_timeout : float | None
204+ autoclose : bool
205+ autoping : bool
206+ heartbeat : float | None
207+ auth : BasicAuth | None
208+ origin : str | None
209+ params : Query
210+ headers : LooseHeaders | None
211+ proxy : StrOrURL | None
212+ proxy_auth : BasicAuth | None
213+ ssl : SSLContext | bool | Fingerprint
214+ verify_ssl : bool | None
215+ fingerprint : bytes | None
216+ ssl_context : SSLContext | None
217+ server_hostname : str | None
218+ proxy_headers : LooseHeaders | None
219+ compress : int
220+ max_msg_size : int
221+
222+
189223@attr .s (auto_attribs = True , frozen = True , slots = True )
190224class ClientTimeout :
191225 total : float | None = None
@@ -214,7 +248,11 @@ class ClientTimeout:
214248# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
215249IDEMPOTENT_METHODS = frozenset ({"GET" , "HEAD" , "OPTIONS" , "TRACE" , "PUT" , "DELETE" })
216250
217- _RetType = TypeVar ("_RetType" , ClientResponse , ClientWebSocketResponse )
251+ _RetType_co = TypeVar (
252+ "_RetType_co" ,
253+ bound = "ClientResponse | ClientWebSocketResponse[bool]" ,
254+ covariant = True ,
255+ )
218256_CharsetResolver = Callable [[ClientResponse , bytes ], str ]
219257
220258
@@ -917,6 +955,35 @@ async def _connect_and_send_request(
917955 )
918956 raise
919957
958+ if sys .version_info >= (3 , 11 ) and TYPE_CHECKING :
959+
960+ @overload
961+ def ws_connect (
962+ self ,
963+ url : StrOrURL ,
964+ * ,
965+ decode_text : Literal [True ] = ...,
966+ ** kwargs : Unpack [_WSConnectOptions ],
967+ ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]" : ...
968+
969+ @overload
970+ def ws_connect (
971+ self ,
972+ url : StrOrURL ,
973+ * ,
974+ decode_text : Literal [False ],
975+ ** kwargs : Unpack [_WSConnectOptions ],
976+ ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]" : ...
977+
978+ @overload
979+ def ws_connect (
980+ self ,
981+ url : StrOrURL ,
982+ * ,
983+ decode_text : bool = ...,
984+ ** kwargs : Unpack [_WSConnectOptions ],
985+ ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]" : ...
986+
920987 def ws_connect (
921988 self ,
922989 url : StrOrURL ,
@@ -942,7 +1009,8 @@ def ws_connect(
9421009 proxy_headers : LooseHeaders | None = None ,
9431010 compress : int = 0 ,
9441011 max_msg_size : int = 4 * 1024 * 1024 ,
945- ) -> "_WSRequestContextManager" :
1012+ decode_text : bool = True ,
1013+ ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]" :
9461014 """Initiate websocket connection."""
9471015 return _WSRequestContextManager (
9481016 self ._ws_connect (
@@ -968,9 +1036,39 @@ def ws_connect(
9681036 proxy_headers = proxy_headers ,
9691037 compress = compress ,
9701038 max_msg_size = max_msg_size ,
1039+ decode_text = decode_text ,
9711040 )
9721041 )
9731042
1043+ if sys .version_info >= (3 , 11 ) and TYPE_CHECKING :
1044+
1045+ @overload
1046+ async def _ws_connect (
1047+ self ,
1048+ url : StrOrURL ,
1049+ * ,
1050+ decode_text : Literal [True ] = ...,
1051+ ** kwargs : Unpack [_WSConnectOptions ],
1052+ ) -> "ClientWebSocketResponse[Literal[True]]" : ...
1053+
1054+ @overload
1055+ async def _ws_connect (
1056+ self ,
1057+ url : StrOrURL ,
1058+ * ,
1059+ decode_text : Literal [False ],
1060+ ** kwargs : Unpack [_WSConnectOptions ],
1061+ ) -> "ClientWebSocketResponse[Literal[False]]" : ...
1062+
1063+ @overload
1064+ async def _ws_connect (
1065+ self ,
1066+ url : StrOrURL ,
1067+ * ,
1068+ decode_text : bool = ...,
1069+ ** kwargs : Unpack [_WSConnectOptions ],
1070+ ) -> "ClientWebSocketResponse[bool]" : ...
1071+
9741072 async def _ws_connect (
9751073 self ,
9761074 url : StrOrURL ,
@@ -996,7 +1094,8 @@ async def _ws_connect(
9961094 proxy_headers : LooseHeaders | None = None ,
9971095 compress : int = 0 ,
9981096 max_msg_size : int = 4 * 1024 * 1024 ,
999- ) -> ClientWebSocketResponse :
1097+ decode_text : bool = True ,
1098+ ) -> "ClientWebSocketResponse[bool]" :
10001099 if timeout is not sentinel :
10011100 if isinstance (timeout , ClientWSTimeout ):
10021101 ws_timeout = timeout
@@ -1162,7 +1261,9 @@ async def _ws_connect(
11621261 transport = conn .transport
11631262 assert transport is not None
11641263 reader = WebSocketDataQueue (conn_proto , 2 ** 16 , loop = self ._loop )
1165- conn_proto .set_parser (WebSocketReader (reader , max_msg_size ), reader )
1264+ conn_proto .set_parser (
1265+ WebSocketReader (reader , max_msg_size , decode_text = decode_text ), reader
1266+ )
11661267 writer = WebSocketWriter (
11671268 conn_proto ,
11681269 transport ,
@@ -1467,32 +1568,34 @@ async def __aexit__(
14671568 await self .close ()
14681569
14691570
1470- class _BaseRequestContextManager (Coroutine [Any , Any , _RetType ], Generic [_RetType ]):
1571+ class _BaseRequestContextManager (
1572+ Coroutine [Any , Any , _RetType_co ], Generic [_RetType_co ]
1573+ ):
14711574
14721575 __slots__ = ("_coro" , "_resp" )
14731576
1474- def __init__ (self , coro : Coroutine [" asyncio.Future[Any]" , None , _RetType ]) -> None :
1475- self ._coro : Coroutine [asyncio .Future [Any ], None , _RetType ] = coro
1577+ def __init__ (self , coro : Coroutine [asyncio .Future [Any ], None , _RetType_co ]) -> None :
1578+ self ._coro : Coroutine [asyncio .Future [Any ], None , _RetType_co ] = coro
14761579
1477- def send (self , arg : None ) -> " asyncio.Future[Any]" :
1580+ def send (self , arg : None ) -> asyncio .Future [Any ]:
14781581 return self ._coro .send (arg )
14791582
1480- def throw (self , * args : Any , ** kwargs : Any ) -> " asyncio.Future[Any]" :
1583+ def throw (self , * args : Any , ** kwargs : Any ) -> asyncio .Future [Any ]:
14811584 return self ._coro .throw (* args , ** kwargs )
14821585
14831586 def close (self ) -> None :
14841587 return self ._coro .close ()
14851588
1486- def __await__ (self ) -> Generator [Any , None , _RetType ]:
1589+ def __await__ (self ) -> Generator [Any , None , _RetType_co ]:
14871590 ret = self ._coro .__await__ ()
14881591 return ret
14891592
1490- def __iter__ (self ) -> Generator [Any , None , _RetType ]:
1593+ def __iter__ (self ) -> Generator [Any , None , _RetType_co ]:
14911594 return self .__await__ ()
14921595
1493- async def __aenter__ (self ) -> _RetType :
1494- self ._resp : _RetType = await self ._coro
1495- return await self ._resp .__aenter__ ()
1596+ async def __aenter__ (self ) -> _RetType_co :
1597+ self ._resp : _RetType_co = await self ._coro
1598+ return await self ._resp .__aenter__ () # type: ignore[return-value]
14961599
14971600 async def __aexit__ (
14981601 self ,
@@ -1504,7 +1607,7 @@ async def __aexit__(
15041607
15051608
15061609_RequestContextManager = _BaseRequestContextManager [ClientResponse ]
1507- _WSRequestContextManager = _BaseRequestContextManager [ClientWebSocketResponse ]
1610+ _WSRequestContextManager = _BaseRequestContextManager [ClientWebSocketResponse [ bool ] ]
15081611
15091612
15101613class _SessionRequestContextManager :
@@ -1513,7 +1616,7 @@ class _SessionRequestContextManager:
15131616
15141617 def __init__ (
15151618 self ,
1516- coro : Coroutine [" asyncio.Future[Any]" , None , ClientResponse ],
1619+ coro : Coroutine [asyncio .Future [Any ], None , ClientResponse ],
15171620 session : ClientSession ,
15181621 ) -> None :
15191622 self ._coro = coro
0 commit comments