Skip to content

Commit 452a9fa

Browse files
feat: add async_creation option to ContextCacheConfig
CachedContent.create() API calls can take 30-40 seconds, blocking the user's request when a cache needs to be recreated. This adds an `async_creation` config option that defers cache creation to a background asyncio task, letting the current request proceed uncached while the cache is built for the next request. When async_creation=False (default), behavior is completely unchanged.
1 parent 6770e41 commit 452a9fa

3 files changed

Lines changed: 249 additions & 13 deletions

File tree

src/google/adk/agents/context_cache_config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ class ContextCacheConfig(BaseModel):
7272
),
7373
)
7474

75+
async_creation: bool = Field(
76+
default=False,
77+
description=(
78+
"When True, cache creation is performed in the background instead of"
79+
" blocking the current request. The current request proceeds uncached"
80+
" and the cache is available for the next request. This eliminates"
81+
" latency spikes from slow CachedContent.create() API calls."
82+
),
83+
)
84+
7585
@property
7686
def ttl_string(self) -> str:
7787
"""Get TTL as string format for cache creation."""
@@ -81,5 +91,6 @@ def __str__(self) -> str:
8191
"""String representation for logging."""
8292
return (
8393
f"ContextCacheConfig(cache_intervals={self.cache_intervals}, "
84-
f"ttl={self.ttl_seconds}s, min_tokens={self.min_tokens})"
94+
f"ttl={self.ttl_seconds}s, min_tokens={self.min_tokens}, "
95+
f"async_creation={self.async_creation})"
8596
)

src/google/adk/models/gemini_context_cache_manager.py

Lines changed: 235 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import asyncio
1920
import hashlib
2021
import json
2122
import logging
@@ -32,6 +33,72 @@
3233

3334
logger = logging.getLogger("google_adk." + __name__)
3435

36+
# Background cache task registry for async_creation mode.
37+
# Key: (model, fingerprint, contents_count)
38+
# Value: asyncio.Task that resolves to Optional[CacheMetadata]
39+
_pending_cache_tasks: dict[tuple[str, str, int], asyncio.Task] = {}
40+
41+
42+
def _cache_task_key(
43+
model: str, fingerprint: str, contents_count: int
44+
) -> tuple[str, str, int]:
45+
"""Build a registry key for a pending cache task."""
46+
return (model, fingerprint, contents_count)
47+
48+
49+
def _check_pending_cache(
50+
key: tuple[str, str, int],
51+
) -> Optional[CacheMetadata]:
52+
"""Check if a background cache task completed successfully.
53+
54+
Returns the CacheMetadata if the task is done and succeeded,
55+
None if the task is still running or failed.
56+
Cleans up the registry entry in either done case.
57+
"""
58+
task = _pending_cache_tasks.get(key)
59+
if task is None:
60+
return None
61+
62+
if not task.done():
63+
return None
64+
65+
# Task is done - remove from registry regardless of outcome
66+
del _pending_cache_tasks[key]
67+
68+
if task.cancelled():
69+
logger.warning("Background cache task was cancelled for key %s", key)
70+
return None
71+
72+
exc = task.exception()
73+
if exc is not None:
74+
logger.warning("Background cache task failed for key %s: %s", key, exc)
75+
return None
76+
77+
return task.result()
78+
79+
80+
def _cleanup_stale_tasks() -> None:
81+
"""Remove completed tasks that have been sitting unclaimed."""
82+
to_remove = []
83+
for key, task in _pending_cache_tasks.items():
84+
if not task.done():
85+
continue
86+
result = None
87+
try:
88+
if not task.cancelled():
89+
exc = task.exception()
90+
if exc is None:
91+
result = task.result()
92+
except Exception:
93+
pass
94+
if result is None or (
95+
result.expire_time is not None and time.time() >= result.expire_time
96+
):
97+
to_remove.append(key)
98+
for key in to_remove:
99+
del _pending_cache_tasks[key]
100+
101+
35102
if TYPE_CHECKING:
36103
from google.genai import Client
37104

@@ -62,13 +129,55 @@ async def handle_context_caching(
62129
the cache to the request by setting cached_content and removing cached
63130
contents from the request.
64131
132+
When async_creation is enabled in the cache config, cache creation is
133+
performed in the background instead of blocking the current request.
134+
65135
Args:
66136
llm_request: Request that may contain cache config and metadata.
67137
Modified in-place to use the cache.
68138
69139
Returns:
70140
Cache metadata to be included in response, or None if caching failed
71141
"""
142+
async_creation = (
143+
llm_request.cache_config
144+
and llm_request.cache_config.async_creation
145+
)
146+
147+
# Opportunistically clean up stale background tasks
148+
if async_creation:
149+
_cleanup_stale_tasks()
150+
151+
# Check for completed background cache creation (async_creation mode)
152+
if (
153+
async_creation
154+
and llm_request.cache_metadata
155+
and llm_request.cache_metadata.cache_name is None
156+
):
157+
fp = llm_request.cache_metadata.fingerprint
158+
cc = llm_request.cache_metadata.contents_count
159+
model = llm_request.model or ""
160+
key = _cache_task_key(model, fp, cc)
161+
bg_result = _check_pending_cache(key)
162+
if bg_result and bg_result.cache_name:
163+
if time.time() < bg_result.expire_time:
164+
logger.info(
165+
"Using background-created cache: %s",
166+
bg_result.cache_name,
167+
)
168+
self._apply_cache_to_request(
169+
llm_request,
170+
bg_result.cache_name,
171+
bg_result.contents_count,
172+
)
173+
return bg_result
174+
else:
175+
logger.info(
176+
"Background-created cache already expired: %s",
177+
bg_result.cache_name,
178+
)
179+
await self.cleanup_cache(bg_result.cache_name)
180+
72181
# Check if we have existing cache metadata and if it's valid
73182
if llm_request.cache_metadata:
74183
logger.debug(
@@ -107,17 +216,37 @@ async def handle_context_caching(
107216

108217
# If fingerprints match, create new cache (expired but same content)
109218
if current_fingerprint == old_cache_metadata.fingerprint:
110-
logger.debug(
111-
"Fingerprints match after invalidation, creating new cache"
112-
)
113-
cache_metadata = await self._create_new_cache_with_contents(
114-
llm_request, cache_contents_count
115-
)
116-
if cache_metadata:
117-
self._apply_cache_to_request(
118-
llm_request, cache_metadata.cache_name, cache_contents_count
219+
if async_creation:
220+
# Launch background cache creation and proceed uncached
221+
key = _cache_task_key(
222+
llm_request.model or "",
223+
current_fingerprint,
224+
cache_contents_count,
225+
)
226+
self._launch_background_cache(
227+
key, llm_request, cache_contents_count
228+
)
229+
logger.debug(
230+
"Async cache creation launched, proceeding uncached"
119231
)
120-
return cache_metadata
232+
return CacheMetadata(
233+
fingerprint=current_fingerprint,
234+
contents_count=cache_contents_count,
235+
)
236+
else:
237+
logger.debug(
238+
"Fingerprints match after invalidation, creating new cache"
239+
)
240+
cache_metadata = await self._create_new_cache_with_contents(
241+
llm_request, cache_contents_count
242+
)
243+
if cache_metadata:
244+
self._apply_cache_to_request(
245+
llm_request,
246+
cache_metadata.cache_name,
247+
cache_contents_count,
248+
)
249+
return cache_metadata
121250

122251
# Fingerprints don't match - recalculate with total contents
123252
logger.debug(
@@ -127,6 +256,18 @@ async def handle_context_caching(
127256
fingerprint_for_all = self._generate_cache_fingerprint(
128257
llm_request, total_contents_count
129258
)
259+
260+
if async_creation and total_contents_count > 0:
261+
# Launch background cache creation for the new fingerprint
262+
key = _cache_task_key(
263+
llm_request.model or "",
264+
fingerprint_for_all,
265+
total_contents_count,
266+
)
267+
self._launch_background_cache(
268+
key, llm_request, total_contents_count
269+
)
270+
130271
return CacheMetadata(
131272
fingerprint=fingerprint_for_all,
132273
contents_count=total_contents_count,
@@ -146,6 +287,90 @@ async def handle_context_caching(
146287
contents_count=total_contents_count,
147288
)
148289

290+
def _launch_background_cache(
291+
self,
292+
key: tuple[str, str, int],
293+
llm_request: LlmRequest,
294+
contents_count: int,
295+
) -> None:
296+
"""Launch cache creation as a background asyncio task.
297+
298+
Creates a snapshot of the request data needed for cache creation,
299+
then fires off the creation in a background task.
300+
301+
Args:
302+
key: Registry key for the pending task
303+
llm_request: Request to create cache for (will be snapshotted)
304+
contents_count: Number of contents to cache
305+
"""
306+
if key in _pending_cache_tasks:
307+
task = _pending_cache_tasks[key]
308+
if not task.done():
309+
logger.debug(
310+
"Background cache creation already in progress for key %s",
311+
key,
312+
)
313+
return
314+
del _pending_cache_tasks[key]
315+
316+
# Snapshot the request data before it gets mutated
317+
snapshot = self._snapshot_request(llm_request, contents_count)
318+
genai_client = self.genai_client
319+
320+
async def _do_create() -> Optional[CacheMetadata]:
321+
mgr = GeminiContextCacheManager(genai_client)
322+
return await mgr._create_new_cache_with_contents(
323+
snapshot, contents_count
324+
)
325+
326+
loop = asyncio.get_running_loop()
327+
task = loop.create_task(
328+
_do_create(),
329+
name=f"bg-cache-{key[1][:8]}",
330+
)
331+
_pending_cache_tasks[key] = task
332+
logger.info("Launched background cache creation for key %s", key)
333+
334+
def _snapshot_request(
335+
self,
336+
llm_request: LlmRequest,
337+
contents_count: int,
338+
) -> LlmRequest:
339+
"""Create a minimal snapshot of the request for background cache creation.
340+
341+
Captures only the fields that _create_gemini_cache needs, so the
342+
background task is not affected by mutations to the original request.
343+
344+
Args:
345+
llm_request: Original request to snapshot
346+
contents_count: Number of contents to include
347+
348+
Returns:
349+
A new LlmRequest with just the fields needed for cache creation
350+
"""
351+
config = types.GenerateContentConfig(
352+
system_instruction=(
353+
llm_request.config.system_instruction
354+
if llm_request.config
355+
else None
356+
),
357+
tools=(
358+
llm_request.config.tools if llm_request.config else None
359+
),
360+
tool_config=(
361+
llm_request.config.tool_config if llm_request.config else None
362+
),
363+
)
364+
return LlmRequest(
365+
model=llm_request.model,
366+
contents=list(llm_request.contents[:contents_count]),
367+
config=config,
368+
cache_config=llm_request.cache_config,
369+
cacheable_contents_token_count=(
370+
llm_request.cacheable_contents_token_count
371+
),
372+
)
373+
149374
def _find_count_of_contents_to_cache(
150375
self, contents: list[types.Content]
151376
) -> int:

tests/unittests/agents/test_context_cache_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,15 @@ def test_str_representation(self):
106106
)
107107

108108
expected = (
109-
"ContextCacheConfig(cache_intervals=15, ttl=3600s, min_tokens=1024)"
109+
"ContextCacheConfig(cache_intervals=15, ttl=3600s, min_tokens=1024, async_creation=False)"
110110
)
111111
assert str(config) == expected
112112

113113
def test_str_representation_defaults(self):
114114
"""Test string representation with default values."""
115115
config = ContextCacheConfig()
116116

117-
expected = "ContextCacheConfig(cache_intervals=10, ttl=1800s, min_tokens=0)"
117+
expected = "ContextCacheConfig(cache_intervals=10, ttl=1800s, min_tokens=0, async_creation=False)"
118118
assert str(config) == expected
119119

120120
def test_pydantic_model_validation(self):

0 commit comments

Comments
 (0)