|
14 | 14 | from contextlib import asynccontextmanager |
15 | 15 | from dataclasses import dataclass, field |
16 | 16 | from typing import Any |
17 | | -from unittest.mock import MagicMock |
| 17 | +from unittest.mock import MagicMock, patch |
18 | 18 | from urllib.parse import urlparse |
19 | 19 |
|
20 | 20 | import anyio |
|
29 | 29 |
|
30 | 30 | from mcp import MCPError, types |
31 | 31 | from mcp.client.session import ClientSession |
32 | | -from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client |
| 32 | +from mcp.client.streamable_http import ( |
| 33 | + MAX_RECONNECTION_ATTEMPTS, |
| 34 | + RequestContext as ClientRequestContext, |
| 35 | + StreamableHTTPTransport, |
| 36 | + streamable_http_client, |
| 37 | +) |
33 | 38 | from mcp.server import Server, ServerRequestContext |
34 | 39 | from mcp.server.streamable_http import ( |
35 | 40 | MCP_PROTOCOL_VERSION_HEADER, |
@@ -2318,3 +2323,73 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( |
2318 | 2323 |
|
2319 | 2324 | assert "content-type" in headers_data |
2320 | 2325 | assert headers_data["content-type"] == "application/json" |
| 2326 | + |
| 2327 | + |
| 2328 | +@pytest.mark.anyio |
| 2329 | +async def test_handle_reconnection_does_not_retry_infinitely(): |
| 2330 | + """Reconnection must count TOTAL attempts, not reset on each successful connect. |
| 2331 | +
|
| 2332 | + Regression test for #2393: when a stream connects successfully but drops |
| 2333 | + before delivering a response, the attempt counter was reset to 0 on the |
| 2334 | + recursive call, allowing an infinite retry loop. |
| 2335 | +
|
| 2336 | + This test simulates a stream that connects, yields one non-completing SSE |
| 2337 | + event, then ends — repeatedly. With MAX_RECONNECTION_ATTEMPTS the loop |
| 2338 | + must terminate. |
| 2339 | + """ |
| 2340 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 2341 | + transport.session_id = "test-session" |
| 2342 | + |
| 2343 | + # Track how many times aconnect_sse is called |
| 2344 | + connect_count = 0 |
| 2345 | + |
| 2346 | + @asynccontextmanager |
| 2347 | + async def fake_aconnect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[Any]: |
| 2348 | + """Simulate a stream that connects OK, yields one event, then ends.""" |
| 2349 | + nonlocal connect_count |
| 2350 | + connect_count += 1 |
| 2351 | + |
| 2352 | + mock_response = MagicMock() |
| 2353 | + mock_response.raise_for_status = MagicMock() |
| 2354 | + |
| 2355 | + # Yield a single non-completing notification SSE event, then end the stream |
| 2356 | + # (simulating a server that drops the connection after partial delivery) |
| 2357 | + async def aiter_sse() -> AsyncIterator[ServerSentEvent]: |
| 2358 | + yield ServerSentEvent( |
| 2359 | + event="message", |
| 2360 | + data='{"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"tok","progress":1,"total":10}}', |
| 2361 | + id=f"evt-{connect_count}", |
| 2362 | + retry=None, |
| 2363 | + ) |
| 2364 | + |
| 2365 | + event_source = MagicMock() |
| 2366 | + event_source.response = mock_response |
| 2367 | + event_source.aiter_sse = aiter_sse |
| 2368 | + yield event_source |
| 2369 | + |
| 2370 | + # Build a minimal RequestContext for _handle_reconnection |
| 2371 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](32) |
| 2372 | + |
| 2373 | + async with write_stream, read_stream: |
| 2374 | + request_message = JSONRPCRequest(jsonrpc="2.0", id="req-1", method="tools/call", params={}) |
| 2375 | + session_message = SessionMessage(request_message) |
| 2376 | + ctx = ClientRequestContext( |
| 2377 | + client=MagicMock(), |
| 2378 | + session_id="test-session", |
| 2379 | + session_message=session_message, |
| 2380 | + metadata=None, |
| 2381 | + read_stream_writer=write_stream, |
| 2382 | + ) |
| 2383 | + |
| 2384 | + with patch("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse): |
| 2385 | + # Use a short sleep override so the test doesn't wait on reconnection delays |
| 2386 | + with patch("mcp.client.streamable_http.DEFAULT_RECONNECTION_DELAY_MS", 0): |
| 2387 | + await transport._handle_reconnection(ctx, last_event_id="evt-0", retry_interval_ms=0) |
| 2388 | + |
| 2389 | + # The method should have connected at most MAX_RECONNECTION_ATTEMPTS times |
| 2390 | + # (one for the initial call at attempt=0, then up to MAX-1 more) |
| 2391 | + assert connect_count <= MAX_RECONNECTION_ATTEMPTS, ( |
| 2392 | + f"Expected at most {MAX_RECONNECTION_ATTEMPTS} reconnection attempts, " |
| 2393 | + f"but aconnect_sse was called {connect_count} times — " |
| 2394 | + f"the attempt counter is not being incremented across reconnections" |
| 2395 | + ) |
0 commit comments