1- from typing import List , Optional , cast
1+ from typing import List , Optional
2+ import math
3+ import itertools
24
35import sympy
4- from torch ._inductor .ir import Buffer , IRNode
5- from torch ._inductor .virtualized import V
6+ from torch ._inductor .ir import IRNode
67
78from PyTorchSimFrontend .mlir import mlir_common
89from PyTorchSimFrontend .mlir .mlir_template import MLIRTemplate , MLIRTemplateKernel
1112TEMPLATE = r"""
1213{{kernel.def_global_vars()}}
1314
14- func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X0, X1], outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} {
15- {{ kernel.def_sram_buffer("X0", X0_TILE_DESC, id=0, indent_size=2) }}
16- {{ kernel.def_sram_buffer("X1", X1_TILE_DESC, id=1, indent_size=2) }}
17- {{ kernel.def_sram_buffer(OUT_DVAR, Y_TILE_DESC, id=2, indent_size=2) }}
15+ func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=INPUT_NAMES, outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} {
16+ {% for i in range(NUM_INPUTS) %}
17+ {{ kernel.def_sram_buffer("X" + i|string, INPUT_TILE_DESCS[i], id=i, indent_size=2) }}
18+ {% endfor %}
19+ {{ kernel.def_sram_buffer(OUT_DVAR, Y_TILE_DESC, id=NUM_INPUTS, indent_size=2) }}
1820 {{ kernel.def_local_vars(indent_size=2) }}
1921
2022 affine.for %cat_block = 0 to 1 step 1 {
21- {% if DIM == 0 %}
22- affine.for %index0 = 0 to {{ X0_ROWS }} step 1 {
23- affine.for %index1 = 0 to {{ COLS }} step 1 {
24- {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }}
25- {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }}
26- }
27- }
28-
29- affine.for %index2 = 0 to {{ X1_ROWS }} step 1 {
30- affine.for %index3 = 0 to {{ COLS }} step 1 {
31- {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }}
32- {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }}
33- }
34- }
35- {% else %}
36- affine.for %index0 = 0 to {{ ROWS }} step 1 {
37- affine.for %index1 = 0 to {{ X0_COLS }} step 1 {
38- {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }}
39- {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }}
40- }
41- affine.for %index3 = 0 to {{ X1_COLS }} step 1 {
42- {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }}
43- {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }}
44- }
45- }
46- {% endif %}
23+ {%- for d in range(RANK-1) %}
24+ affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ TILE_SIZES[d] }} {
25+ {%- endfor %}
26+ {%- for i in range(NUM_INPUTS) %}
27+ // Input tensor{{ i }}
28+ affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} {
29+ %index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }})
30+ {{ kernel.def_dma_op("MVIN", "X" + i|string, INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }}
31+ {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }}
32+ } { inner_loop=true }
33+ {%- endfor %}
34+
35+ {%- for d in range(RANK-1) %}
36+ } { outer_loop=true }
37+ {%- endfor %}
4738 } { outer_loop=true }
4839 return
4940}
@@ -66,79 +57,132 @@ def render(
6657 is_out_variant = template_buffer_node is not None
6758 if is_out_variant :
6859 self .output_node = template_buffer_node
69- # cat template currently emits a single output buffer and does not
70- # support epilogue output remapping.
71-
72- def _unwrap_node (n ):
73- return n .node if hasattr (n , "node" ) else n
74-
75- x0 = _unwrap_node (self .input_nodes [0 ])
76- x1 = _unwrap_node (self .input_nodes [1 ])
77- y = _unwrap_node (self .output_node )
78-
79- def _as_int (v ):
80- try :
81- return int (v )
82- except Exception :
83- return int (V .graph .sizevars .size_hint (v ))
84-
85- x0_rows = _as_int (x0 .get_size ()[0 ])
86- x1_rows = _as_int (x1 .get_size ()[0 ])
87- x0_cols = _as_int (x0 .get_size ()[1 ])
88- x1_cols = _as_int (x1 .get_size ()[1 ])
89- y_cols = _as_int (y .get_size ()[1 ])
90- kernel .loop_size = None
91-
92- # 2D cat template with contiguous layout.
93- x0_tile_desc = mlir_common .MLIRMultiDimTile ([1 , 1 ], kernel .vector_lane , vlane_split_axis = 1 , vlane_stride = 1 )
94- x0_tile_desc .set_tile_size_stride ([1 , 1 ], [1 , 1 ])
95- x0_tile_desc .set_name ("x0_cat_tile" )
96- x1_tile_desc = mlir_common .MLIRMultiDimTile ([1 , 1 ], kernel .vector_lane , vlane_split_axis = 1 , vlane_stride = 1 )
97- x1_tile_desc .set_tile_size_stride ([1 , 1 ], [1 , 1 ])
98- x1_tile_desc .set_name ("x1_cat_tile" )
99- y_tile_desc = mlir_common .MLIRMultiDimTile ([1 , 1 ], kernel .vector_lane , vlane_split_axis = 1 , vlane_stride = 1 )
100- y_tile_desc .set_tile_size_stride ([1 , 1 ], [1 , 1 ])
60+
61+ input_nodes = self .input_nodes
62+ y = self .output_node
63+ num_inputs = len (self .input_nodes )
64+ rank = len (y .get_size ())
65+
66+ input_sizes = [x .get_size () for x in input_nodes ]
67+ output_sizes = [sz for dim , sz in enumerate (y .get_size ()) if dim != self .dim ]
68+ output_dim = [dim for dim , sz in enumerate (y .get_size ()) if dim != self .dim ]
69+
70+ tile_sizes = tile_info if tile_info is not None else [1 ] * len (output_sizes )
71+
72+ # Calculate non-concat dimensions tile size (for SPAD calculation)
73+ non_dim_tile_elements = math .prod (tile_sizes ) if tile_sizes else 1
74+ non_dim_tile_spad = non_dim_tile_elements * kernel .precision
75+
76+ # Calculate max tile size for concat dimension for each input
77+ # SPAD needs to hold: input tile + output tile (for same non-dim tile)
78+ max_spad_per_input = kernel .spad_info ["spad_size" ] * kernel .vector_lane // 2
79+ extra_concat_input = math .ceil (max_spad_per_input / non_dim_tile_spad ) - num_inputs
80+
81+ input_tile_sizes_dim = []
82+ input_tile_descs = []
83+ input_idxs = []
84+ output_idxs = []
85+ output_strides = y .get_layout ().stride
86+
87+ cumulative_offsets = [0 ]
88+ for i in range (num_inputs - 1 ):
89+ cumulative_offsets .append (cumulative_offsets [- 1 ] + input_sizes [i ][self .dim ])
90+
91+ for i , x in enumerate (input_nodes ):
92+ # Calculate max tile size for concat dimension for this input
93+ input_dim_size = input_sizes [i ][self .dim ]
94+ if extra_concat_input > 0 and non_dim_tile_elements > 0 :
95+ max_tile_dim = min (
96+ input_dim_size , extra_concat_input
97+ )
98+ extra_concat_input -= max_tile_dim
99+ else :
100+ max_tile_dim = 1
101+
102+ input_tile_sizes_dim .append (max_tile_dim )
103+
104+ # Build full tile size list for this input
105+ full_tile_sizes = []
106+ tile_size_idx = 0
107+ for d in range (rank ):
108+ if d != self .dim :
109+ full_tile_sizes .append (tile_sizes [tile_size_idx ])
110+ tile_size_idx += 1
111+ else :
112+ full_tile_sizes .append (max_tile_dim )
113+
114+ tile_desc = mlir_common .MLIRMultiDimTile (
115+ full_tile_sizes ,
116+ kernel .vector_lane ,
117+ vlane_split_axis = rank - 1 ,
118+ vlane_stride = 1
119+ )
120+ tile_desc .set_tile_size (full_tile_sizes )
121+ tile_desc .set_name (f"x{ i } _cat_tile" )
122+ input_tile_descs .append (tile_desc )
123+ x_stride = x .get_layout ().stride
124+
125+ input_idx = []
126+ output_idx = []
127+ for d in range (rank ):
128+ if d != self .dim :
129+ input_idx_symbol = sympy .Symbol (f"index{ d } " )
130+ output_idx_symbol = sympy .Symbol (f"index{ d } " )
131+ else :
132+ input_idx_symbol = sympy .Symbol (f"index_local{ self .dim } _{ i } " )
133+ output_idx_symbol = sympy .Symbol (f"index{ self .dim } _{ i } " )
134+ input_idx .append (input_idx_symbol * x_stride [d ])
135+ output_idx .append (output_idx_symbol * output_strides [d ])
136+ input_idxs .append (input_idx )
137+ output_idxs .append (output_idx )
138+
139+ # Output tile size: use max of all input concat tile sizes for output
140+ max_output_tile_dim = max (input_tile_sizes_dim ) if input_tile_sizes_dim else 1
141+ output_full_tile_sizes = []
142+ tile_size_idx = 0
143+ for d in range (rank ):
144+ if d != self .dim :
145+ output_full_tile_sizes .append (tile_sizes [tile_size_idx ])
146+ tile_size_idx += 1
147+ else :
148+ output_full_tile_sizes .append (max_output_tile_dim )
149+
150+ y_tile_desc = mlir_common .MLIRMultiDimTile (
151+ output_full_tile_sizes ,
152+ kernel .vector_lane ,
153+ vlane_split_axis = rank - 1 ,
154+ vlane_stride = 1
155+ )
156+ y_tile_desc .set_tile_size (output_full_tile_sizes )
101157 y_tile_desc .set_name ("y_cat_tile" )
102158
103- if self .dim == 0 :
104- # Flattened offsets for dim=0 cat.
105- x0_idx = [sympy .Symbol ("index0" ) * x0_cols , sympy .Symbol ("index1" )]
106- x1_idx = [sympy .Symbol ("index2" ) * x1_cols , sympy .Symbol ("index3" )]
107- y0_idx = [sympy .Symbol ("index0" ) * y_cols , sympy .Symbol ("index1" )]
108- y1_idx = [(sympy .Symbol ("index2" ) + x0_rows ) * y_cols , sympy .Symbol ("index3" )]
109- else :
110- # Flattened offsets for dim=1 cat.
111- x0_idx = [sympy .Symbol ("index0" ) * x0_cols , sympy .Symbol ("index1" )]
112- x1_idx = [sympy .Symbol ("index0" ) * x1_cols , sympy .Symbol ("index3" )]
113- y0_idx = [sympy .Symbol ("index0" ) * y_cols , sympy .Symbol ("index1" )]
114- y1_idx = [sympy .Symbol ("index0" ) * y_cols , sympy .Symbol ("index3" ) + x0_cols ]
159+ input_names = [f"X{ i } " for i in range (num_inputs )]
160+ names_str = ", " .join (input_names + ["out_ptr1" if is_out_variant else "Y" ])
161+ indent_size = 2 + (rank - 1 ) * 2 + 4
115162
116163 kernel .render_options = dict (
117164 KERNEL_NAME = self .name ,
118165 kernel = kernel ,
119- X0 = x0 ,
120- X1 = x1 ,
121166 Y = y ,
122167 OUT_DVAR = "out_ptr1" if is_out_variant else "Y" ,
123- NAMES_STR = "X0, X1, out_ptr1" if is_out_variant else "X0, X1, Y" ,
168+ NAMES_STR = names_str ,
169+ INPUT_NAMES = input_nodes ,
170+ NUM_INPUTS = num_inputs ,
171+ RANK = rank ,
124172 DIM = self .dim ,
125- X0_ROWS = x0_rows ,
126- X1_ROWS = x1_rows ,
127- ROWS = x0_rows ,
128- X0_COLS = x0_cols ,
129- X1_COLS = x1_cols ,
130- COLS = x0_cols ,
131- X0_TILE_DESC = x0_tile_desc ,
132- X1_TILE_DESC = x1_tile_desc ,
173+ INPUT_SIZES = input_sizes ,
174+ OUTPUT_SIZES = output_sizes ,
175+ OUTPUT_DIM = output_dim ,
176+ TILE_SIZES = tile_sizes ,
177+ INPUT_TILE_SIZES_DIM = input_tile_sizes_dim ,
178+ INPUT_TILE_DESCS = input_tile_descs ,
133179 Y_TILE_DESC = y_tile_desc ,
134- X0_IDX = x0_idx ,
135- X1_IDX = x1_idx ,
136- Y0_IDX = y0_idx ,
137- Y1_IDX = y1_idx ,
180+ INPUT_IDXS = input_idxs ,
181+ OUTPUT_IDXS = output_idxs ,
182+ CUMULATIVE_OFFSETS = cumulative_offsets ,
183+ INDENT_SIZE = indent_size ,
138184 input_reorder = self .input_reorder ,
139185 )
140- # Needed when epilogue fusion requests set_ranges().
141- kernel .dim_aliasing = {"index0" : "index0" , "index1" : "index1" }
142186
143187 if hasattr (self .output_node , "node" ) and hasattr (self .output_node .node , "get_name" ):
144188 output_node_name = self .output_node .node .get_name ()
@@ -165,3 +209,58 @@ def _as_int(v):
165209
166210 code = self ._template_from_string (TEMPLATE ).render (** kernel .render_options )
167211 return code
212+
213+ def get_tile_candidates (
214+ self ,
215+ kernel : MLIRTemplateKernel ,
216+ template_buffer_node = None ,
217+ epilogue_nodes : Optional [List [IRNode ]] = None ,
218+ ** kwargs ,
219+ ):
220+ """Generate tile candidates for cat operation. Concat dimension always has tile size 1."""
221+ if template_buffer_node is not None :
222+ self .output_node = template_buffer_node
223+
224+ y = self .output_node
225+ num_inputs = len (self .input_nodes )
226+ output_sizes = [sz for dim , sz in enumerate (y .get_size ()) if dim != self .dim ]
227+ num_non_dim_dims = len (output_sizes )
228+
229+ if num_non_dim_dims == 0 :
230+ return [[1 ]]
231+
232+ tile_candidates = []
233+ dim_tile_candidates = []
234+
235+ for dim_size in output_sizes :
236+ dim_candidates = []
237+ max_tile = min (dim_size , kernel .spad_info ["spad_size" ] // (kernel .vector_lane * kernel .precision * 2 * num_inputs ))
238+
239+ for mult in range (1 , max_tile // kernel .vector_lane + 1 ):
240+ tile = mult * kernel .vector_lane
241+ if tile <= dim_size :
242+ dim_candidates .append (tile )
243+
244+ if max_tile > 0 :
245+ for exp in range (int (math .log2 (max_tile )) + 1 ):
246+ tile = 2 ** exp
247+ if tile <= dim_size and tile not in dim_candidates :
248+ dim_candidates .append (tile )
249+
250+ if dim_size not in dim_candidates :
251+ dim_candidates .append (dim_size )
252+
253+ dim_tile_candidates .append (sorted (set (dim_candidates ))[:5 ])
254+
255+ for tile_combo in itertools .product (* dim_tile_candidates ):
256+ total_elements = math .prod (tile_combo )
257+ total_spad_needed = total_elements * (num_inputs + 1 ) * kernel .precision
258+
259+ if total_spad_needed <= kernel .spad_info ["spad_size" ] * kernel .vector_lane :
260+ tile_candidates .append (list (tile_combo ))
261+
262+ if not tile_candidates :
263+ tile_candidates = [[1 ] * num_non_dim_dims ]
264+
265+ tile_candidates .sort (key = lambda x : - math .prod (x ))
266+ return tile_candidates [:4 ]
0 commit comments