Skip to content

Commit caece61

Browse files
committed
[Template] Polish template kernel of cat operation
1 parent fc247be commit caece61

File tree

9 files changed

+355
-205
lines changed

9 files changed

+355
-205
lines changed

PyTorchSimDevice/torch_openreg/openreg/__init__.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -256,52 +256,6 @@ def launch_model(model, *args, stream_index=0, timestamp=0, **kwargs):
256256
from .random import * # noqa: F403
257257
from .amp import *
258258

259-
def _precheck_cat_out_args(args, kwargs):
260-
tensors = args[0] if len(args) > 0 else kwargs.get("tensors")
261-
dim = args[1] if len(args) > 1 else kwargs.get("dim", 0)
262-
out = kwargs.get("out", args[2] if len(args) > 2 else None)
263-
264-
if out is None:
265-
return
266-
if not isinstance(tensors, (list, tuple)) or len(tensors) == 0:
267-
raise RuntimeError("aten::cat.out requires non-empty tensor list")
268-
if not all(isinstance(t, torch.Tensor) for t in tensors):
269-
raise RuntimeError("aten::cat.out tensors must be Tensor values")
270-
if not isinstance(out, torch.Tensor):
271-
raise RuntimeError("aten::cat.out out must be a Tensor")
272-
273-
rank = tensors[0].dim()
274-
if rank == 0:
275-
raise RuntimeError("aten::cat.out does not support scalar inputs")
276-
if dim < 0:
277-
dim += rank
278-
if dim < 0 or dim >= rank:
279-
raise RuntimeError(f"aten::cat.out dim out of range: dim={dim}, rank={rank}")
280-
if any(t.dim() != rank for t in tensors):
281-
raise RuntimeError("aten::cat.out inputs must have the same rank")
282-
if any(t.dtype != tensors[0].dtype for t in tensors):
283-
raise RuntimeError("aten::cat.out inputs must have the same dtype")
284-
if out.dim() != rank:
285-
raise RuntimeError("aten::cat.out out rank mismatch")
286-
287-
for d in range(rank):
288-
if d == dim:
289-
continue
290-
base = tensors[0].shape[d]
291-
if any(t.shape[d] != base for t in tensors[1:]):
292-
raise RuntimeError(
293-
f"aten::cat.out non-concatenated dimension mismatch at dim={d}"
294-
)
295-
if out.shape[d] != base:
296-
raise RuntimeError(f"aten::cat.out out shape mismatch at dim={d}")
297-
298-
expected = sum(t.shape[dim] for t in tensors)
299-
if out.shape[dim] != expected:
300-
raise RuntimeError(
301-
f"aten::cat.out out concatenated dimension mismatch at dim={dim}: "
302-
f"expected {expected}, got {out.shape[dim]}"
303-
)
304-
305259
def eager_to_compile(op_name):
306260
"""
307261
Register an eager mode operation as a graph-based implementation using torch.compile().
@@ -313,9 +267,6 @@ def eager_to_compile(op_name):
313267
torch.npu.eager_to_compile("aten::mul.Tensor")
314268
"""
315269
def wrapper(*args, **kwargs):
316-
if op_name == "aten::cat.out":
317-
_precheck_cat_out_args(args, kwargs)
318-
319270
@torch.compile(dynamic=False)
320271
def dummy_graph(*args, **kwargs):
321272
# Convert "aten::mul.Tensor" -> torch.ops.aten.mul.Tensor

PyTorchSimFrontend/mlir/mlir_bmm_template.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@
154154
class MLIRBMMTemplate(MLIRTemplate):
155155
def __init__(self, input_nodes, layout, input_reorder=None):
156156
super().__init__("kernel", input_nodes, layout, input_reorder)
157+
self.support_epilogue_fusion = True
158+
self.support_prologue_fusion = True
159+
self.support_reduction_fusion = True
157160

158161
def render(self,
159162
kernel: MLIRTemplateKernel,
Lines changed: 193 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import List, Optional, cast
1+
from typing import List, Optional
2+
import math
3+
import itertools
24

35
import sympy
4-
from torch._inductor.ir import Buffer, IRNode
5-
from torch._inductor.virtualized import V
6+
from torch._inductor.ir import IRNode
67

78
from PyTorchSimFrontend.mlir import mlir_common
89
from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel
@@ -11,39 +12,29 @@
1112
TEMPLATE = r"""
1213
{{kernel.def_global_vars()}}
1314
14-
func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X0, X1], outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} {
15-
{{ kernel.def_sram_buffer("X0", X0_TILE_DESC, id=0, indent_size=2) }}
16-
{{ kernel.def_sram_buffer("X1", X1_TILE_DESC, id=1, indent_size=2) }}
17-
{{ kernel.def_sram_buffer(OUT_DVAR, Y_TILE_DESC, id=2, indent_size=2) }}
15+
func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=INPUT_NAMES, outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} {
16+
{% for i in range(NUM_INPUTS) %}
17+
{{ kernel.def_sram_buffer("X" + i|string, INPUT_TILE_DESCS[i], id=i, indent_size=2) }}
18+
{% endfor %}
19+
{{ kernel.def_sram_buffer(OUT_DVAR, Y_TILE_DESC, id=NUM_INPUTS, indent_size=2) }}
1820
{{ kernel.def_local_vars(indent_size=2) }}
1921
2022
affine.for %cat_block = 0 to 1 step 1 {
21-
{% if DIM == 0 %}
22-
affine.for %index0 = 0 to {{ X0_ROWS }} step 1 {
23-
affine.for %index1 = 0 to {{ COLS }} step 1 {
24-
{{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }}
25-
{{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }}
26-
}
27-
}
28-
29-
affine.for %index2 = 0 to {{ X1_ROWS }} step 1 {
30-
affine.for %index3 = 0 to {{ COLS }} step 1 {
31-
{{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }}
32-
{{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }}
33-
}
34-
}
35-
{% else %}
36-
affine.for %index0 = 0 to {{ ROWS }} step 1 {
37-
affine.for %index1 = 0 to {{ X0_COLS }} step 1 {
38-
{{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }}
39-
{{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }}
40-
}
41-
affine.for %index3 = 0 to {{ X1_COLS }} step 1 {
42-
{{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }}
43-
{{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }}
44-
}
45-
}
46-
{% endif %}
23+
{%- for d in range(RANK-1) %}
24+
affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ TILE_SIZES[d] }} {
25+
{%- endfor %}
26+
{%- for i in range(NUM_INPUTS) %}
27+
// Input tensor{{ i }}
28+
affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} {
29+
%index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }})
30+
{{ kernel.def_dma_op("MVIN", "X" + i|string, INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }}
31+
{{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }}
32+
} { inner_loop=true }
33+
{%- endfor %}
34+
35+
{%- for d in range(RANK-1) %}
36+
} { outer_loop=true }
37+
{%- endfor %}
4738
} { outer_loop=true }
4839
return
4940
}
@@ -66,79 +57,132 @@ def render(
6657
is_out_variant = template_buffer_node is not None
6758
if is_out_variant:
6859
self.output_node = template_buffer_node
69-
# cat template currently emits a single output buffer and does not
70-
# support epilogue output remapping.
71-
72-
def _unwrap_node(n):
73-
return n.node if hasattr(n, "node") else n
74-
75-
x0 = _unwrap_node(self.input_nodes[0])
76-
x1 = _unwrap_node(self.input_nodes[1])
77-
y = _unwrap_node(self.output_node)
78-
79-
def _as_int(v):
80-
try:
81-
return int(v)
82-
except Exception:
83-
return int(V.graph.sizevars.size_hint(v))
84-
85-
x0_rows = _as_int(x0.get_size()[0])
86-
x1_rows = _as_int(x1.get_size()[0])
87-
x0_cols = _as_int(x0.get_size()[1])
88-
x1_cols = _as_int(x1.get_size()[1])
89-
y_cols = _as_int(y.get_size()[1])
90-
kernel.loop_size = None
91-
92-
# 2D cat template with contiguous layout.
93-
x0_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1)
94-
x0_tile_desc.set_tile_size_stride([1, 1], [1, 1])
95-
x0_tile_desc.set_name("x0_cat_tile")
96-
x1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1)
97-
x1_tile_desc.set_tile_size_stride([1, 1], [1, 1])
98-
x1_tile_desc.set_name("x1_cat_tile")
99-
y_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1)
100-
y_tile_desc.set_tile_size_stride([1, 1], [1, 1])
60+
61+
input_nodes = self.input_nodes
62+
y = self.output_node
63+
num_inputs = len(self.input_nodes)
64+
rank = len(y.get_size())
65+
66+
input_sizes = [x.get_size() for x in input_nodes]
67+
output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim]
68+
output_dim = [dim for dim, sz in enumerate(y.get_size()) if dim != self.dim]
69+
70+
tile_sizes = tile_info if tile_info is not None else [1] * len(output_sizes)
71+
72+
# Calculate non-concat dimensions tile size (for SPAD calculation)
73+
non_dim_tile_elements = math.prod(tile_sizes) if tile_sizes else 1
74+
non_dim_tile_spad = non_dim_tile_elements * kernel.precision
75+
76+
# Calculate max tile size for concat dimension for each input
77+
# SPAD needs to hold: input tile + output tile (for same non-dim tile)
78+
max_spad_per_input = kernel.spad_info["spad_size"] * kernel.vector_lane // 2
79+
extra_concat_input = math.ceil(max_spad_per_input /non_dim_tile_spad) - num_inputs
80+
81+
input_tile_sizes_dim = []
82+
input_tile_descs = []
83+
input_idxs = []
84+
output_idxs = []
85+
output_strides = y.get_layout().stride
86+
87+
cumulative_offsets = [0]
88+
for i in range(num_inputs - 1):
89+
cumulative_offsets.append(cumulative_offsets[-1] + input_sizes[i][self.dim])
90+
91+
for i, x in enumerate(input_nodes):
92+
# Calculate max tile size for concat dimension for this input
93+
input_dim_size = input_sizes[i][self.dim]
94+
if extra_concat_input > 0 and non_dim_tile_elements > 0:
95+
max_tile_dim = min(
96+
input_dim_size, extra_concat_input
97+
)
98+
extra_concat_input -= max_tile_dim
99+
else:
100+
max_tile_dim = 1
101+
102+
input_tile_sizes_dim.append(max_tile_dim)
103+
104+
# Build full tile size list for this input
105+
full_tile_sizes = []
106+
tile_size_idx = 0
107+
for d in range(rank):
108+
if d != self.dim:
109+
full_tile_sizes.append(tile_sizes[tile_size_idx])
110+
tile_size_idx += 1
111+
else:
112+
full_tile_sizes.append(max_tile_dim)
113+
114+
tile_desc = mlir_common.MLIRMultiDimTile(
115+
full_tile_sizes,
116+
kernel.vector_lane,
117+
vlane_split_axis=rank - 1,
118+
vlane_stride=1
119+
)
120+
tile_desc.set_tile_size(full_tile_sizes)
121+
tile_desc.set_name(f"x{i}_cat_tile")
122+
input_tile_descs.append(tile_desc)
123+
x_stride = x.get_layout().stride
124+
125+
input_idx = []
126+
output_idx = []
127+
for d in range(rank):
128+
if d != self.dim:
129+
input_idx_symbol = sympy.Symbol(f"index{d}")
130+
output_idx_symbol = sympy.Symbol(f"index{d}")
131+
else:
132+
input_idx_symbol = sympy.Symbol(f"index_local{self.dim}_{i}")
133+
output_idx_symbol = sympy.Symbol(f"index{self.dim}_{i}")
134+
input_idx.append(input_idx_symbol * x_stride[d])
135+
output_idx.append(output_idx_symbol * output_strides[d])
136+
input_idxs.append(input_idx)
137+
output_idxs.append(output_idx)
138+
139+
# Output tile size: use max of all input concat tile sizes for output
140+
max_output_tile_dim = max(input_tile_sizes_dim) if input_tile_sizes_dim else 1
141+
output_full_tile_sizes = []
142+
tile_size_idx = 0
143+
for d in range(rank):
144+
if d != self.dim:
145+
output_full_tile_sizes.append(tile_sizes[tile_size_idx])
146+
tile_size_idx += 1
147+
else:
148+
output_full_tile_sizes.append(max_output_tile_dim)
149+
150+
y_tile_desc = mlir_common.MLIRMultiDimTile(
151+
output_full_tile_sizes,
152+
kernel.vector_lane,
153+
vlane_split_axis=rank - 1,
154+
vlane_stride=1
155+
)
156+
y_tile_desc.set_tile_size(output_full_tile_sizes)
101157
y_tile_desc.set_name("y_cat_tile")
102158

103-
if self.dim == 0:
104-
# Flattened offsets for dim=0 cat.
105-
x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")]
106-
x1_idx = [sympy.Symbol("index2") * x1_cols, sympy.Symbol("index3")]
107-
y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")]
108-
y1_idx = [(sympy.Symbol("index2") + x0_rows) * y_cols, sympy.Symbol("index3")]
109-
else:
110-
# Flattened offsets for dim=1 cat.
111-
x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")]
112-
x1_idx = [sympy.Symbol("index0") * x1_cols, sympy.Symbol("index3")]
113-
y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")]
114-
y1_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index3") + x0_cols]
159+
input_names = [f"X{i}" for i in range(num_inputs)]
160+
names_str = ", ".join(input_names + ["out_ptr1" if is_out_variant else "Y"])
161+
indent_size = 2 + (rank - 1) * 2 + 4
115162

116163
kernel.render_options = dict(
117164
KERNEL_NAME=self.name,
118165
kernel=kernel,
119-
X0=x0,
120-
X1=x1,
121166
Y=y,
122167
OUT_DVAR="out_ptr1" if is_out_variant else "Y",
123-
NAMES_STR="X0, X1, out_ptr1" if is_out_variant else "X0, X1, Y",
168+
NAMES_STR=names_str,
169+
INPUT_NAMES=input_nodes,
170+
NUM_INPUTS=num_inputs,
171+
RANK=rank,
124172
DIM=self.dim,
125-
X0_ROWS=x0_rows,
126-
X1_ROWS=x1_rows,
127-
ROWS=x0_rows,
128-
X0_COLS=x0_cols,
129-
X1_COLS=x1_cols,
130-
COLS=x0_cols,
131-
X0_TILE_DESC=x0_tile_desc,
132-
X1_TILE_DESC=x1_tile_desc,
173+
INPUT_SIZES=input_sizes,
174+
OUTPUT_SIZES=output_sizes,
175+
OUTPUT_DIM=output_dim,
176+
TILE_SIZES=tile_sizes,
177+
INPUT_TILE_SIZES_DIM=input_tile_sizes_dim,
178+
INPUT_TILE_DESCS=input_tile_descs,
133179
Y_TILE_DESC=y_tile_desc,
134-
X0_IDX=x0_idx,
135-
X1_IDX=x1_idx,
136-
Y0_IDX=y0_idx,
137-
Y1_IDX=y1_idx,
180+
INPUT_IDXS=input_idxs,
181+
OUTPUT_IDXS=output_idxs,
182+
CUMULATIVE_OFFSETS=cumulative_offsets,
183+
INDENT_SIZE=indent_size,
138184
input_reorder=self.input_reorder,
139185
)
140-
# Needed when epilogue fusion requests set_ranges().
141-
kernel.dim_aliasing = {"index0": "index0", "index1": "index1"}
142186

143187
if hasattr(self.output_node, "node") and hasattr(self.output_node.node, "get_name"):
144188
output_node_name = self.output_node.node.get_name()
@@ -165,3 +209,58 @@ def _as_int(v):
165209

166210
code = self._template_from_string(TEMPLATE).render(**kernel.render_options)
167211
return code
212+
213+
def get_tile_candidates(
214+
self,
215+
kernel: MLIRTemplateKernel,
216+
template_buffer_node=None,
217+
epilogue_nodes: Optional[List[IRNode]] = None,
218+
**kwargs,
219+
):
220+
"""Generate tile candidates for cat operation. Concat dimension always has tile size 1."""
221+
if template_buffer_node is not None:
222+
self.output_node = template_buffer_node
223+
224+
y = self.output_node
225+
num_inputs = len(self.input_nodes)
226+
output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim]
227+
num_non_dim_dims = len(output_sizes)
228+
229+
if num_non_dim_dims == 0:
230+
return [[1]]
231+
232+
tile_candidates = []
233+
dim_tile_candidates = []
234+
235+
for dim_size in output_sizes:
236+
dim_candidates = []
237+
max_tile = min(dim_size, kernel.spad_info["spad_size"] // (kernel.vector_lane * kernel.precision * 2 * num_inputs))
238+
239+
for mult in range(1, max_tile // kernel.vector_lane + 1):
240+
tile = mult * kernel.vector_lane
241+
if tile <= dim_size:
242+
dim_candidates.append(tile)
243+
244+
if max_tile > 0:
245+
for exp in range(int(math.log2(max_tile)) + 1):
246+
tile = 2 ** exp
247+
if tile <= dim_size and tile not in dim_candidates:
248+
dim_candidates.append(tile)
249+
250+
if dim_size not in dim_candidates:
251+
dim_candidates.append(dim_size)
252+
253+
dim_tile_candidates.append(sorted(set(dim_candidates))[:5])
254+
255+
for tile_combo in itertools.product(*dim_tile_candidates):
256+
total_elements = math.prod(tile_combo)
257+
total_spad_needed = total_elements * (num_inputs + 1) * kernel.precision
258+
259+
if total_spad_needed <= kernel.spad_info["spad_size"] * kernel.vector_lane:
260+
tile_candidates.append(list(tile_combo))
261+
262+
if not tile_candidates:
263+
tile_candidates = [[1] * num_non_dim_dims]
264+
265+
tile_candidates.sort(key=lambda x: -math.prod(x))
266+
return tile_candidates[:4]

0 commit comments

Comments
 (0)