diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 9e5c9bb2ec..ae0ddd140c 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -34,6 +34,7 @@ from ..events.event import Event from ..events.event_actions import EventActions from ..events.event_actions import EventCompaction +from ..models.cache_metadata import CacheMetadata from ..utils.vertex_ai_utils import get_express_mode_api_key from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig @@ -303,6 +304,14 @@ async def append_event(self, session: Session, event: Event) -> Event: else None ), } + if event.usage_metadata: + metadata_dict['usage_metadata'] = event.usage_metadata.model_dump( + exclude_none=True, mode='json' + ) + if event.cache_metadata: + metadata_dict['cache_metadata'] = event.cache_metadata.model_dump( + exclude_none=True, mode='json' + ) if event.grounding_metadata: metadata_dict['grounding_metadata'] = event.grounding_metadata.model_dump( exclude_none=True, mode='json' @@ -423,6 +432,14 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: getattr(event_metadata, 'grounding_metadata', None), types.GroundingMetadata, ) + usage_metadata = _session_util.decode_model( + getattr(event_metadata, 'usage_metadata', None), + types.GenerateContentResponseUsageMetadata, + ) + cache_metadata = _session_util.decode_model( + getattr(event_metadata, 'cache_metadata', None), + CacheMetadata, + ) else: long_running_tool_ids = None partial = None @@ -433,6 +450,8 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: compaction_data = None usage_metadata_data = None grounding_metadata = None + usage_metadata = None + cache_metadata = None if actions: actions_dict = actions.model_dump(exclude_none=True, mode='python') @@ -474,6 +493,8 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: branch=branch, custom_metadata=custom_metadata, grounding_metadata=grounding_metadata, + usage_metadata=usage_metadata, + cache_metadata=cache_metadata, long_running_tool_ids=long_running_tool_ids, usage_metadata=usage_metadata, ) diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 20fdbe3c6d..c8cf8a6adb 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -28,6 +28,7 @@ from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.events.event_actions import EventCompaction +from google.adk.models.cache_metadata import CacheMetadata from google.adk.sessions.base_session_service import GetSessionConfig from google.adk.sessions.session import Session from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService @@ -249,6 +250,8 @@ def _convert_to_object(data): 'artifact_delta', 'custom_metadata', 'requested_auth_configs', + 'cache_metadata', + 'usage_metadata', ]: kwargs[key] = value else: @@ -1039,3 +1042,52 @@ async def test_append_event_with_usage_metadata_and_compaction(): assert appended_event.custom_metadata == {'extra': 'info'} assert '_compaction' not in (appended_event.custom_metadata or {}) assert '_usage_metadata' not in (appended_event.custom_metadata or {}) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event_with_cache_and_usage_metadata(): + """cache_metadata and usage_metadata round-trip through append and get.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert session is not None + + cache_meta = CacheMetadata( + cache_name='projects/123/locations/us-central1/cachedContents/456', + expire_time=9999999999.0, + fingerprint='abc123hash', + invocations_used=3, + contents_count=10, + created_at=1700000000.0, + ) + usage_meta = genai_types.GenerateContentResponseUsageMetadata( + prompt_token_count=100, + candidates_token_count=50, + total_token_count=150, + cached_content_token_count=80, + ) + event_to_append = Event( + invocation_id='cache_test_invocation', + author='model', + timestamp=1734005536.0, + content=genai_types.Content( + parts=[genai_types.Part(text='cached response')] + ), + cache_metadata=cache_meta, + usage_metadata=usage_meta, + ) + + await session_service.append_event(session, event_to_append) + + retrieved_session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert retrieved_session is not None + + appended_event = retrieved_session.events[-1] + # cache_metadata is preserved + assert appended_event.cache_metadata == cache_meta + # usage_metadata is preserved + assert appended_event.usage_metadata == usage_meta