Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,9 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
)

mock_agent_engine = mock.Mock()
mock_agent_engine.create_session.return_value = {"id": "session1"}
mock_agent_engine.async_create_session = mock.AsyncMock(
return_value={"id": "session1"}
)
stream_query_return_value = [
{
"id": "1",
Expand All @@ -1086,7 +1088,13 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
},
]

mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
async def _async_iterator(iterable):
for item in iterable:
yield item

mock_agent_engine.async_stream_query.return_value = _async_iterator(
stream_query_return_value
)
mock_vertexai_client.return_value.agent_engines.get.return_value = (
mock_agent_engine
)
Expand All @@ -1100,10 +1108,10 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
name="projects/test-project/locations/us-central1/reasoningEngines/123"
)
mock_agent_engine.create_session.assert_called_once_with(
mock_agent_engine.async_create_session.assert_called_once_with(
user_id="123", state={"a": "1"}
)
mock_agent_engine.stream_query.assert_called_once_with(
mock_agent_engine.async_stream_query.assert_called_once_with(
user_id="123", session_id="session1", message="agent prompt"
)

Expand Down Expand Up @@ -1154,7 +1162,9 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
)

mock_agent_engine = mock.Mock()
mock_agent_engine.create_session.return_value = {"id": "session1"}
mock_agent_engine.async_create_session = mock.AsyncMock(
return_value={"id": "session1"}
)
stream_query_return_value = [
{
"id": "1",
Expand All @@ -1170,7 +1180,13 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
},
]

mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
async def _async_iterator(iterable):
for item in iterable:
yield item

mock_agent_engine.async_stream_query.return_value = _async_iterator(
stream_query_return_value
)
mock_vertexai_client.return_value.agent_engines.get.return_value = (
mock_agent_engine
)
Expand All @@ -1184,10 +1200,10 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
name="projects/test-project/locations/us-central1/reasoningEngines/123"
)
mock_agent_engine.create_session.assert_called_once_with(
mock_agent_engine.async_create_session.assert_called_once_with(
user_id="123", state={"a": "1"}
)
mock_agent_engine.stream_query.assert_called_once_with(
mock_agent_engine.async_stream_query.assert_called_once_with(
user_id="123", session_id="session1", message="agent prompt"
)

Expand Down
28 changes: 14 additions & 14 deletions vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from . import evals
from . import types
from . import agent_engines

try:
import litellm
Expand All @@ -62,12 +63,9 @@ def _get_agent_engine_instance(
if not hasattr(_thread_local_data, "agent_engine_instances"):
_thread_local_data.agent_engine_instances = {}
if agent_name not in _thread_local_data.agent_engine_instances:
client = vertexai.Client(
project=api_client.project,
location=api_client.location,
)
agent_engines_module = agent_engines.AgentEngines(api_client_=api_client)
_thread_local_data.agent_engine_instances[agent_name] = (
client.agent_engines.get(name=agent_name)
agent_engines_module.get(name=agent_name)
)
return _thread_local_data.agent_engine_instances[agent_name]

Expand Down Expand Up @@ -278,10 +276,12 @@ def agent_run_wrapper(
and type(agent_engine).__name__ == "AgentEngine"
):
agent_engine_instance = agent_engine
return inference_fn_arg(
row=row_arg,
contents=contents_arg,
agent_engine=agent_engine_instance,
return asyncio.run(
inference_fn_arg(
row=row_arg,
contents=contents_arg,
agent_engine=agent_engine_instance,
)
)

future = executor.submit(
Expand Down Expand Up @@ -1262,7 +1262,7 @@ def _run_agent(
)


def _execute_agent_run_with_retry(
async def _execute_agent_run_with_retry(
row: pd.Series,
contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict],
agent_engine: types.AgentEngine,
Expand All @@ -1284,7 +1284,7 @@ def _execute_agent_run_with_retry(
)
user_id = session_inputs.user_id
session_state = session_inputs.state
session = agent_engine.create_session(
session = await agent_engine.async_create_session(
user_id=user_id,
state=session_state,
)
Expand All @@ -1295,7 +1295,7 @@ def _execute_agent_run_with_retry(
for attempt in range(max_retries):
try:
responses = []
for event in agent_engine.stream_query(
async for event in agent_engine.async_stream_query(
user_id=user_id,
session_id=session["id"],
message=contents,
Expand All @@ -1314,7 +1314,7 @@ def _execute_agent_run_with_retry(
)
if attempt == max_retries - 1:
return {"error": f"Resource exhausted after retries: {e}"}
time.sleep(2**attempt)
await asyncio.sleep(2**attempt)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Unexpected error during generate_content on attempt %d/%d: %s",
Expand All @@ -1325,7 +1325,7 @@ def _execute_agent_run_with_retry(

if attempt == max_retries - 1:
return {"error": f"Failed after retries: {e}"}
time.sleep(1)
await asyncio.sleep(1)
return {"error": f"Failed to get agent run results after {max_retries} retries"}


Expand Down
Loading