Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion getstream/video/rtc/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from getstream.video.rtc.recording import RecordingManager
from getstream.video.rtc.participants import ParticipantsState
from getstream.video.rtc.tracks import SubscriptionConfig, SubscriptionManager
from getstream.video.rtc.reconnection import ReconnectionManager
from getstream.video.rtc.reconnection import ReconnectionManager, ReconnectionStrategy
from getstream.video.rtc.peer_connection import PeerConnectionManager
from getstream.video.rtc.models import JoinCallResponse
from getstream.video.rtc.tracer import Tracer
Expand Down Expand Up @@ -277,6 +277,25 @@ async def _on_subscriber_offer(self, event: events_pb2.SubscriberOffer):
finally:
self.subscriber_negotiation_lock.release()

async def _on_signaling_connection_lost(self, reason: str) -> None:
"""Reconnect when the signaling WebSocket drops unexpectedly.

The WebSocketClient itself only logs the error and stops; it has
no reconnect of its own. This handler bridges that gap by routing
the loss into the existing `ReconnectionManager`, so a transient
TCP reset or a missed health check no longer means a dead session.
"""
if not self.running:
return
logger.warning(f"Signaling WS lost; triggering reconnect: {reason}")
try:
await self._reconnector.reconnect(
strategy=ReconnectionStrategy.FAST,
reason=f"signaling ws lost: {reason}",
)
except Exception:
logger.exception("Reconnect after signaling WS loss failed")

async def _connect_coordinator_ws(self):
"""
Connects to the coordinator websocket and subscribes to events.
Expand Down Expand Up @@ -414,6 +433,15 @@ async def _connect_internal(
# Connect subscriber offer event to handle SDP negotiation
self._ws_client.on_event("subscriber_offer", self._on_subscriber_offer)

# Drive reconnection when the signaling WS drops outside of an
# SFU-level error event (raw socket close, health-check timeout,
# transport-level exceptions). Without this handler the
# WebSocketClient just logs and stops; the session sits hanging
# until the frontend times out and tears it down.
self._ws_client.on_event(
"connection_lost", self._on_signaling_connection_lost
)

# Re-emit the events so they can be subscribed to on the ConnectionManager
self._ws_client.on_wildcard("*", self.emit)

Expand Down
42 changes: 42 additions & 0 deletions getstream/video/rtc/signaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.thread = None
self.running = False
self.closed = False
self._connection_lost_sent = False

# For ping/health check mechanism
self.ping_thread = None
Expand Down Expand Up @@ -214,11 +215,48 @@ def _on_error(self, ws, error):
error_event.error.error.message = str(error)
self.first_message = error_event
self.first_message_event.set()
elif not self.closed:
self._notify_connection_lost(f"error: {error}")
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def _on_close(self, ws, close_status_code, close_msg):
"""Handle WebSocket close event."""
logger.debug(f"WebSocket connection closed: {close_status_code} {close_msg}")
was_unexpected = not self.closed
self.running = False
if was_unexpected:
self._notify_connection_lost(
f"closed by remote (code={close_status_code} msg={close_msg})"
)

def _notify_connection_lost(self, reason: str) -> None:
"""Schedule a ``connection_lost`` emit on the main loop.

Idempotent per ``WebSocketClient`` instance — only the first call
per disconnect actually emits. Callers run on the WS worker thread
or ``_ping_loop`` thread; pyee schedules async listeners via
``loop.create_task``, which is not thread-safe, hence the hop.
Same pattern as ``_on_message`` for SFU events.
"""
if not self._claim_connection_lost():
return
try:
asyncio.run_coroutine_threadsafe(
self._emit_connection_lost(reason),
self.main_loop,
)
except Exception:
logger.exception("Failed to schedule connection_lost emit")

def _claim_connection_lost(self) -> bool:
"""Return True iff this is the first connection-lost notification."""
if self._connection_lost_sent:
return False
self._connection_lost_sent = True
return True

async def _emit_connection_lost(self, reason: str) -> None:
with telemetry.attach_span(self.parent_span):
self.emit("connection_lost", reason)

def _start_ping_handler(self):
"""Start the ping mechanism in a background thread."""
Expand All @@ -242,6 +280,10 @@ def _ping_loop(self):
current_time = time.time()
if current_time - self.last_health_check_time > self.ping_interval * 2:
logger.warning("Health check failed, closing connection")
# Notify before close() so the owner can reconnect; close()
# itself sets `self.closed=True` and would suppress the
# notification in `_on_close`.
self._notify_connection_lost("health check timeout")
self.close()
return

Expand Down
21 changes: 21 additions & 0 deletions tests/rtc/test_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SfuJoinError,
)
from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2
from getstream.video.rtc.reconnection import ReconnectionStrategy

load_dotenv()

Expand Down Expand Up @@ -203,3 +204,23 @@ def test_rejects_negative_max_join_retries(self):
pytest.raises(ValueError, match="max_join_retries must be >= 0"),
):
ConnectionManager(call=MagicMock(), user_id="user1", max_join_retries=-1)

@pytest.mark.asyncio
async def test_signaling_connection_lost_triggers_fast_reconnect(
self, connection_manager
):
"""A signaling-WS `connection_lost` event drives a FAST reconnect.

Without this handler the session would sit hanging on a transient
socket drop until the frontend tears it down.
"""
cm = connection_manager
cm.running = True
cm._reconnector.reconnect = AsyncMock()

await cm._on_signaling_connection_lost("health check timeout")

cm._reconnector.reconnect.assert_called_once()
kwargs = cm._reconnector.reconnect.call_args.kwargs
assert kwargs["strategy"] == ReconnectionStrategy.FAST
assert "health check timeout" in kwargs["reason"]
106 changes: 106 additions & 0 deletions tests/test_signaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,112 @@ async def test_thread_usage(self, join_request, mock_websocket):
# Thread should be joined during close
assert not client.running

@pytest.mark.asyncio
async def test_connection_lost_emitted_on_unexpected_close(
self, join_request, mock_websocket
):
"""An unexpected WS close after handshake emits a `connection_lost` event.

The owner (ConnectionManager) relies on this signal to drive
reconnection — without it, a transient socket drop leaves the
session hanging until the frontend times out.
"""
client = WebSocketClient(
"wss://test.url", join_request, asyncio.get_running_loop()
)

received: list[str] = []

async def on_lost(reason):
received.append(reason)

client.on_event("connection_lost", on_lost)

# Complete handshake so we're past the initial connect phase.
join_response = events_pb2.SfuEvent()
join_response.join_response.reconnected = False

connect_task = asyncio.create_task(client.connect())
await asyncio.sleep(0.1)

on_open_callback = mock_websocket.call_args[1]["on_open"]
on_open_callback(mock_websocket.return_value)

on_message_callback = mock_websocket.call_args[1]["on_message"]
on_message_callback(
mock_websocket.return_value, join_response.SerializeToString()
)
await connect_task

# Simulate the remote dropping the connection (not user-initiated).
on_close_callback = mock_websocket.call_args[1]["on_close"]
on_close_callback(mock_websocket.return_value, 1006, "abnormal closure")

# Allow the threadsafe-scheduled emit to run on the loop.
await asyncio.sleep(0.1)

assert len(received) == 1, (
f"expected exactly one connection_lost event, got {received}"
)
assert "1006" in received[0], (
f"reason should mention the close code, got {received[0]!r}"
)

client.close()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

@pytest.mark.asyncio
async def test_connection_lost_emitted_once_on_error_then_close(
self, join_request, mock_websocket
):
"""A chained ``_on_error`` → ``_on_close`` fires exactly one event.

websocket-client typically delivers an error followed by a close
for the same drop. Consumers (including the SDK-public re-emit via
``ConnectionManager`` wildcard) expect one notification per
disconnect, not one per callback.
"""
client = WebSocketClient(
"wss://test.url", join_request, asyncio.get_running_loop()
)

received: list[str] = []

async def on_lost(reason):
received.append(reason)

client.on_event("connection_lost", on_lost)

join_response = events_pb2.SfuEvent()
join_response.join_response.reconnected = False

connect_task = asyncio.create_task(client.connect())
await asyncio.sleep(0.1)

on_open_callback = mock_websocket.call_args[1]["on_open"]
on_open_callback(mock_websocket.return_value)

on_message_callback = mock_websocket.call_args[1]["on_message"]
on_message_callback(
mock_websocket.return_value, join_response.SerializeToString()
)
await connect_task

# Error first, then close — what websocket-client actually does on
# most transport-level failures.
on_error_callback = mock_websocket.call_args[1]["on_error"]
on_error_callback(mock_websocket.return_value, Exception("boom"))

on_close_callback = mock_websocket.call_args[1]["on_close"]
on_close_callback(mock_websocket.return_value, 1006, "abnormal closure")

await asyncio.sleep(0.1)

assert len(received) == 1, (
f"expected exactly one connection_lost event for the chain, got {received}"
)

client.close()

@pytest.mark.asyncio
async def test_on_open_traces_ws_open_and_join_request(
self, join_request, mock_websocket
Expand Down
Loading