diff --git a/python/restate/server.py b/python/restate/server.py index f6a514a..70e3aba 100644 --- a/python/restate/server.py +++ b/python/restate/server.py @@ -12,7 +12,8 @@ import asyncio import logging -from typing import Dict, TypedDict, Literal +import signal +from typing import Dict, Set, TypedDict, Literal from restate.discovery import compute_discovery_json from restate.endpoint import Endpoint @@ -213,7 +214,23 @@ def asgi_app(endpoint: Endpoint) -> RestateAppT: # Prepare request signer identity_verifier = PyIdentityVerifier(endpoint.identity_keys) + active_channels: Set[ReceiveChannel] = set() + sigterm_installed = False + + def _on_sigterm() -> None: + """Notify all active receive channels of graceful shutdown.""" + for ch in active_channels: + ch.notify_shutdown() + async def app(scope: Scope, receive: Receive, send: Send): + nonlocal sigterm_installed + if not sigterm_installed: + loop = asyncio.get_running_loop() + try: + loop.add_signal_handler(signal.SIGTERM, _on_sigterm) + except (NotImplementedError, RuntimeError): + pass # Windows or non-main thread + sigterm_installed = True try: if scope["type"] == "lifespan": raise LifeSpanNotImplemented() @@ -265,11 +282,13 @@ async def app(scope: Scope, receive: Receive, send: Send): # Let us set up restate's execution context for this invocation and handler. # receive_channel = ReceiveChannel(receive) + active_channels.add(receive_channel) try: await process_invocation_to_completion( VMWrapper(request_headers), handler, dict(request_headers), receive_channel, send ) finally: + active_channels.discard(receive_channel) await receive_channel.close() except LifeSpanNotImplemented as e: raise e diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 216d48c..59823d8 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -454,6 +454,8 @@ async def leave(self): # {'type': 'http.request', 'body': b'', 'more_body': True} # {'type': 'http.request', 'body': b'', 'more_body': False} # {'type': 'http.disconnect'} + # Wait for the runtime to explicitly close its side of the input. + # On SIGTERM, the shutdown event unblocks this instead of an arbitrary timeout. await self.receive.block_until_http_input_closed() # finally, we close our side # it is important to do it, after the other side has closed his side, @@ -545,9 +547,9 @@ async def wrapper(f): continue if chunk.get("type") == "http.disconnect": raise DisconnectedException() - if chunk.get("body", None) is not None: - body = chunk.get("body") - assert isinstance(body, bytes) + # Skip empty body frames to avoid hot loop (see #175) + body: bytes | None = chunk.get("body", None) # type: ignore[assignment] + if body is not None and len(body) > 0: self.vm.notify_input(body) if not chunk.get("more_body", False): self.vm.notify_input_closed() diff --git a/python/restate/server_types.py b/python/restate/server_types.py index 63bb9ce..8c9d332 100644 --- a/python/restate/server_types.py +++ b/python/restate/server_types.py @@ -58,6 +58,12 @@ class HTTPRequestEvent(TypedDict): more_body: bool +class HTTPDisconnectEvent(TypedDict): + """ASGI Disconnect event""" + + type: Literal["http.disconnect"] + + class HTTPResponseStartEvent(TypedDict): """ASGI Response start event""" @@ -75,7 +81,7 @@ class HTTPResponseBodyEvent(TypedDict): more_body: bool -ASGIReceiveEvent = HTTPRequestEvent +ASGIReceiveEvent = Union[HTTPRequestEvent, HTTPDisconnectEvent] ASGISendEvent = Union[HTTPResponseStartEvent, HTTPResponseBodyEvent] @@ -158,12 +164,18 @@ async def loop(): async def __call__(self) -> ASGIReceiveEvent | RestateEvent: """Get the next message.""" + if self._disconnected.is_set() and self._queue.empty(): + return {"type": "http.disconnect"} what = await self._queue.get() self._queue.task_done() return what + def notify_shutdown(self) -> None: + """Signal that a graceful shutdown has been requested (e.g. SIGTERM).""" + self._http_input_closed.set() + async def block_until_http_input_closed(self) -> None: - """Wait until the HTTP input is closed""" + """Wait until the HTTP input is closed or a shutdown signal is received.""" await self._http_input_closed.wait() async def enqueue_restate_event(self, what: RestateEvent): diff --git a/tests/disconnect_hotloop.py b/tests/disconnect_hotloop.py new file mode 100644 index 0000000..f81a724 --- /dev/null +++ b/tests/disconnect_hotloop.py @@ -0,0 +1,256 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +Regression tests for disconnect and SIGTERM shutdown handling. + +Covers: +- Hot-loop bug when BidiStream disconnects (empty queue, empty body frames) +- Graceful shutdown via notify_shutdown() unblocking block_until_http_input_closed() +""" + +import asyncio +from typing import cast +from unittest.mock import MagicMock + +import pytest + +from restate.server_types import ASGIReceiveEvent, ReceiveChannel + + +@pytest.fixture(scope="session") +def anyio_backend(): + return "asyncio" + + +pytestmark = [ + pytest.mark.anyio, +] + + +async def test_receive_channel_returns_disconnect_when_drained(): + """After disconnect, an empty queue should return http.disconnect immediately.""" + events = [ + {"type": "http.request", "body": b"hello", "more_body": True}, + {"type": "http.request", "body": b"", "more_body": False}, + {"type": "http.disconnect"}, + ] + event_iter = iter(events) + + async def mock_receive() -> ASGIReceiveEvent: + try: + return cast(ASGIReceiveEvent, next(event_iter)) + except StopIteration: + # Block forever — simulates the real ASGI receive after disconnect + await asyncio.Event().wait() + raise RuntimeError("unreachable") + + channel = ReceiveChannel(mock_receive) + + # Drain all queued events + try: + result1 = await asyncio.wait_for(channel(), timeout=1.0) + assert result1["type"] == "http.request" + + result2 = await asyncio.wait_for(channel(), timeout=1.0) + assert result2["type"] == "http.request" + + result3 = await asyncio.wait_for(channel(), timeout=1.0) + assert result3["type"] == "http.disconnect" + + # Now the queue is drained and _disconnected is set. + # This call should return immediately with a synthetic disconnect, + # NOT block forever. + result4 = await asyncio.wait_for(channel(), timeout=1.0) + assert result4["type"] == "http.disconnect" + finally: + await channel.close() + + +async def test_receive_channel_does_not_block_after_disconnect(): + """Repeated calls after disconnect should all return promptly.""" + events = [ + {"type": "http.disconnect"}, + ] + event_iter = iter(events) + + async def mock_receive() -> ASGIReceiveEvent: + try: + return cast(ASGIReceiveEvent, next(event_iter)) + except StopIteration: + await asyncio.Event().wait() + raise RuntimeError("unreachable") + + channel = ReceiveChannel(mock_receive) + + try: + # Consume the real disconnect + result = await asyncio.wait_for(channel(), timeout=1.0) + assert result["type"] == "http.disconnect" + + # Subsequent calls should not block + for _ in range(5): + result = await asyncio.wait_for(channel(), timeout=0.5) + assert result["type"] == "http.disconnect" + finally: + await channel.close() + + +async def test_empty_body_frames_do_not_cause_hotloop(): + """ + When the VM returns DoProgressReadFromInput and the chunk has body=b'', + notify_input should NOT be called (it would cause a tight loop). + The loop should exit via DisconnectedException when http.disconnect arrives. + """ + from restate.server_context import ServerInvocationContext, DisconnectedException + from restate.vm import DoProgressReadFromInput + + # Build a minimal mock context + vm = MagicMock() + vm.take_output.return_value = None + vm.do_progress.return_value = DoProgressReadFromInput() + + handler = MagicMock() + invocation = MagicMock() + send = MagicMock() + + events = [ + {"type": "http.request", "body": b"", "more_body": True}, + {"type": "http.request", "body": b"", "more_body": False}, + {"type": "http.disconnect"}, + ] + event_iter = iter(events) + + async def mock_receive() -> ASGIReceiveEvent: + try: + return cast(ASGIReceiveEvent, next(event_iter)) + except StopIteration: + await asyncio.Event().wait() + raise RuntimeError("unreachable") + + receive_channel = ReceiveChannel(mock_receive) + + ctx = ServerInvocationContext.__new__(ServerInvocationContext) + ctx.vm = vm + ctx.handler = handler + ctx.invocation = invocation + ctx.send = send + ctx.receive = receive_channel + ctx.run_coros_to_execute = {} + ctx.tasks = MagicMock() + + try: + with pytest.raises(DisconnectedException): + await asyncio.wait_for( + ctx.create_poll_or_cancel_coroutine([0]), + timeout=2.0, + ) + + # notify_input should never have been called with empty bytes + for call in vm.notify_input.call_args_list: + arg = call[0][0] + assert len(arg) > 0, f"notify_input called with empty bytes: {arg!r}" + finally: + await receive_channel.close() + + +# ---- Shutdown / SIGTERM tests ---- + + +async def test_block_until_http_input_closed_returns_on_normal_close(): + """block_until_http_input_closed returns when the runtime closes its input.""" + events = [ + {"type": "http.request", "body": b"data", "more_body": True}, + {"type": "http.request", "body": b"", "more_body": False}, + ] + event_iter = iter(events) + + async def mock_receive() -> ASGIReceiveEvent: + try: + return cast(ASGIReceiveEvent, next(event_iter)) + except StopIteration: + await asyncio.Event().wait() + raise RuntimeError("unreachable") + + channel = ReceiveChannel(mock_receive) + try: + # Should return promptly once more_body=False is received + await asyncio.wait_for(channel.block_until_http_input_closed(), timeout=1.0) + finally: + await channel.close() + + +async def test_block_until_http_input_closed_returns_on_shutdown(): + """block_until_http_input_closed returns when notify_shutdown() is called, + even if the runtime never closes its input.""" + + async def mock_receive() -> ASGIReceiveEvent: + # Never sends any events — simulates the runtime not closing its side + await asyncio.Event().wait() + raise RuntimeError("unreachable") + + channel = ReceiveChannel(mock_receive) + try: + # Schedule shutdown after a short delay + async def trigger_shutdown(): + await asyncio.sleep(0.05) + channel.notify_shutdown() + + asyncio.create_task(trigger_shutdown()) + + # Should return promptly due to shutdown, NOT block forever + await asyncio.wait_for(channel.block_until_http_input_closed(), timeout=1.0) + finally: + await channel.close() + + +async def test_notify_shutdown_is_idempotent(): + """Calling notify_shutdown() multiple times does not raise.""" + + async def mock_receive() -> ASGIReceiveEvent: + await asyncio.Event().wait() + raise RuntimeError("unreachable") + + channel = ReceiveChannel(mock_receive) + try: + channel.notify_shutdown() + channel.notify_shutdown() # should not raise + + # Should return immediately since shutdown is already set + await asyncio.wait_for(channel.block_until_http_input_closed(), timeout=0.5) + finally: + await channel.close() + + +async def test_shutdown_unblocks_concurrent_waiters(): + """Multiple concurrent waiters on block_until_http_input_closed + should all be unblocked by a single notify_shutdown().""" + + async def mock_receive() -> ASGIReceiveEvent: + await asyncio.Event().wait() + raise RuntimeError("unreachable") + + channel = ReceiveChannel(mock_receive) + try: + results = [] + + async def waiter(idx: int): + await channel.block_until_http_input_closed() + results.append(idx) + + tasks = [asyncio.create_task(waiter(i)) for i in range(3)] + + await asyncio.sleep(0.05) + channel.notify_shutdown() + + await asyncio.wait_for(asyncio.gather(*tasks), timeout=1.0) + assert sorted(results) == [0, 1, 2] + finally: + await channel.close()