|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +import asyncio |
17 | 18 | from collections.abc import Mapping |
18 | 19 | from collections.abc import Sequence |
19 | | -from datetime import datetime |
| 20 | +import datetime |
20 | 21 | from functools import lru_cache |
21 | 22 | import logging |
22 | 23 | from typing import Optional |
|
38 | 39 |
|
39 | 40 | logger = logging.getLogger('google_adk.' + __name__) |
40 | 41 |
|
| 42 | +# Strong references to fire-and-forget tasks to prevent garbage collection. |
| 43 | +# See https://docs.python.org/3/library/asyncio-task.html#creating-tasks |
| 44 | +_background_tasks: set[asyncio.Task] = set() |
| 45 | + |
41 | 46 | _GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS = frozenset({ |
42 | 47 | 'disable_consolidation', |
43 | 48 | 'disable_memory_revisions', |
|
47 | 52 | 'revision_expire_time', |
48 | 53 | 'revision_labels', |
49 | 54 | 'revision_ttl', |
| 55 | + 'ttl', |
50 | 56 | 'wait_for_completion', |
51 | 57 | }) |
52 | 58 |
|
|
65 | 71 | 'wait_for_completion', |
66 | 72 | }) |
67 | 73 |
|
| 74 | +_INGEST_EVENTS_CONFIG_FALLBACK_KEYS = frozenset({ |
| 75 | + 'force_flush', |
| 76 | + 'generation_trigger_config', |
| 77 | + 'stream_id', |
| 78 | +}) |
| 79 | + |
68 | 80 | _ENABLE_CONSOLIDATION_KEY = 'enable_consolidation' |
| 81 | + |
| 82 | + |
| 83 | +def _should_use_generate_memories( |
| 84 | + custom_metadata: Mapping[str, object] | None, |
| 85 | +) -> bool: |
| 86 | + """Returns True if custom_metadata contains keys only GenerateMemories supports. |
| 87 | +
|
| 88 | + If any key in custom_metadata is recognized by GenerateMemories but NOT by |
| 89 | + IngestEvents, the generate_memories API path is used. Otherwise |
| 90 | + ingest_events is the default. |
| 91 | + """ |
| 92 | + if not custom_metadata: |
| 93 | + return False |
| 94 | + ingest_keys = _INGEST_EVENTS_CONFIG_FALLBACK_KEYS |
| 95 | + generate_keys = _GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS |
| 96 | + for key in custom_metadata: |
| 97 | + if key not in ingest_keys and key in generate_keys: |
| 98 | + return True |
| 99 | + return False |
| 100 | + |
| 101 | + |
69 | 102 | # Vertex docs for GenerateMemoriesRequest.DirectMemoriesSource allow |
70 | 103 | # at most 5 direct_memories per request. |
71 | 104 | _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL = 5 |
@@ -203,14 +236,35 @@ async def add_events_to_memory( |
203 | 236 | session_id: str | None = None, |
204 | 237 | custom_metadata: Mapping[str, object] | None = None, |
205 | 238 | ) -> None: |
206 | | - """Adds events to Vertex AI Memory Bank via memories.generate. |
| 239 | + """Adds events to Vertex AI Memory Bank. |
| 240 | +
|
| 241 | + Uses ``memories.ingest_events`` by default. If ``custom_metadata`` contains |
| 242 | + keys supported only by ``memories.generate`` (e.g. ``ttl``, |
| 243 | + ``revision_ttl``, ``metadata``, ``wait_for_completion``), the generate path |
| 244 | + is used instead. |
207 | 245 |
|
208 | 246 | Args: |
209 | 247 | app_name: The application name for memory scope. |
210 | 248 | user_id: The user ID for memory scope. |
211 | 249 | events: The events to process for memory generation. |
212 | 250 | session_id: Optional session ID. Currently unused. |
213 | | - custom_metadata: Optional service-specific metadata for generate config. |
| 251 | + custom_metadata: Optional service-specific metadata. Supported keys |
| 252 | + depend on the API path chosen: |
| 253 | +
|
| 254 | + **IngestEvents keys** (default path): |
| 255 | + stream_id: Identifier for the event stream. |
| 256 | + force_flush: If True, forces flushing buffered events. |
| 257 | + generation_trigger_config: Configuration for triggering memory |
| 258 | + generation, e.g. |
| 259 | + ``{"generation_rule": {"idle_duration": "60s"}}``. |
| 260 | +
|
| 261 | + **GenerateMemories keys** (used when any of these are present): |
| 262 | + ttl: Time-to-live for generated memories, e.g. ``"6000s"``. |
| 263 | + revision_ttl: Time-to-live for memory revisions. |
| 264 | + metadata: A mapping of custom metadata key-value pairs. |
| 265 | + wait_for_completion: Whether to wait for generation to complete. |
| 266 | + disable_consolidation: Disable memory consolidation. |
| 267 | + disable_memory_revisions: Disable memory revisions. |
214 | 268 | """ |
215 | 269 | _ = session_id |
216 | 270 | await self._add_events_to_memory_from_events( |
@@ -260,30 +314,139 @@ async def _add_events_to_memory_from_events( |
260 | 314 | events_to_process: Sequence[Event], |
261 | 315 | custom_metadata: Mapping[str, object] | None = None, |
262 | 316 | ) -> None: |
| 317 | + # The generate_memories API is used only when custom_metadata contains |
| 318 | + # keys exclusive to GenerateMemories. Otherwise, ingest_events is the |
| 319 | + # default path, as its behavior is consistent with GenerateMemories |
| 320 | + # (trigger immediately) and supports additional parameters like |
| 321 | + # generation_trigger_config. |
| 322 | + if _should_use_generate_memories(custom_metadata): |
| 323 | + import vertexai |
| 324 | + |
| 325 | + direct_events = [] |
| 326 | + for event in events_to_process: |
| 327 | + if _should_filter_out_event(event.content): |
| 328 | + continue |
| 329 | + if event.content: |
| 330 | + direct_events.append( |
| 331 | + vertexai.types.GenerateMemoriesRequestDirectContentsSourceEvent( |
| 332 | + content=event.content |
| 333 | + ) |
| 334 | + ) |
| 335 | + if direct_events: |
| 336 | + api_client = self._get_api_client() |
| 337 | + config = _build_generate_memories_config(custom_metadata) |
| 338 | + operation = await api_client.agent_engines.memories.generate( |
| 339 | + name='reasoningEngines/' + self._agent_engine_id, |
| 340 | + direct_contents_source=vertexai.types.GenerateMemoriesRequestDirectContentsSource( |
| 341 | + events=direct_events |
| 342 | + ), |
| 343 | + scope={ |
| 344 | + 'app_name': app_name, |
| 345 | + 'user_id': user_id, |
| 346 | + }, |
| 347 | + config=config, |
| 348 | + ) |
| 349 | + logger.info('Generate memory response received.') |
| 350 | + logger.debug('Generate memory response: %s', operation) |
| 351 | + else: |
| 352 | + logger.info('No events to add to memory.') |
| 353 | + return |
| 354 | + |
| 355 | + await self._add_events_to_memory_via_ingest( |
| 356 | + app_name=app_name, |
| 357 | + user_id=user_id, |
| 358 | + events_to_process=events_to_process, |
| 359 | + custom_metadata=custom_metadata, |
| 360 | + ) |
| 361 | + |
| 362 | + async def _add_events_to_memory_via_ingest( |
| 363 | + self, |
| 364 | + *, |
| 365 | + app_name: str, |
| 366 | + user_id: str, |
| 367 | + events_to_process: Sequence[Event], |
| 368 | + custom_metadata: Mapping[str, object] | None = None, |
| 369 | + ) -> None: |
| 370 | + """Adds events to Vertex AI Memory Bank via memories.ingest_events. |
| 371 | +
|
| 372 | + Args: |
| 373 | + app_name: The application name for memory scope. |
| 374 | + user_id: The user ID for memory scope. |
| 375 | + events_to_process: The events to process for memory ingestion. |
| 376 | + custom_metadata: Optional service-specific metadata. Supported keys: |
| 377 | + stream_id: Identifier for the event stream. |
| 378 | + force_flush: If True, forces flushing buffered events (passed as |
| 379 | + part of the ingest_events config). |
| 380 | + generation_trigger_config: Configuration for triggering memory |
| 381 | + generation, e.g. |
| 382 | + ``{"generation_rule": {"idle_duration": "60s"}}``. |
| 383 | + """ |
| 384 | + import vertexai |
| 385 | + |
263 | 386 | direct_events = [] |
264 | 387 | for event in events_to_process: |
265 | 388 | if _should_filter_out_event(event.content): |
266 | 389 | continue |
267 | 390 | if event.content: |
268 | | - direct_events.append({ |
269 | | - 'content': event.content.model_dump(exclude_none=True, mode='json') |
270 | | - }) |
| 391 | + event_time = None |
| 392 | + if event.timestamp is not None: |
| 393 | + event_time = datetime.datetime.fromtimestamp( |
| 394 | + event.timestamp, tz=datetime.timezone.utc |
| 395 | + ) |
| 396 | + direct_events.append( |
| 397 | + vertexai.types.IngestionDirectContentsSourceEvent( |
| 398 | + content=event.content, |
| 399 | + event_id=event.id, |
| 400 | + event_time=event_time, |
| 401 | + ) |
| 402 | + ) |
| 403 | + |
| 404 | + api_client = self._get_api_client() |
| 405 | + |
| 406 | + stream_id = custom_metadata.get('stream_id') if custom_metadata else None |
| 407 | + force_flush = ( |
| 408 | + custom_metadata.get('force_flush') if custom_metadata else None |
| 409 | + ) |
| 410 | + generation_trigger_config = ( |
| 411 | + custom_metadata.get('generation_trigger_config') |
| 412 | + if custom_metadata |
| 413 | + else None |
| 414 | + ) |
| 415 | + |
| 416 | + request_kwargs: dict[str, object] = { |
| 417 | + 'name': 'reasoningEngines/' + self._agent_engine_id, |
| 418 | + 'scope': { |
| 419 | + 'app_name': app_name, |
| 420 | + 'user_id': user_id, |
| 421 | + }, |
| 422 | + } |
| 423 | + # No-events requests are valid for trigger config updates, but |
| 424 | + # won't trigger an events flush. |
271 | 425 | if direct_events: |
272 | | - api_client = self._get_api_client() |
273 | | - config = _build_generate_memories_config(custom_metadata) |
274 | | - operation = await api_client.agent_engines.memories.generate( |
275 | | - name='reasoningEngines/' + self._agent_engine_id, |
276 | | - direct_contents_source={'events': direct_events}, |
277 | | - scope={ |
278 | | - 'app_name': app_name, |
279 | | - 'user_id': user_id, |
280 | | - }, |
281 | | - config=config, |
| 426 | + request_kwargs['direct_contents_source'] = ( |
| 427 | + vertexai.types.IngestionDirectContentsSource(events=direct_events) |
282 | 428 | ) |
283 | | - logger.info('Generate memory response received.') |
284 | | - logger.debug('Generate memory response: %s', operation) |
285 | | - else: |
286 | | - logger.info('No events to add to memory.') |
| 429 | + if stream_id: |
| 430 | + request_kwargs['stream_id'] = stream_id |
| 431 | + # force_flush is part of the ingest_events config, not a |
| 432 | + # top-level request parameter. |
| 433 | + config: dict[str, object] = {} |
| 434 | + if force_flush is not None: |
| 435 | + config['force_flush'] = force_flush |
| 436 | + if config: |
| 437 | + request_kwargs['config'] = config |
| 438 | + if generation_trigger_config: |
| 439 | + request_kwargs['generation_trigger_config'] = generation_trigger_config |
| 440 | + |
| 441 | + # Fire the ingest request without blocking. IngestEvents latency |
| 442 | + # (~800ms to trigger) makes awaiting unnecessary outside debugging. |
| 443 | + task = asyncio.create_task( |
| 444 | + api_client.agent_engines.memories.ingest_events(**request_kwargs) |
| 445 | + ) |
| 446 | + _background_tasks.add(task) |
| 447 | + task.add_done_callback(_background_tasks.discard) |
| 448 | + task.add_done_callback(_log_ingest_task_error) |
| 449 | + logger.info('Ingest events request triggered.') |
287 | 450 |
|
288 | 451 | async def _add_memories_via_create( |
289 | 452 | self, |
@@ -402,12 +565,31 @@ def _get_api_client(self) -> vertexai.AsyncClient: |
402 | 565 | return vertexai.Client(project=self._project, location=self._location).aio |
403 | 566 |
|
404 | 567 |
|
| 568 | +def _log_ingest_task_error(task: asyncio.Task) -> None: |
| 569 | + """Logs errors from fire-and-forget ingest_events tasks.""" |
| 570 | + if task.cancelled(): |
| 571 | + return |
| 572 | + exception = task.exception() |
| 573 | + if exception: |
| 574 | + logger.error('Background ingest_events task failed: %s', exception) |
| 575 | + |
| 576 | + |
405 | 577 | def _should_filter_out_event(content: types.Content) -> bool: |
406 | 578 | """Returns whether the event should be filtered out.""" |
407 | 579 | if not content or not content.parts: |
408 | 580 | return True |
409 | 581 | for part in content.parts: |
410 | | - if part.text or part.inline_data or part.file_data: |
| 582 | + if ( |
| 583 | + part.text |
| 584 | + or part.inline_data |
| 585 | + or part.file_data |
| 586 | + or part.function_call |
| 587 | + or part.function_response |
| 588 | + or part.executable_code |
| 589 | + or part.code_execution_result |
| 590 | + or part.tool_call |
| 591 | + or part.tool_response |
| 592 | + ): |
411 | 593 | return False |
412 | 594 | return True |
413 | 595 |
|
@@ -742,7 +924,7 @@ def _to_vertex_metadata_value( |
742 | 924 | return {'double_value': float(value)} |
743 | 925 | if isinstance(value, str): |
744 | 926 | return {'string_value': value} |
745 | | - if isinstance(value, datetime): |
| 927 | + if isinstance(value, datetime.datetime): |
746 | 928 | return {'timestamp_value': value} |
747 | 929 | if isinstance(value, Mapping): |
748 | 930 | if value.keys() <= { |
|
0 commit comments