Commit 44d436b
authored
Custom op to update cache for torch.cond (#16366)
torch.cond doesn't take aliasing or mutations. Adding 2 ops for
supporting conditionally updating kv cache:
* `executorch::alias`: takes 2 tensors and return the same 2 tensors.
* `executorch::update_cross_attn_cache`: takes a tensor `cache` and a
tensor `value`, in place copy `value` into `cache`.
With these 2 ops, we can rewrite the model definition from:
```py
if is_cross_attention and past_key_values and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_values.layers[self.layer_idx].keys
value_states = past_key_values.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
if past_key_values is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
```
Into:
```py
def use_cached_kv(
cached_keys: Tensor,
cached_values: Tensor,
key_value_states: Tensor,
) -> tuple[Tensor, Tensor]:
# Just reuse cached K/V
return torch.ops.executorch.alias(cached_keys, cached_values)
def recompute_kv(
cached_keys: Tensor, # unused
cached_values: Tensor, # unused
key_value_states: Tensor,
) -> tuple[Tensor, Tensor]:
# Compute fresh K/V (export-friendly: use custom op to mutate cache)
key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
k = torch.ops.executorch.update_cross_attn_cache(key_states, cached_keys)
v = torch.ops.executorch.update_cross_attn_cache(value_states, cached_values)
return k, v
if past_key_values is not None and self.layer_idx is not None:
# Grab cached tensors (these are Tensors, so they are OK for export)
cached_keys = past_key_values.layers[self.layer_idx].keys
cached_values = past_key_values.layers[self.layer_idx].values
# Tensor predicate: True if any element is non-zero
# Result is a 0-dim bool tensor suitable for torch.cond
cache_is_initialized = (cached_keys != 0).any()
# Use torch.cond to select branch in a traceable way.
# All operands must be (nested) tensors or simple Python values.
key_states, value_states = torch.cond(
cache_is_initialized,
use_cached_kv,
recompute_kv,
operands=(cached_keys, cached_values, key_value_states),
)
```1 parent 1ce615e commit 44d436b
File tree
6 files changed
+436
-18
lines changed- .ci/docker/ci_commit_pins
- extension/llm/custom_ops
- runtime/core/portable_type/c10/c10/util
6 files changed
+436
-18
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
| 1 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
| 16 | + | |
15 | 17 | | |
16 | 18 | | |
| 19 | + | |
| 20 | + | |
17 | 21 | | |
18 | 22 | | |
| 23 | + | |
| 24 | + | |
19 | 25 | | |
20 | 26 | | |
21 | 27 | | |
| |||
387 | 393 | | |
388 | 394 | | |
389 | 395 | | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
0 commit comments