Skip to content

Commit c5f085e

Browse files
HamHyungkyuYWHyuk
authored andcommitted
[Frontend] Enhance vector size handling for low-precision paths in MLIR kernels
1 parent dd71c70 commit c5f085e

File tree

2 files changed

+68
-9
lines changed

2 files changed

+68
-9
lines changed

PyTorchSimFrontend/mlir/mlir_common.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,17 @@ def get_dtype_nbytes(dtype):
103103

104104
MLIR_INF = {
105105
"inf" : {
106+
"f16" : 0x7C00,
106107
"f32" : 0x7F800000,
107108
"f64" : 0x7FF0000000000000
108109
},
109110
"-inf" : {
111+
"f16" : 0xFC00,
110112
"f32" : 0xFF800000,
111113
"f64" : 0xFFF0000000000000
112114
},
113115
"nan" : {
116+
"f16" : 0x7C00,
114117
"f32" : 0x7FC00000,
115118
"f64" : 0x7FF8000000000000
116119
}
@@ -260,17 +263,23 @@ def get_tile_stride_per_lane(self, tile_size: list[int], tile_stride: list[int])
260263
return tile_stride
261264

262265
def get_compute_vec_size(self, tile_size: list[int], reduction_numel: int, nr_rdim: int) -> int:
263-
if self.forced_vec_size is not None:
264-
return self.forced_vec_size
265-
266266
per_lane = self.get_numel_per_lane(tile_size)
267267
stride = self.vlane_stride
268268
if nr_rdim:
269269
val = per_lane // max(reduction_numel, 1)
270+
result = val
270271
for mult in [8, 4, 2]:
271272
if per_lane >= val * mult:
272-
return val * mult
273-
return val
273+
result = val * mult
274+
break
275+
if self.forced_vec_size is not None:
276+
# Cap while keeping result divisible by val (= reduction_size).
277+
# This preserves the assert(vec_len % reduction_size == 0) invariant.
278+
capped = (min(result, self.forced_vec_size) // max(val, 1)) * max(val, 1)
279+
result = max(capped, val)
280+
return result
281+
if self.forced_vec_size is not None:
282+
return self.forced_vec_size
274283
for mult in [8, 4, 2]:
275284
if (per_lane // stride) >= mult:
276285
return stride * mult
@@ -787,10 +796,24 @@ def codegen_nodes(self, nodes, kernel_name):
787796
# Set node range info
788797
vars, reduction_vars = self.set_ranges(group, reduction_group)
789798
tile_desc = self.compute_tile_size(nodes, vars, reduction_vars)
799+
_, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs()
800+
safe_vec_size = self.get_safe_vec_size(tile_desc.get_compute_vec_size())
801+
# For pointwise (non-reduction) kernels, cap the MLIR vector size so that
802+
# f16->f32 widening stays within LMUL<=4 (step and forced_vec_size must match).
803+
# Reduction kernels are left unchanged: their accumulator/multi_reduction
804+
# structure assumes compute_vec_size == step, so we must not split them here.
805+
tile_desc.vmap.forced_vec_size = safe_vec_size
806+
compute_vec = tile_desc.get_compute_vec_size()
807+
# RVV requires vector lengths that produce integer power-of-2 LMUL values.
808+
# Non-power-of-2 element counts (e.g. 24) cause LLVM WidenVectorResult crashes.
809+
# Raise BEFORE the try/except so this propagates to make_choices (not retried).
810+
if compute_vec > 1 and (compute_vec & (compute_vec - 1)) != 0:
811+
raise RecompileSignal(
812+
f"Non-power-of-2 compute_vec_size {compute_vec}: tile rejected (RVV requires power-of-2 LMUL)"
813+
)
790814
self.compute_body_loop.size = tile_desc.get_numel_per_lane()
791-
self.compute_body_loop.step = tile_desc.get_compute_vec_size()
815+
self.compute_body_loop.step = compute_vec
792816
try:
793-
_, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs()
794817
with self as kernel:
795818
for node in nodes:
796819
node.run(vars, reduction_vars)
@@ -1035,6 +1058,42 @@ def __exit__(self, exc_type, exc_val, exc_tb):
10351058
self._nested_context_depth -= 1
10361059
if self._nested_context_depth == 0:
10371060
super().__exit__(exc_type, exc_val, exc_tb)
1061+
1062+
def get_safe_vec_size(self, default_vec_size: int = 64) -> int:
1063+
"""
1064+
Cap forced vector size for low-precision paths so widening ops
1065+
(e.g., f16/bf16 -> f32) do not exceed RVV LMUL limits.
1066+
1067+
Widening is legal up to source LMUL<=4 (destination LMUL<=8).
1068+
Using RVV relation LMUL = (SEW * VL) / VLEN, the safe source VL is:
1069+
VL <= 4 * VLEN / SEW
1070+
"""
1071+
1072+
if not hasattr(self, "buffer_types") or not self.buffer_types:
1073+
return default_vec_size
1074+
1075+
lowp_bits = []
1076+
for info in self.buffer_types.values():
1077+
dtype = info[0] if info else None
1078+
if dtype in DTYPE_LOWP_FP:
1079+
mlir_dtype = DTYPE_TO_MLIR[dtype]
1080+
lowp_bits.append(MLIR_TO_BIT[mlir_dtype])
1081+
1082+
if not lowp_bits:
1083+
return default_vec_size
1084+
1085+
min_lowp_bits = min(lowp_bits)
1086+
# Constraint: Vector element count must be compatible across all types.
1087+
# VLEN=256: f16 (LMUL=2) and f32 (LMUL=4) both yield 32 elements.
1088+
# Note: Gem5 version restricts widening ops to LMUL < 8 for destination registers.
1089+
# Max LMUL set to 2 to ensure compatibility/safety.
1090+
1091+
widen_safe_cap = self.vlen * 2 // min_lowp_bits
1092+
if widen_safe_cap <= 0:
1093+
return default_vec_size
1094+
1095+
vec_size = min(default_vec_size, widen_safe_cap)
1096+
return vec_size
10381097

10391098
@dataclasses.dataclass
10401099
class LoopLevel:

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,7 @@ def set_tile_size(self, template_fusion_info, prologue=False):
12551255
numel_per_lane = tile_desc.get_numel_per_lane()
12561256
r_tile_size = tile_desc.get_tile_size()[-1]
12571257
nr_outer_loop = (numel_per_lane + r_tile_size-1) // r_tile_size
1258-
tile_desc.vmap.forced_vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality...
1258+
tile_desc.vmap.forced_vec_size = self.get_safe_vec_size(nr_outer_loop * 32) # Why? Emprically selected, other option failed to functionality...
12591259

12601260
self.reduction_fusion = True
12611261
self.r_tile_size = tile_desc.get_tile_size()[-1]
@@ -1266,7 +1266,7 @@ def set_tile_size(self, template_fusion_info, prologue=False):
12661266
self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop
12671267
self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop)
12681268
else:
1269-
tile_desc.vmap.forced_vec_size = 64
1269+
tile_desc.vmap.forced_vec_size = self.get_safe_vec_size(64)
12701270

12711271
if prologue:
12721272
self.prologue_compute_body_loop.size = tile_desc.get_numel_per_lane()

0 commit comments

Comments
 (0)