@@ -103,14 +103,17 @@ def get_dtype_nbytes(dtype):
103103
104104MLIR_INF = {
105105 "inf" : {
106+ "f16" : 0x7C00 ,
106107 "f32" : 0x7F800000 ,
107108 "f64" : 0x7FF0000000000000
108109 },
109110 "-inf" : {
111+ "f16" : 0xFC00 ,
110112 "f32" : 0xFF800000 ,
111113 "f64" : 0xFFF0000000000000
112114 },
113115 "nan" : {
116+ "f16" : 0x7C00 ,
114117 "f32" : 0x7FC00000 ,
115118 "f64" : 0x7FF8000000000000
116119 }
@@ -260,17 +263,23 @@ def get_tile_stride_per_lane(self, tile_size: list[int], tile_stride: list[int])
260263 return tile_stride
261264
262265 def get_compute_vec_size (self , tile_size : list [int ], reduction_numel : int , nr_rdim : int ) -> int :
263- if self .forced_vec_size is not None :
264- return self .forced_vec_size
265-
266266 per_lane = self .get_numel_per_lane (tile_size )
267267 stride = self .vlane_stride
268268 if nr_rdim :
269269 val = per_lane // max (reduction_numel , 1 )
270+ result = val
270271 for mult in [8 , 4 , 2 ]:
271272 if per_lane >= val * mult :
272- return val * mult
273- return val
273+ result = val * mult
274+ break
275+ if self .forced_vec_size is not None :
276+ # Cap while keeping result divisible by val (= reduction_size).
277+ # This preserves the assert(vec_len % reduction_size == 0) invariant.
278+ capped = (min (result , self .forced_vec_size ) // max (val , 1 )) * max (val , 1 )
279+ result = max (capped , val )
280+ return result
281+ if self .forced_vec_size is not None :
282+ return self .forced_vec_size
274283 for mult in [8 , 4 , 2 ]:
275284 if (per_lane // stride ) >= mult :
276285 return stride * mult
@@ -787,10 +796,24 @@ def codegen_nodes(self, nodes, kernel_name):
787796 # Set node range info
788797 vars , reduction_vars = self .set_ranges (group , reduction_group )
789798 tile_desc = self .compute_tile_size (nodes , vars , reduction_vars )
799+ _ , _ , _ , self .buffer_types = self .kernel_group .args .mlir_argdefs ()
800+ safe_vec_size = self .get_safe_vec_size (tile_desc .get_compute_vec_size ())
801+ # For pointwise (non-reduction) kernels, cap the MLIR vector size so that
802+ # f16->f32 widening stays within LMUL<=4 (step and forced_vec_size must match).
803+ # Reduction kernels are left unchanged: their accumulator/multi_reduction
804+ # structure assumes compute_vec_size == step, so we must not split them here.
805+ tile_desc .vmap .forced_vec_size = safe_vec_size
806+ compute_vec = tile_desc .get_compute_vec_size ()
807+ # RVV requires vector lengths that produce integer power-of-2 LMUL values.
808+ # Non-power-of-2 element counts (e.g. 24) cause LLVM WidenVectorResult crashes.
809+ # Raise BEFORE the try/except so this propagates to make_choices (not retried).
810+ if compute_vec > 1 and (compute_vec & (compute_vec - 1 )) != 0 :
811+ raise RecompileSignal (
812+ f"Non-power-of-2 compute_vec_size { compute_vec } : tile rejected (RVV requires power-of-2 LMUL)"
813+ )
790814 self .compute_body_loop .size = tile_desc .get_numel_per_lane ()
791- self .compute_body_loop .step = tile_desc . get_compute_vec_size ()
815+ self .compute_body_loop .step = compute_vec
792816 try :
793- _ , _ , _ , self .buffer_types = self .kernel_group .args .mlir_argdefs ()
794817 with self as kernel :
795818 for node in nodes :
796819 node .run (vars , reduction_vars )
@@ -1035,6 +1058,42 @@ def __exit__(self, exc_type, exc_val, exc_tb):
10351058 self ._nested_context_depth -= 1
10361059 if self ._nested_context_depth == 0 :
10371060 super ().__exit__ (exc_type , exc_val , exc_tb )
1061+
1062+ def get_safe_vec_size (self , default_vec_size : int = 64 ) -> int :
1063+ """
1064+ Cap forced vector size for low-precision paths so widening ops
1065+ (e.g., f16/bf16 -> f32) do not exceed RVV LMUL limits.
1066+
1067+ Widening is legal up to source LMUL<=4 (destination LMUL<=8).
1068+ Using RVV relation LMUL = (SEW * VL) / VLEN, the safe source VL is:
1069+ VL <= 4 * VLEN / SEW
1070+ """
1071+
1072+ if not hasattr (self , "buffer_types" ) or not self .buffer_types :
1073+ return default_vec_size
1074+
1075+ lowp_bits = []
1076+ for info in self .buffer_types .values ():
1077+ dtype = info [0 ] if info else None
1078+ if dtype in DTYPE_LOWP_FP :
1079+ mlir_dtype = DTYPE_TO_MLIR [dtype ]
1080+ lowp_bits .append (MLIR_TO_BIT [mlir_dtype ])
1081+
1082+ if not lowp_bits :
1083+ return default_vec_size
1084+
1085+ min_lowp_bits = min (lowp_bits )
1086+ # Constraint: Vector element count must be compatible across all types.
1087+ # VLEN=256: f16 (LMUL=2) and f32 (LMUL=4) both yield 32 elements.
1088+ # Note: Gem5 version restricts widening ops to LMUL < 8 for destination registers.
1089+ # Max LMUL set to 2 to ensure compatibility/safety.
1090+
1091+ widen_safe_cap = self .vlen * 2 // min_lowp_bits
1092+ if widen_safe_cap <= 0 :
1093+ return default_vec_size
1094+
1095+ vec_size = min (default_vec_size , widen_safe_cap )
1096+ return vec_size
10381097
10391098@dataclasses .dataclass
10401099class LoopLevel :
0 commit comments