Skip to content

Commit 2827e32

Browse files
committed
Move torch.cond predicate non-persistent buffer to CPU
Avoid device-to-host memory copies when evaluating `torch.cond` predicates. When a GPU buffer (e.g., a KV cache `initialized` flag) is used as a predicate for `torch.cond`, the runtime must synchronize and copy the predicate value from GPU to CPU on every forward pass to evaluate the condition. This adds latency and synchronization overhead. `MoveCondPredicateToCpuPass` moves non-persistent buffer predicates to CPU at export time, eliminating per-inference D2H transfers. The predicate is typically a small scalar (e.g., a boolean flag), so keeping it on CPU has negligible memory impact. - Add `MoveCondPredicateToCpuPass` in `backends/cuda/passes/` - Add unit tests covering: - GPU buffer predicates moved to CPU - CPU buffer predicates unchanged - Computed predicates unaffected - Multiple `torch.cond` calls - Cross-attention cache pattern - Persistent buffers (state_dict) not moved - Add Python tests to `unittest-cuda` CI job in `cuda.yml` ghstack-source-id: b439eb3 ghstack-comment-id: 3687889864 Pull-Request: #16378
1 parent c5d66a5 commit 2827e32

File tree

7 files changed

+662
-7
lines changed

7 files changed

+662
-7
lines changed

.github/workflows/cuda.yml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ jobs:
8787
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
8888
PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda
8989
90-
test-cuda-shims:
91-
name: test-cuda-shims
90+
unittest-cuda:
91+
name: unittest-cuda
9292
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
9393
permissions:
9494
id-token: write
@@ -103,17 +103,20 @@ jobs:
103103
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
104104
script: |
105105
set -eux
106-
# Install requirements
107-
bash ./install_requirements.sh
106+
# Install executorch in editable mode so custom op libs land in-tree
107+
bash ./install_executorch.sh
108108
109109
# Build ExecuTorch with CUDA support
110110
cmake --workflow --preset llm-release-cuda
111111
112-
# Build and run CUDA shim tests
112+
# Build and run CUDA shim tests (C++)
113113
pushd backends/cuda/runtime/shims/tests
114114
cmake --workflow --preset default
115115
popd
116116
117+
# Run CUDA backend Python tests, overrides addopts so that we don't run all tests in pytest.ini
118+
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
119+
117120
export-model-cuda-artifact:
118121
name: export-model-cuda-artifact
119122
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)

backends/aoti/aoti_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,10 @@ def preprocess(
156156
# Apply custom backend-specific passes
157157
custom_passes = cls.get_custom_passes(compile_specs)
158158
for custom_pass in custom_passes:
159-
custom_pass(device_edge_program.graph_module)
159+
if getattr(custom_pass, "requires_exported_program", False):
160+
custom_pass(device_edge_program)
161+
else:
162+
custom_pass(device_edge_program.graph_module)
160163

161164
# Run decompositions if any
162165
if decomposition_table:

backends/cuda/cuda_backend.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
import torch
1414
from executorch.backends.aoti.aoti_backend import AotiBackend
15+
from executorch.backends.cuda.passes.move_cond_predicate_to_cpu import (
16+
MoveCondPredicateToCpuPass,
17+
)
1518
from executorch.backends.cuda.triton.replacement_pass import (
1619
ReplaceEdgeOpWithTritonOpPass,
1720
)
@@ -155,7 +158,10 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
155158
)
156159
triton_kernel_mode = mode
157160

158-
return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
161+
passes = [MoveCondPredicateToCpuPass()]
162+
if triton_kernel_mode == "ON":
163+
passes.append(ReplaceEdgeOpWithTritonOpPass())
164+
return passes
159165

160166
@classmethod
161167
def get_aoti_compile_options(

backends/cuda/passes/__init__.py

Whitespace-only changes.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch.export import ExportedProgram
9+
10+
11+
class MoveCondPredicateToCpuPass:
12+
"""
13+
A pass that moves the predicate of torch.cond to CPU if the predicate is a constantbuffer.
14+
This is useful for models that use the predicate as a constant buffer, such as an `initialized` flag for cross attention kv cache.
15+
16+
This saves ~50us per torch.cond call on RTX 5080.
17+
18+
Example:
19+
```
20+
class CrossAttentionWithCache(torch.nn.Module):
21+
def __init__(self, hidden_size):
22+
super().__init__()
23+
self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
24+
self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
25+
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
26+
self.out_proj = torch.nn.Linear(hidden_size, hidden_size)
27+
# Buffer used as predicate for torch.cond
28+
self.register_buffer("initialized", torch.tensor([False]), persistent=False)
29+
self.register_buffer("k_cache", torch.zeros(1, 10, hidden_size), persistent=False)
30+
self.register_buffer("v_cache", torch.zeros(1, 10, hidden_size), persistent=False)
31+
32+
def compute_kv(self, encoder_hidden_states):
33+
k = self.k_proj(encoder_hidden_states)
34+
v = self.v_proj(encoder_hidden_states)
35+
self.k_cache.copy_(k)
36+
self.v_cache.copy_(v)
37+
self.initialized.fill_(True)
38+
return k, v
39+
40+
def use_cached_kv(self, encoder_hidden_states):
41+
return self.k_cache.clone(), self.v_cache.clone()
42+
43+
def forward(self, hidden_states, encoder_hidden_states):
44+
q = self.q_proj(hidden_states)
45+
# Use torch.cond with initialized buffer as predicate
46+
k, v = torch.cond(
47+
self.initialized,
48+
self.use_cached_kv,
49+
self.compute_kv,
50+
(encoder_hidden_states,),
51+
)
52+
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
53+
return self.out_proj(attn_output)
54+
```
55+
In this example if we keep `self.initialized` on GPU, we will need to copy it to CPU for every forward pass.
56+
We move the predicate to CPU to avoid device to host copies.
57+
This pass is only applicable to models that use torch.cond and its predicate is a constant buffer.
58+
"""
59+
60+
requires_exported_program = True
61+
62+
def __call__(self, exported_program: ExportedProgram):
63+
graph_module = exported_program.graph_module
64+
65+
# Map input names to buffer names
66+
inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers
67+
68+
for node in graph_module.graph.nodes:
69+
if (
70+
node.op == "call_function"
71+
and node.target == torch.ops.higher_order.cond
72+
):
73+
pred_node = node.args[0]
74+
if (
75+
pred_node.op == "placeholder"
76+
and pred_node.name in inputs_to_buffers
77+
):
78+
buffer_name = inputs_to_buffers[pred_node.name]
79+
80+
if buffer_name in exported_program.constants:
81+
tensor = exported_program._constants[buffer_name]
82+
if tensor.device.type != "cpu":
83+
exported_program._constants[buffer_name] = tensor.to("cpu")
84+
85+
# Also update the placeholder metadata
86+
if "val" in pred_node.meta:
87+
fake_tensor = pred_node.meta["val"]
88+
if isinstance(fake_tensor, torch.Tensor):
89+
pred_node.meta["val"] = fake_tensor.to("cpu")
90+
exported_program.validate()

backends/cuda/passes/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)