Skip to content

Commit 44d436b

Browse files
authored
Custom op to update cache for torch.cond (#16366)
torch.cond doesn't take aliasing or mutations. Adding 2 ops for supporting conditionally updating kv cache: * `executorch::alias`: takes 2 tensors and return the same 2 tensors. * `executorch::update_cross_attn_cache`: takes a tensor `cache` and a tensor `value`, in place copy `value` into `cache`. With these 2 ops, we can rewrite the model definition from: ```py if is_cross_attention and past_key_values and is_updated: # reuse k,v, cross_attentions key_states = past_key_values.layers[self.layer_idx].keys value_states = past_key_values.layers[self.layer_idx].values else: key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) ``` Into: ```py def use_cached_kv( cached_keys: Tensor, cached_values: Tensor, key_value_states: Tensor, ) -> tuple[Tensor, Tensor]: # Just reuse cached K/V return torch.ops.executorch.alias(cached_keys, cached_values) def recompute_kv( cached_keys: Tensor, # unused cached_values: Tensor, # unused key_value_states: Tensor, ) -> tuple[Tensor, Tensor]: # Compute fresh K/V (export-friendly: use custom op to mutate cache) key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() k = torch.ops.executorch.update_cross_attn_cache(key_states, cached_keys) v = torch.ops.executorch.update_cross_attn_cache(value_states, cached_values) return k, v if past_key_values is not None and self.layer_idx is not None: # Grab cached tensors (these are Tensors, so they are OK for export) cached_keys = past_key_values.layers[self.layer_idx].keys cached_values = past_key_values.layers[self.layer_idx].values # Tensor predicate: True if any element is non-zero # Result is a 0-dim bool tensor suitable for torch.cond cache_is_initialized = (cached_keys != 0).any() # Use torch.cond to select branch in a traceable way. # All operands must be (nested) tensors or simple Python values. key_states, value_states = torch.cond( cache_is_initialized, use_cached_kv, recompute_kv, operands=(cached_keys, cached_values, key_value_states), ) ```
1 parent 1ce615e commit 44d436b

File tree

6 files changed

+436
-18
lines changed

6 files changed

+436
-18
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
7a064ed3eafa43f17412d434b395240c727b3000
1+
7a79b41e29a790ebb4b530eb98a89381e2d7de29

extension/llm/custom_ops/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,16 @@ runtime.python_test(
6363
"//executorch/extension/pybindings:portable_lib",
6464
],
6565
)
66+
67+
runtime.python_test(
68+
name = "test_update_cross_attn_cache",
69+
srcs = [
70+
"test_update_cross_attn_cache.py",
71+
],
72+
preload_deps = [
73+
":custom_ops_aot_py",
74+
],
75+
deps = [
76+
"//caffe2:torch",
77+
],
78+
)

extension/llm/custom_ops/custom_ops.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,16 @@
1212

1313
import logging
1414

15+
from typing import Tuple
16+
1517
import torch
1618

19+
from torch._inductor.lowering import lowerings as L, register_lowering
20+
1721
from torch.library import impl
1822

23+
aten = torch.ops.aten
24+
1925
try:
2026
op = torch.ops.llama.sdpa_with_kv_cache.default
2127
assert op is not None
@@ -387,3 +393,103 @@ def custom_quantized_sdpa_meta(
387393
)
388394

389395
return torch.empty(query.size(), dtype=torch.float32, device="meta")
396+
397+
398+
# 1) Define the custom op in the "executorch" namespace with name "alias"
399+
@torch.library.custom_op("executorch::alias", mutates_args=())
400+
def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
401+
# no copies, just pass-through
402+
return x, y
403+
404+
405+
# 2) FakeTensor kernel: describes output metadata for compile-time
406+
@custom_alias.register_fake
407+
def _(x, y):
408+
# For this op, outputs have exactly the same shape/dtype/device as inputs.
409+
# We just need *dummy* tensors with that metadata.
410+
out_x = torch.empty_like(x)
411+
out_y = torch.empty_like(y)
412+
return out_x, out_y
413+
414+
415+
@register_lowering(torch.ops.executorch.alias.default)
416+
def lowering_custom_alias(x, y):
417+
# x, y here are IR values (Inductor's internal representation).
418+
# Alias is logically a no-op – just pass them through.
419+
return x, y
420+
421+
422+
# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max
423+
def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor):
424+
torch._assert(value.dim() == 4, "value must be 4D")
425+
torch._assert(cache.dim() == 4, "cache must be 4D")
426+
# Cache shape: (B, H, S_max, D)
427+
# Value shape: (B, H, S, D)
428+
torch._assert(
429+
value.size(2) <= cache.size(2),
430+
f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}",
431+
)
432+
torch._assert(value.size(0) == cache.size(0), "batch size mismatch")
433+
torch._assert(value.size(1) == cache.size(1), "num heads mismatch")
434+
torch._assert(value.size(3) == cache.size(3), "head dim mismatch")
435+
torch._assert(value.dtype == cache.dtype, "dtype mismatch")
436+
437+
438+
# Intentionally declaring no mutations to enable use inside torch.cond branches,
439+
# which require pure functions. torch.cond requires branch functions to be mutation-free.
440+
# We omit `cache` from `mutates_args` to satisfy this constraint, accepting the
441+
# mutation for inference use.
442+
@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[])
443+
def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
444+
"""
445+
Update cross-attention KV cache with new values.
446+
447+
Copies the value tensor into the beginning of the cache tensor along the
448+
sequence dimension. This is used for cross-attention caching where the
449+
encoder outputs are computed once and reused across decoding steps.
450+
451+
Args:
452+
value: New values to store in cache. Shape: [B, H, S, D] where
453+
B = batch size, H = num heads, S = sequence length, D = head dim.
454+
cache: Pre-allocated cache tensor to update. Shape: [B, H, S_max, D]
455+
where S_max >= S.
456+
457+
Returns:
458+
A clone of the updated cache tensor. Note that this is different from
459+
inductor lowering which returns the cache tensor itself. The reason is
460+
that if we return input buffer directly, we will fail torch check in
461+
higher order ops.
462+
463+
Note:
464+
The cache is mutated in-place, but we return a clone to avoid aliasing
465+
issues with the exported program.
466+
"""
467+
_validate_cross_attn_cache_params(value, cache)
468+
cache[:, :, : value.size(2), :].copy_(value)
469+
return cache.clone()
470+
471+
472+
# Register the fake (meta) kernel
473+
@_update_cross_attn_cache.register_fake
474+
def _update_cross_attn_cache_fake(
475+
value: torch.Tensor, cache: torch.Tensor
476+
) -> torch.Tensor:
477+
_validate_cross_attn_cache_params(value, cache)
478+
return torch.empty_like(cache)
479+
480+
481+
# Register Inductor lowering
482+
@register_lowering(torch.ops.executorch.update_cross_attn_cache)
483+
def _update_cross_attn_cache_lowering(value, cache):
484+
# cache shape: [B, H, S_max, D]
485+
# value shape: [B, H, S, D]
486+
487+
# We need to slice the cache along dim 2 (sequence length)
488+
# slice(self, dim, start, end, step=1)
489+
seq_len = value.get_size()[2]
490+
cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1)
491+
492+
# Copy value into the slice
493+
L[aten.copy_.default](cache_slice, value)
494+
495+
return cache

0 commit comments

Comments
 (0)