Skip to content

Commit a8be026

Browse files
committed
[Template/SPDA] Cleanup test case + Add an activate option
1 parent e925ae4 commit a8be026

5 files changed

Lines changed: 181 additions & 582 deletions

File tree

PyTorchSimDevice/csrc/aten/native/Extra.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,38 @@ int64_t _fused_sdp_choice(
2020
std::optional<double> scale,
2121
bool enable_gqa) {
2222

23-
auto backend = sdp::SDPBackend::overrideable;
24-
return static_cast<int64_t>(backend);
23+
sdp::sdp_params params{query, key, value, attn_mask, dropout_p, is_causal, enable_gqa};
24+
25+
// Reject inputs that are fundamentally unsupported (e.g. wrong rank)
26+
if (!sdp::check_tensor_shapes(params, /*debug=*/false)) {
27+
return static_cast<int64_t>(sdp::SDPBackend::error);
28+
}
29+
30+
// q: (B, Hq, L, E) k/v: (B, H, S, E)
31+
const int64_t Hq = query.size(-3);
32+
const int64_t H = key.size(-3);
33+
const int64_t L = query.size(-2); // query sequence length
34+
const int64_t S = key.size(-2); // key/value sequence length
35+
36+
// Conditions required by the MLIR FlashSDPA kernel:
37+
// Prefill only : L == S (decode has L == 1, not supported)
38+
// Non-GQA : Hq == H (equal query and KV heads)
39+
// No dropout : template has no dropout implementation
40+
// Dense tensors : no nested tensor support
41+
const bool can_use_mlir_flash =
42+
(L == S) &&
43+
(Hq == H) && !enable_gqa &&
44+
sdp::check_for_dropout(params, /*debug=*/false) &&
45+
sdp::check_nested_tensor(params, /*debug=*/false);
46+
47+
const bool ctx_flash = at::globalContext().userEnabledFlashSDP();
48+
const bool ctx_math = at::globalContext().userEnabledMathSDP();
49+
50+
if (ctx_overrideable && can_use_mlir_flash) {
51+
return static_cast<int64_t>(sdp::SDPBackend::overrideable);
52+
}
53+
54+
return static_cast<int64_t>(sdp::SDPBackend::math);
2555
}
2656

2757
void quantize_tensor_per_tensor_affine_stub(

PyTorchSimDevice/torch_openreg/openreg/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def _lazy_init():
7373
register_interface_for_device(custom_device(), ExtensionDeviceInterface)
7474
_initialized = True
7575

76+
# Set default SDPA backend to math-only for this device.
77+
torch._C._set_sdp_use_flash(False)
78+
torch._C._set_sdp_use_overrideable(False)
79+
torch._C._set_sdp_use_math(True)
80+
7681
# Create default streams for all devices
7782
num_devices = device_count()
7883
for device_idx in range(num_devices):

PyTorchSimFrontend/mlir/mlir_lowering.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate, MLIRStableSortTemplate
2121
from PyTorchSimFrontend.mlir.mlir_sdpa_template import (
2222
MLIRFlashSDPATemplate,
23-
MLIRDecodeGQASDPAPartialTemplate,
24-
MLIRDecodeGQASDPAReduceTemplate,
2523
flash_sdpa_args,
2624
calculate_scale,
2725
)
@@ -51,56 +49,27 @@ def tuned_bmm(mat1, mat2, *, layout=None):
5149

5250

5351
def tuned_flash_sdpa(
54-
query : TensorBox,
55-
key : TensorBox,
56-
value : TensorBox,
52+
query : TensorBox,
53+
key : TensorBox,
54+
value : TensorBox,
5755
attn_bias : Optional[TensorBox] = None,
58-
dropout_p : float = 0.0,
59-
is_causal : bool = False,
56+
dropout_p : float = 0.0,
57+
is_causal : bool = False,
6058
return_debug_mask : bool = False,
61-
scale : Optional[float] = None) -> tuple:
62-
63-
59+
scale : Optional[float] = None,
60+
enable_gqa : bool = False) -> tuple:
61+
# _fused_sdp_choice in C++ already guarantees:
62+
# L == S (prefill), Hq == H (non-GQA), dropout_p == 0.0
63+
# before routing here via SDPBackend::overrideable.
64+
# Non-matching shapes fall back to SDPBackend::math in C++ and decompose
65+
# into primitive ops (matmul/softmax) before reaching this lowering.
6466
scale = calculate_scale(query, scale)
6567
N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value)
66-
67-
# Decode-only GQA fast path: q is (B,Hq,1,Dh), B==1, Hq!=H, Hq%H==0.
68-
# Always use the 2-kernel decode path:
69-
# 1) block partials over (kv head, sequence block)
70-
# 2) reduce/merge across blocks
71-
# This keeps KV shared across qsub, avoids dh0-outer duplication, and
72-
# stores compact partials instead of full score/prob tensors in DRAM.
73-
if L == 1 and Hq != H and N == 1 and (Hq % H) == 0:
74-
g = Hq // H
75-
vector_lane = extension_config.vpu_num_lanes
76-
tile_e = vector_lane
77-
dh_tiles = E // tile_e
78-
decode_gqa_block_size = 512
79-
BlkS = decode_gqa_block_size if S >= decode_gqa_block_size else int(S)
80-
# Padding-based tail handling: allow S not divisible by BlkS.
81-
nblk = (S + BlkS - 1) // BlkS
82-
HgDhTiles = H * g * dh_tiles
83-
tile_pack = tile_e * 2
84-
85-
partial_layout = ir.FixedLayout(
86-
query.get_device(),
87-
torch.float32,
88-
[HgDhTiles, nblk, tile_pack],
89-
)
90-
partial_tmpl = MLIRDecodeGQASDPAPartialTemplate([query, key, value], partial_layout, scale, BlkS=BlkS)
91-
partial = partial_tmpl.generate().output_node()
92-
partial.realize()
93-
reduce_tmpl = MLIRDecodeGQASDPAReduceTemplate([partial], layout, BlkS=BlkS)
94-
out_node = reduce_tmpl.generate().output_node()
95-
return (out_node, None, None, None, None, None, None, None, None)
96-
9768
mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale)
98-
99-
# _scaled_dot_product_flash_attention has to return a tuple which has 9 values
100-
# since its backward(_scaled_dot_product_flash_attention_backward) needs that values.
101-
# (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
10269
return (mlir_template.generate().output_node(), None, None, None, None, None, None, None, None)
10370

71+
72+
10473
def conv_layout(
10574
x: TensorBox,
10675
weight: TensorBox,
@@ -345,5 +314,4 @@ def _sort_layouts(x: TensorBox, dim: int, descending: bool):
345314

346315
if extension_config.CONFIG_USE_TIMING_POOLING:
347316
lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template
348-
349317
lowerings.update({getattr(aten._scaled_dot_product_fused_attention_overrideable, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_fused_attention_overrideable.overloads()})

0 commit comments

Comments
 (0)