From 0ef1ebc67e0181feec6698c6be675e22897210a3 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 28 Oct 2025 09:23:57 -0700 Subject: [PATCH] fix: GenAI Client(evals) - Apply async function for agent run PiperOrigin-RevId: 825073570 --- tests/unit/vertexai/genai/test_evals.py | 32 ++++++++++++++++++------- vertexai/_genai/_evals_common.py | 28 +++++++++++----------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index b6da92e69c..d4886f2978 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -1070,7 +1070,9 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict( ) mock_agent_engine = mock.Mock() - mock_agent_engine.create_session.return_value = {"id": "session1"} + mock_agent_engine.async_create_session = mock.AsyncMock( + return_value={"id": "session1"} + ) stream_query_return_value = [ { "id": "1", @@ -1086,7 +1088,13 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict( }, ] - mock_agent_engine.stream_query.return_value = iter(stream_query_return_value) + async def _async_iterator(iterable): + for item in iterable: + yield item + + mock_agent_engine.async_stream_query.return_value = _async_iterator( + stream_query_return_value + ) mock_vertexai_client.return_value.agent_engines.get.return_value = ( mock_agent_engine ) @@ -1100,10 +1108,10 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict( mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with( name="projects/test-project/locations/us-central1/reasoningEngines/123" ) - mock_agent_engine.create_session.assert_called_once_with( + mock_agent_engine.async_create_session.assert_called_once_with( user_id="123", state={"a": "1"} ) - mock_agent_engine.stream_query.assert_called_once_with( + mock_agent_engine.async_stream_query.assert_called_once_with( user_id="123", session_id="session1", message="agent prompt" ) @@ -1154,7 +1162,9 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string( ) mock_agent_engine = mock.Mock() - mock_agent_engine.create_session.return_value = {"id": "session1"} + mock_agent_engine.async_create_session = mock.AsyncMock( + return_value={"id": "session1"} + ) stream_query_return_value = [ { "id": "1", @@ -1170,7 +1180,13 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string( }, ] - mock_agent_engine.stream_query.return_value = iter(stream_query_return_value) + async def _async_iterator(iterable): + for item in iterable: + yield item + + mock_agent_engine.async_stream_query.return_value = _async_iterator( + stream_query_return_value + ) mock_vertexai_client.return_value.agent_engines.get.return_value = ( mock_agent_engine ) @@ -1184,10 +1200,10 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string( mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with( name="projects/test-project/locations/us-central1/reasoningEngines/123" ) - mock_agent_engine.create_session.assert_called_once_with( + mock_agent_engine.async_create_session.assert_called_once_with( user_id="123", state={"a": "1"} ) - mock_agent_engine.stream_query.assert_called_once_with( + mock_agent_engine.async_stream_query.assert_called_once_with( user_id="123", session_id="session1", message="agent prompt" ) diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index bd43229bd9..1e9469be67 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -41,6 +41,7 @@ from . import evals from . import types +from . import agent_engines try: import litellm @@ -62,12 +63,9 @@ def _get_agent_engine_instance( if not hasattr(_thread_local_data, "agent_engine_instances"): _thread_local_data.agent_engine_instances = {} if agent_name not in _thread_local_data.agent_engine_instances: - client = vertexai.Client( - project=api_client.project, - location=api_client.location, - ) + agent_engines_module = agent_engines.AgentEngines(api_client_=api_client) _thread_local_data.agent_engine_instances[agent_name] = ( - client.agent_engines.get(name=agent_name) + agent_engines_module.get(name=agent_name) ) return _thread_local_data.agent_engine_instances[agent_name] @@ -278,10 +276,12 @@ def agent_run_wrapper( and type(agent_engine).__name__ == "AgentEngine" ): agent_engine_instance = agent_engine - return inference_fn_arg( - row=row_arg, - contents=contents_arg, - agent_engine=agent_engine_instance, + return asyncio.run( + inference_fn_arg( + row=row_arg, + contents=contents_arg, + agent_engine=agent_engine_instance, + ) ) future = executor.submit( @@ -1262,7 +1262,7 @@ def _run_agent( ) -def _execute_agent_run_with_retry( +async def _execute_agent_run_with_retry( row: pd.Series, contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], agent_engine: types.AgentEngine, @@ -1284,7 +1284,7 @@ def _execute_agent_run_with_retry( ) user_id = session_inputs.user_id session_state = session_inputs.state - session = agent_engine.create_session( + session = await agent_engine.async_create_session( user_id=user_id, state=session_state, ) @@ -1295,7 +1295,7 @@ def _execute_agent_run_with_retry( for attempt in range(max_retries): try: responses = [] - for event in agent_engine.stream_query( + async for event in agent_engine.async_stream_query( user_id=user_id, session_id=session["id"], message=contents, @@ -1314,7 +1314,7 @@ def _execute_agent_run_with_retry( ) if attempt == max_retries - 1: return {"error": f"Resource exhausted after retries: {e}"} - time.sleep(2**attempt) + await asyncio.sleep(2**attempt) except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Unexpected error during generate_content on attempt %d/%d: %s", @@ -1325,7 +1325,7 @@ def _execute_agent_run_with_retry( if attempt == max_retries - 1: return {"error": f"Failed after retries: {e}"} - time.sleep(1) + await asyncio.sleep(1) return {"error": f"Failed to get agent run results after {max_retries} retries"}