Skip to content

Commit ddc87bd

Browse files
committed
add boolean option for ASGITransport streaming
1 parent a0b2cc7 commit ddc87bd

2 files changed

Lines changed: 105 additions & 47 deletions

File tree

httpx/_transports/asgi.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import typing
45

56
from .._models import Request, Response
@@ -101,11 +102,9 @@ def __init__(
101102
self,
102103
ignore_body: bool,
103104
asgi_generator: typing.AsyncGenerator[_Message, None],
104-
disconnect_request_event: Event,
105105
) -> None:
106106
self._ignore_body = ignore_body
107107
self._asgi_generator = asgi_generator
108-
self._disconnect_request_event = disconnect_request_event
109108

110109
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
111110
more_body = True
@@ -118,13 +117,10 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
118117
more_body = message.get("more_body", False)
119118
if chunk and not self._ignore_body:
120119
yield chunk
121-
if not more_body:
122-
self._disconnect_request_event.set()
123120
finally:
124121
await self.aclose()
125122

126123
async def aclose(self) -> None:
127-
self._disconnect_request_event.set()
128124
await self._asgi_generator.aclose()
129125

130126

@@ -149,6 +145,9 @@ class ASGITransport(AsyncBaseTransport):
149145
such as testing the content of a client 500 response.
150146
* `root_path` - The root path on which the ASGI application should be mounted.
151147
* `client` - A two-tuple indicating the client IP and port of incoming requests.
148+
* `streaming` - Set to `True` to enable streaming of response content. Default to
149+
`False`, as activating this feature means that the ASGI `app` will run in a
150+
sub-task, which has observable side effects for context variables.
152151
```
153152
"""
154153

@@ -158,18 +157,20 @@ def __init__(
158157
raise_app_exceptions: bool = True,
159158
root_path: str = "",
160159
client: tuple[str, int] = ("127.0.0.1", 123),
160+
*,
161+
streaming: bool = False,
161162
) -> None:
162163
self.app = app
163164
self.raise_app_exceptions = raise_app_exceptions
164165
self.root_path = root_path
165166
self.client = client
167+
self.streaming = streaming
166168

167169
async def handle_async_request(
168170
self,
169171
request: Request,
170172
) -> Response:
171-
disconnect_request_event = create_event()
172-
asgi_generator = self._stream_asgi_messages(request, disconnect_request_event)
173+
asgi_generator = self._stream_asgi_messages(request)
173174

174175
async for message in asgi_generator:
175176
if message["type"] == "http.response.start":
@@ -179,15 +180,13 @@ async def handle_async_request(
179180
stream=ASGIResponseStream(
180181
ignore_body=request.method == "HEAD",
181182
asgi_generator=asgi_generator,
182-
disconnect_request_event=disconnect_request_event,
183183
),
184184
)
185185
else:
186-
disconnect_request_event.set()
187186
return Response(status_code=500, headers=[])
188187

189188
async def _stream_asgi_messages(
190-
self, request: Request, disconnect_request_event: Event
189+
self, request: Request
191190
) -> typing.AsyncGenerator[typing.MutableMapping[str, typing.Any]]:
192191
assert isinstance(request.stream, AsyncByteStream)
193192

@@ -211,9 +210,13 @@ async def _stream_asgi_messages(
211210
request_body_chunks = request.stream.__aiter__()
212211
request_complete = False
213212

213+
# Response.
214+
response_complete = create_event()
215+
214216
# ASGI response messages stream
217+
stream_size = 0 if self.streaming else float("inf")
215218
response_message_send_stream, response_message_recv_stream = (
216-
create_memory_object_stream(0)
219+
create_memory_object_stream(stream_size)
217220
)
218221

219222
# ASGI app exception
@@ -225,7 +228,7 @@ async def receive() -> _Message:
225228
nonlocal request_complete
226229

227230
if request_complete:
228-
await disconnect_request_event.wait()
231+
await response_complete.wait()
229232
return {"type": "http.disconnect"}
230233

231234
try:
@@ -235,17 +238,29 @@ async def receive() -> _Message:
235238
return {"type": "http.request", "body": b"", "more_body": False}
236239
return {"type": "http.request", "body": body, "more_body": True}
237240

241+
async def send(message: _Message) -> None:
242+
await response_message_send_stream.send(message)
243+
if message["type"] == "http.response.body" and not message.get(
244+
"more_body", False
245+
):
246+
response_complete.set()
247+
238248
async def run_app() -> None:
239249
nonlocal app_exception
240250
try:
241-
await self.app(scope, receive, response_message_send_stream.send)
251+
await self.app(scope, receive, send)
242252
except Exception as ex:
243253
app_exception = ex
244254
finally:
245255
await response_message_send_stream.aclose()
246256

247-
async with create_task_group() as task_group:
248-
task_group.start_soon(run_app)
257+
async with contextlib.AsyncExitStack() as exit_stack:
258+
exit_stack.callback(response_complete.set)
259+
if self.streaming:
260+
task_group = await exit_stack.enter_async_context(create_task_group())
261+
task_group.start_soon(run_app)
262+
else:
263+
await run_app()
249264

250265
async with response_message_recv_stream:
251266
try:

tests/test_asgi.py

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
import json
22

33
import anyio
4+
import contextvars
45
import pytest
56

67
import httpx
78

89

10+
test_asgi_contextvar = contextvars.ContextVar("test_asgi_contextvar")
11+
12+
13+
@pytest.fixture(params=[False, True], ids=["no_streaming", "with_streaming"])
14+
def streaming(request):
15+
return request.param
16+
17+
918
def run_in_task_group(app):
1019
"""A decorator that runs an ASGI callable in a task group"""
1120

@@ -91,8 +100,8 @@ async def raise_exc_after_response(scope, receive, send):
91100

92101

93102
@pytest.mark.anyio
94-
async def test_asgi_transport():
95-
async with httpx.ASGITransport(app=hello_world) as transport:
103+
async def test_asgi_transport(streaming: bool):
104+
async with httpx.ASGITransport(app=hello_world, streaming=streaming) as transport:
96105
request = httpx.Request("GET", "http://www.example.com/")
97106
response = await transport.handle_async_request(request)
98107
await response.aread()
@@ -101,8 +110,8 @@ async def test_asgi_transport():
101110

102111

103112
@pytest.mark.anyio
104-
async def test_asgi_transport_no_body():
105-
async with httpx.ASGITransport(app=echo_body) as transport:
113+
async def test_asgi_transport_no_body(streaming: bool):
114+
async with httpx.ASGITransport(app=echo_body, streaming=streaming) as transport:
106115
request = httpx.Request("GET", "http://www.example.com/")
107116
response = await transport.handle_async_request(request)
108117
await response.aread()
@@ -111,8 +120,8 @@ async def test_asgi_transport_no_body():
111120

112121

113122
@pytest.mark.anyio
114-
async def test_asgi():
115-
transport = httpx.ASGITransport(app=hello_world)
123+
async def test_asgi(streaming: bool):
124+
transport = httpx.ASGITransport(app=hello_world, streaming=streaming)
116125
async with httpx.AsyncClient(transport=transport) as client:
117126
response = await client.get("http://www.example.org/")
118127

@@ -121,8 +130,8 @@ async def test_asgi():
121130

122131

123132
@pytest.mark.anyio
124-
async def test_asgi_urlencoded_path():
125-
transport = httpx.ASGITransport(app=echo_path)
133+
async def test_asgi_urlencoded_path(streaming: bool):
134+
transport = httpx.ASGITransport(app=echo_path, streaming=streaming)
126135
async with httpx.AsyncClient(transport=transport) as client:
127136
url = httpx.URL("http://www.example.org/").copy_with(path="/user@example.org")
128137
response = await client.get(url)
@@ -132,8 +141,8 @@ async def test_asgi_urlencoded_path():
132141

133142

134143
@pytest.mark.anyio
135-
async def test_asgi_raw_path():
136-
transport = httpx.ASGITransport(app=echo_raw_path)
144+
async def test_asgi_raw_path(streaming: bool):
145+
transport = httpx.ASGITransport(app=echo_raw_path, streaming=streaming)
137146
async with httpx.AsyncClient(transport=transport) as client:
138147
url = httpx.URL("http://www.example.org/").copy_with(path="/user@example.org")
139148
response = await client.get(url)
@@ -143,11 +152,11 @@ async def test_asgi_raw_path():
143152

144153

145154
@pytest.mark.anyio
146-
async def test_asgi_raw_path_should_not_include_querystring_portion():
155+
async def test_asgi_raw_path_should_not_include_querystring_portion(streaming: bool):
147156
"""
148157
See https://github.com/encode/httpx/issues/2810
149158
"""
150-
transport = httpx.ASGITransport(app=echo_raw_path)
159+
transport = httpx.ASGITransport(app=echo_raw_path, streaming=streaming)
151160
async with httpx.AsyncClient(transport=transport) as client:
152161
url = httpx.URL("http://www.example.org/path?query")
153162
response = await client.get(url)
@@ -157,8 +166,8 @@ async def test_asgi_raw_path_should_not_include_querystring_portion():
157166

158167

159168
@pytest.mark.anyio
160-
async def test_asgi_upload():
161-
transport = httpx.ASGITransport(app=echo_body)
169+
async def test_asgi_upload(streaming: bool):
170+
transport = httpx.ASGITransport(app=echo_body, streaming=streaming)
162171
async with httpx.AsyncClient(transport=transport) as client:
163172
response = await client.post("http://www.example.org/", content=b"example")
164173

@@ -167,8 +176,8 @@ async def test_asgi_upload():
167176

168177

169178
@pytest.mark.anyio
170-
async def test_asgi_headers():
171-
transport = httpx.ASGITransport(app=echo_headers)
179+
async def test_asgi_headers(streaming: bool):
180+
transport = httpx.ASGITransport(app=echo_headers, streaming=streaming)
172181
async with httpx.AsyncClient(transport=transport) as client:
173182
response = await client.get("http://www.example.org/")
174183

@@ -185,31 +194,33 @@ async def test_asgi_headers():
185194

186195

187196
@pytest.mark.anyio
188-
async def test_asgi_exc():
189-
transport = httpx.ASGITransport(app=raise_exc)
197+
async def test_asgi_exc(streaming: bool):
198+
transport = httpx.ASGITransport(app=raise_exc, streaming=streaming)
190199
async with httpx.AsyncClient(transport=transport) as client:
191200
with pytest.raises(RuntimeError):
192201
await client.get("http://www.example.org/")
193202

194203

195204
@pytest.mark.anyio
196-
async def test_asgi_exc_after_response_start():
197-
transport = httpx.ASGITransport(app=raise_exc_after_response_start)
205+
async def test_asgi_exc_after_response_start(streaming: bool):
206+
transport = httpx.ASGITransport(
207+
app=raise_exc_after_response_start, streaming=streaming
208+
)
198209
async with httpx.AsyncClient(transport=transport) as client:
199210
with pytest.raises(RuntimeError):
200211
await client.get("http://www.example.org/")
201212

202213

203214
@pytest.mark.anyio
204-
async def test_asgi_exc_after_response():
205-
transport = httpx.ASGITransport(app=raise_exc_after_response)
215+
async def test_asgi_exc_after_response(streaming: bool):
216+
transport = httpx.ASGITransport(app=raise_exc_after_response, streaming=streaming)
206217
async with httpx.AsyncClient(transport=transport) as client:
207218
with pytest.raises(RuntimeError):
208219
await client.get("http://www.example.org/")
209220

210221

211222
@pytest.mark.anyio
212-
async def test_asgi_disconnect_after_response_complete():
223+
async def test_asgi_disconnect_after_response_complete(streaming: bool):
213224
disconnect = False
214225

215226
async def read_body(scope, receive, send):
@@ -235,7 +246,7 @@ async def read_body(scope, receive, send):
235246
message = await receive()
236247
disconnect = message.get("type") == "http.disconnect"
237248

238-
transport = httpx.ASGITransport(app=read_body)
249+
transport = httpx.ASGITransport(app=read_body, streaming=streaming)
239250
async with httpx.AsyncClient(transport=transport) as client:
240251
response = await client.post("http://www.example.org/", content=b"example")
241252

@@ -244,18 +255,33 @@ async def read_body(scope, receive, send):
244255

245256

246257
@pytest.mark.anyio
247-
async def test_asgi_exc_no_raise():
248-
transport = httpx.ASGITransport(app=raise_exc, raise_app_exceptions=False)
258+
async def test_asgi_exc_no_raise(streaming: bool):
259+
transport = httpx.ASGITransport(
260+
app=raise_exc, raise_app_exceptions=False, streaming=streaming
261+
)
249262
async with httpx.AsyncClient(transport=transport) as client:
250263
response = await client.get("http://www.example.org/")
251264

252265
assert response.status_code == 500
253266

254267

255268
@pytest.mark.anyio
256-
async def test_asgi_exc_no_raise_after_response_start():
269+
async def test_asgi_exc_no_raise_after_response_start(streaming: bool):
270+
transport = httpx.ASGITransport(
271+
app=raise_exc_after_response_start,
272+
raise_app_exceptions=False,
273+
streaming=streaming,
274+
)
275+
async with httpx.AsyncClient(transport=transport) as client:
276+
response = await client.get("http://www.example.org/")
277+
278+
assert response.status_code == 200
279+
280+
281+
@pytest.mark.anyio
282+
async def test_asgi_exc_no_raise_after_response(streaming: bool):
257283
transport = httpx.ASGITransport(
258-
app=raise_exc_after_response_start, raise_app_exceptions=False
284+
app=raise_exc_after_response, raise_app_exceptions=False, streaming=streaming
259285
)
260286
async with httpx.AsyncClient(transport=transport) as client:
261287
response = await client.get("http://www.example.org/")
@@ -264,15 +290,32 @@ async def test_asgi_exc_no_raise_after_response_start():
264290

265291

266292
@pytest.mark.anyio
267-
async def test_asgi_exc_no_raise_after_response():
293+
async def test_asgi_app_runs_in_same_context_as_caller():
294+
async def set_contextvar_in_app(scope, receive, send):
295+
test_asgi_contextvar.set("value_from_app")
296+
297+
status = 200
298+
output = b"Hello, World!"
299+
headers = [
300+
(b"content-type", "text/plain"),
301+
(b"content-length", str(len(output))),
302+
]
303+
304+
await send(
305+
{"type": "http.response.start", "status": status, "headers": headers}
306+
)
307+
await send({"type": "http.response.body", "body": output})
308+
268309
transport = httpx.ASGITransport(
269-
app=raise_exc_after_response, raise_app_exceptions=False
310+
app=set_contextvar_in_app, raise_app_exceptions=False, streaming=False
270311
)
271312
async with httpx.AsyncClient(transport=transport) as client:
272313
response = await client.get("http://www.example.org/")
273314

274315
assert response.status_code == 200
275316

317+
assert test_asgi_contextvar.get(None) == "value_from_app"
318+
276319

277320
@pytest.mark.parametrize(
278321
"send_in_sub_task",
@@ -296,7 +339,7 @@ async def send_response_body_after_event(scope, receive, send):
296339
send_response_body_after_event
297340
)
298341

299-
transport = httpx.ASGITransport(app=send_response_body_after_event)
342+
transport = httpx.ASGITransport(app=send_response_body_after_event, streaming=True)
300343
async with httpx.AsyncClient(transport=transport) as client:
301344
with anyio.fail_after(0.1):
302345
async with client.stream("GET", "http://www.example.org/") as response:
@@ -335,7 +378,7 @@ async def send_response_body_after_event(scope, receive, send):
335378
send_response_body_after_event
336379
)
337380

338-
transport = httpx.ASGITransport(app=send_response_body_after_event)
381+
transport = httpx.ASGITransport(app=send_response_body_after_event, streaming=True)
339382
async with httpx.AsyncClient(transport=transport) as client:
340383
with anyio.fail_after(0.1):
341384
async with client.stream("GET", "http://www.example.org/") as response:

0 commit comments

Comments
 (0)