diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index bc175ccc48..bb2baf1bb5 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -129,6 +129,7 @@ def __init__( self._timeout_config = self._create_timeout_config(timeout) if client is not None: self.client = client + self._non_streaming_client: Client | None = None self._close_http_client = True return if agent_card is None: @@ -144,17 +145,30 @@ def __init__( self._http_client = http_client # Store for cleanup self._close_http_client = True - # Create A2A client using factory - config = ClientConfig( + interceptors = [auth_interceptor] if auth_interceptor is not None else None + + # Create streaming client (SSE transport for stream=True) + streaming_config = ClientConfig( httpx_client=http_client, + streaming=True, supported_protocol_bindings=["JSONRPC"], ) - factory = ClientFactory(config) - interceptors = [auth_interceptor] if auth_interceptor is not None else None + # Create non-streaming client (single request/response for stream=False) + non_streaming_config = ClientConfig( + httpx_client=http_client, + streaming=False, + supported_protocol_bindings=["JSONRPC"], + ) + streaming_factory = ClientFactory(streaming_config) + non_streaming_factory = ClientFactory(non_streaming_config) # Attempt transport negotiation with the provided agent card try: - self.client = factory.create(agent_card, interceptors=interceptors) # type: ignore + self.client = streaming_factory.create(agent_card, interceptors=interceptors) # type: ignore + self._non_streaming_client = non_streaming_factory.create( + agent_card, + interceptors=interceptors, # type: ignore + ) except Exception as transport_error: # Transport negotiation failed - fall back to minimal agent card with JSONRPC fallback_url = agent_card.supported_interfaces[0].url if agent_card.supported_interfaces else url @@ -166,7 +180,11 @@ def __init__( ) from transport_error fallback_card = minimal_agent_card(fallback_url, ["JSONRPC"]) try: - self.client = factory.create(fallback_card, interceptors=interceptors) # type: ignore + self.client = streaming_factory.create(fallback_card, interceptors=interceptors) # type: ignore + self._non_streaming_client = non_streaming_factory.create( + fallback_card, + interceptors=interceptors, # type: ignore + ) except Exception as fallback_error: raise RuntimeError( f"A2A transport negotiation failed. " @@ -282,6 +300,13 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] del function_invocation_kwargs, client_kwargs, kwargs normalized_messages = normalize_messages(messages) + # Use non-streaming transport for non-streaming calls when available. + # This sends a single HTTP request/response instead of opening an SSE + # connection, matching the protocol's intent for synchronous operations. + active_client = ( + self._non_streaming_client if (not stream and self._non_streaming_client is not None) else self.client + ) + if continuation_token is not None: a2a_stream: AsyncIterable[A2AStreamItem] = self.client.subscribe( SubscribeToTaskRequest(id=continuation_token["task_id"]) @@ -293,7 +318,11 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] normalized_messages[-1], context_id=session.service_session_id if session else None, ) - a2a_stream = self.client.send_message(SendMessageRequest(message=a2a_message)) + request = SendMessageRequest(message=a2a_message) + if background and not stream: + # return_immediately only applies to non-streaming (message/send) + request.configuration.return_immediately = True + a2a_stream = active_client.send_message(request) provider_session = session if provider_session is None and self.context_providers: diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 76294f30bf..f5474bc374 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -44,6 +44,7 @@ def __init__(self) -> None: self.subscribe_responses: list[StreamResponse] = [] self.get_task_response: Task | None = None self.last_message: Any = None + self.last_request: Any = None def add_message_response(self, message_id: str, text: str, role: str = "agent") -> None: """Add a mock Message response.""" @@ -91,6 +92,7 @@ def add_in_progress_task_response( async def send_message(self, request: Any) -> AsyncIterator[StreamResponse]: """Mock send_message method that yields responses.""" + self.last_request = request self.last_message = getattr(request, "message", request) self.call_count += 1 @@ -745,6 +747,96 @@ async def test_working_task_no_token_without_background(a2a_agent: A2AAgent, moc assert response.continuation_token is None +async def test_background_sets_return_immediately_on_request( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that background=True sets return_immediately=True on SendMessageRequest configuration.""" + mock_a2a_client.add_in_progress_task_response("task-bg", state=TaskState.TASK_STATE_WORKING) + + await a2a_agent.run("Background task", background=True) + + assert mock_a2a_client.last_request.configuration.return_immediately is True + + +async def test_foreground_does_not_set_return_immediately( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that background=False (default) does not set configuration on SendMessageRequest.""" + mock_a2a_client.add_task_response("task-fg2", [{"id": "art-1", "content": "Done"}]) + + await a2a_agent.run("Foreground task") + + assert mock_a2a_client.last_request.HasField("configuration") is False + + +async def test_streaming_background_does_not_set_return_immediately( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that background=True with stream=True does not set return_immediately. + + Per A2A spec, return_immediately only applies to non-streaming (message/send). + """ + mock_a2a_client.add_task_response("task-sb", [{"id": "art-1", "content": "Streaming bg"}]) + + updates: list[AgentResponseUpdate] = [] + async for update in a2a_agent.run("Stream background", stream=True, background=True): + updates.append(update) + + assert mock_a2a_client.last_request.HasField("configuration") is False + + +async def test_non_streaming_run_uses_non_streaming_client() -> None: + """Test that stream=False uses the non-streaming client when available.""" + streaming_client = MockA2AClient() + non_streaming_client = MockA2AClient() + non_streaming_client.add_task_response("task-ns", [{"id": "art-1", "content": "Non-streaming result"}]) + + agent = A2AAgent(name="Test Agent", id="test-ns", client=streaming_client, http_client=None) + agent._non_streaming_client = non_streaming_client # type: ignore[assignment] + + response = await agent.run("Hello") + + # Non-streaming client should have been called + assert non_streaming_client.call_count == 1 + assert streaming_client.call_count == 0 + assert response.messages[0].text == "Non-streaming result" + assert non_streaming_client.last_request.HasField("configuration") is False + + +async def test_streaming_run_uses_streaming_client() -> None: + """Test that stream=True always uses the streaming client.""" + streaming_client = MockA2AClient() + non_streaming_client = MockA2AClient() + streaming_client.add_task_response("task-s", [{"id": "art-1", "content": "Streaming result"}]) + + agent = A2AAgent(name="Test Agent", id="test-s", client=streaming_client, http_client=None) + agent._non_streaming_client = non_streaming_client # type: ignore[assignment] + + updates: list[AgentResponseUpdate] = [] + async for update in agent.run("Hello", stream=True): + updates.append(update) + + # Streaming client should have been called + assert streaming_client.call_count == 1 + assert non_streaming_client.call_count == 0 + assert updates[0].contents[0].text == "Streaming result" + + +async def test_non_streaming_client_fallback_when_not_available( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that stream=False falls back to streaming client when non-streaming client is unavailable.""" + mock_a2a_client.add_task_response("task-fb", [{"id": "art-1", "content": "Fallback result"}]) + + # a2a_agent is created with client= param so _non_streaming_client is None + assert a2a_agent._non_streaming_client is None + + response = await a2a_agent.run("Hello") + + assert mock_a2a_client.call_count == 1 + assert response.messages[0].text == "Fallback result" + + async def test_completed_task_has_no_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: """Test that a completed task does not set a continuation token.""" mock_a2a_client.add_task_response("task-done", [{"id": "art-1", "content": "Result"}])