|
20 | 20 | from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate, MLIRStableSortTemplate |
21 | 21 | from PyTorchSimFrontend.mlir.mlir_sdpa_template import ( |
22 | 22 | MLIRFlashSDPATemplate, |
23 | | - MLIRDecodeGQASDPAPartialTemplate, |
24 | | - MLIRDecodeGQASDPAReduceTemplate, |
25 | 23 | flash_sdpa_args, |
26 | 24 | calculate_scale, |
27 | 25 | ) |
@@ -51,56 +49,27 @@ def tuned_bmm(mat1, mat2, *, layout=None): |
51 | 49 |
|
52 | 50 |
|
53 | 51 | def tuned_flash_sdpa( |
54 | | - query : TensorBox, |
55 | | - key : TensorBox, |
56 | | - value : TensorBox, |
| 52 | + query : TensorBox, |
| 53 | + key : TensorBox, |
| 54 | + value : TensorBox, |
57 | 55 | attn_bias : Optional[TensorBox] = None, |
58 | | - dropout_p : float = 0.0, |
59 | | - is_causal : bool = False, |
| 56 | + dropout_p : float = 0.0, |
| 57 | + is_causal : bool = False, |
60 | 58 | return_debug_mask : bool = False, |
61 | | - scale : Optional[float] = None) -> tuple: |
62 | | - |
63 | | - |
| 59 | + scale : Optional[float] = None, |
| 60 | + enable_gqa : bool = False) -> tuple: |
| 61 | + # _fused_sdp_choice in C++ already guarantees: |
| 62 | + # L == S (prefill), Hq == H (non-GQA), dropout_p == 0.0 |
| 63 | + # before routing here via SDPBackend::overrideable. |
| 64 | + # Non-matching shapes fall back to SDPBackend::math in C++ and decompose |
| 65 | + # into primitive ops (matmul/softmax) before reaching this lowering. |
64 | 66 | scale = calculate_scale(query, scale) |
65 | 67 | N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value) |
66 | | - |
67 | | - # Decode-only GQA fast path: q is (B,Hq,1,Dh), B==1, Hq!=H, Hq%H==0. |
68 | | - # Always use the 2-kernel decode path: |
69 | | - # 1) block partials over (kv head, sequence block) |
70 | | - # 2) reduce/merge across blocks |
71 | | - # This keeps KV shared across qsub, avoids dh0-outer duplication, and |
72 | | - # stores compact partials instead of full score/prob tensors in DRAM. |
73 | | - if L == 1 and Hq != H and N == 1 and (Hq % H) == 0: |
74 | | - g = Hq // H |
75 | | - vector_lane = extension_config.vpu_num_lanes |
76 | | - tile_e = vector_lane |
77 | | - dh_tiles = E // tile_e |
78 | | - decode_gqa_block_size = 512 |
79 | | - BlkS = decode_gqa_block_size if S >= decode_gqa_block_size else int(S) |
80 | | - # Padding-based tail handling: allow S not divisible by BlkS. |
81 | | - nblk = (S + BlkS - 1) // BlkS |
82 | | - HgDhTiles = H * g * dh_tiles |
83 | | - tile_pack = tile_e * 2 |
84 | | - |
85 | | - partial_layout = ir.FixedLayout( |
86 | | - query.get_device(), |
87 | | - torch.float32, |
88 | | - [HgDhTiles, nblk, tile_pack], |
89 | | - ) |
90 | | - partial_tmpl = MLIRDecodeGQASDPAPartialTemplate([query, key, value], partial_layout, scale, BlkS=BlkS) |
91 | | - partial = partial_tmpl.generate().output_node() |
92 | | - partial.realize() |
93 | | - reduce_tmpl = MLIRDecodeGQASDPAReduceTemplate([partial], layout, BlkS=BlkS) |
94 | | - out_node = reduce_tmpl.generate().output_node() |
95 | | - return (out_node, None, None, None, None, None, None, None, None) |
96 | | - |
97 | 68 | mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale) |
98 | | - |
99 | | - # _scaled_dot_product_flash_attention has to return a tuple which has 9 values |
100 | | - # since its backward(_scaled_dot_product_flash_attention_backward) needs that values. |
101 | | - # (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) |
102 | 69 | return (mlir_template.generate().output_node(), None, None, None, None, None, None, None, None) |
103 | 70 |
|
| 71 | + |
| 72 | + |
104 | 73 | def conv_layout( |
105 | 74 | x: TensorBox, |
106 | 75 | weight: TensorBox, |
@@ -345,5 +314,4 @@ def _sort_layouts(x: TensorBox, dim: int, descending: bool): |
345 | 314 |
|
346 | 315 | if extension_config.CONFIG_USE_TIMING_POOLING: |
347 | 316 | lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template |
348 | | - |
349 | 317 | lowerings.update({getattr(aten._scaled_dot_product_fused_attention_overrideable, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_fused_attention_overrideable.overloads()}) |
0 commit comments