@@ -166,7 +166,7 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor:
166166 from torch .library import triton_op , wrap_triton
167167
168168 @triton .jit
169- def _triton_calculate_scale (x , axis ):
169+ def _triton_calculate_scale (x , axis , SCALING_MODE : tl . constexpr ):
170170 # There is no good support for accessing globals from a jit'ed triton
171171 # function, so we redefine them here. Since this is prototype code which
172172 # we plan to remove after torch.compile catches up, this is fine.
@@ -179,23 +179,48 @@ def _triton_calculate_scale(x, axis):
179179 # Find the maximum absolute value for each row
180180 max_abs = tl .max (x , axis = axis )
181181
182- # Calculate the e8m0 scale by extracting the exponent (floor)
183- # TODO(future PR): support other exponent extraction types (ceil, RNE)
184- max_abs = max_abs .to (tl .bfloat16 )
185- max_abs_int16 = max_abs .to (tl .int16 , bitcast = True )
186- extracted_pow2 = ((max_abs_int16 >> bf16_mbits ) & 0b11111111 ) - bf16_exp_bias
187- extracted_pow2 = extracted_pow2 - target_max_pow2
188- scale_e8m0_unbiased = extracted_pow2 .to (tl .bfloat16 )
189-
190- # Clamp to exponents that can be represented in e8m0
191- # Add 1 to capture NaNs
192- scale_e8m0_unbiased = tl .clamp (
193- scale_e8m0_unbiased , - 1 * e8m0_exponent_bias , e8m0_exponent_bias + 1
194- )
182+ # Compute e8m0 biased scale using either RCEIL or FLOOR rounding.
183+ if SCALING_MODE == "rceil" :
184+ # RCEIL scaling mode using PTX instruction supported on sm100.
185+ # The input should be: amax / 448.0
186+ # where 448.0 is the max representable value in FP8 E4M3 format.
187+ F8E4M3_MAX_RCP : tl .constexpr = 1.0 / 448.0
188+ scale_input = max_abs .to (tl .float32 ) * F8E4M3_MAX_RCP
189+
190+ # The PTX instruction outputs a packed uint16 where:
191+ # - high byte = E8M0 of first input (0.0 in our case)
192+ # - low byte = E8M0 of second input (scale_input)
193+ # Casting uint16 to uint8 naturally truncates to the low byte.
194+ scale_e8m0_biased = tl .inline_asm_elementwise (
195+ asm = "cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, $1;" ,
196+ constraints = "=h,r" ,
197+ args = [scale_input .to (tl .float32 , bitcast = False )],
198+ dtype = tl .uint16 ,
199+ is_pure = True ,
200+ pack = 1 ,
201+ ).to (tl .uint8 )
202+ else :
203+ tl .static_assert (SCALING_MODE == "floor" )
204+
205+ # Original floor implementation
206+ # Calculate the e8m0 scale by extracting the exponent (floor)
207+ max_abs = max_abs .to (tl .bfloat16 )
208+ max_abs_int16 = max_abs .to (tl .int16 , bitcast = True )
209+ extracted_pow2 = (
210+ (max_abs_int16 >> bf16_mbits ) & 0b11111111
211+ ) - bf16_exp_bias
212+ extracted_pow2 = extracted_pow2 - target_max_pow2
213+ scale_e8m0_unbiased = extracted_pow2 .to (tl .bfloat16 )
214+
215+ # Clamp to exponents that can be represented in e8m0
216+ # Add 1 to capture NaNs
217+ scale_e8m0_unbiased = tl .clamp (
218+ scale_e8m0_unbiased , - 1 * e8m0_exponent_bias , e8m0_exponent_bias + 1
219+ )
195220
196- # Create the biased e8m0 representation and cast it to 8 bits
197- scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
198- scale_e8m0_biased = scale_e8m0_biased .to (tl .uint8 )
221+ # Create the biased e8m0 representation and cast it to 8 bits
222+ scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
223+ scale_e8m0_biased = scale_e8m0_biased .to (tl .uint8 )
199224
200225 # TODO(future PR): add NaN handling here,
201226 # https://github.com/pytorch/pytorch/pull/100572 will likely be useful to
@@ -248,6 +273,7 @@ def to_mxfp8_dim1_kernel(
248273 ROW_TILE_SIZE : tl .constexpr ,
249274 COL_TILE_SIZE : tl .constexpr ,
250275 INNER_BLOCK_SIZE : tl .constexpr , # should be 32 for MX
276+ SCALING_MODE : tl .constexpr ,
251277 ):
252278 """
253279 Example tiling for n_rows==8, n_cols=8, ROW_TILE_SIZE=4, COL_TILE_SIZE=4, INNER_BLOCK_SIZE=2,
@@ -334,7 +360,11 @@ def to_mxfp8_dim1_kernel(
334360
335361 # Find the maximum absolute value for each column
336362 # shape: (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,)
337- col_scale_r , col_scale_e8m0_r = _triton_calculate_scale (x_block_abs_t_r , axis = 1 )
363+ col_scale_r , col_scale_e8m0_r = _triton_calculate_scale (
364+ x_block_abs_t_r ,
365+ axis = 1 ,
366+ SCALING_MODE = SCALING_MODE ,
367+ )
338368
339369 # Divide each column by scale
340370 # Broadcasting col_scale to match x_block's shape
@@ -397,6 +427,7 @@ def to_mxfp8_dim0_kernel(
397427 ROW_TILE_SIZE : tl .constexpr ,
398428 COL_TILE_SIZE : tl .constexpr ,
399429 SCALE_BLOCK_SIZE : tl .constexpr , # should be 32 for MX
430+ SCALING_MODE : tl .constexpr ,
400431 ):
401432 """
402433 Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
@@ -432,7 +463,9 @@ def to_mxfp8_dim0_kernel(
432463
433464 # Find the maximum absolute value for each row (across columns)
434465 # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
435- scale_fp32_r , scale_e8m0_r = _triton_calculate_scale (x_block_abs_r , axis = 1 )
466+ scale_fp32_r , scale_e8m0_r = _triton_calculate_scale (
467+ x_block_abs_r , axis = 1 , mode = SCALING_MODE
468+ )
436469
437470 # Divide each row by scale
438471 # Broadcasting scale to match x_block's shape
@@ -468,12 +501,15 @@ def to_mxfp8_dim0_kernel(
468501
469502 @triton_op ("torchao::triton_to_mxfp8_dim0" , mutates_args = {})
470503 def triton_to_mxfp8_dim0 (
471- x : torch .Tensor , inner_block_size : int = 32
504+ x : torch .Tensor ,
505+ inner_block_size : int = 32 ,
506+ scaling_mode : str = "rceil" ,
472507 ) -> Tuple [torch .Tensor , torch .Tensor ]:
473508 """
474509 Input:
475510 * `x` - input tensor, in row major memory layout
476511 * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
512+ * `scaling_mode` - floor or rceil
477513
478514 Output:
479515 * `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
@@ -518,6 +554,7 @@ def triton_to_mxfp8_dim0(
518554 n_rows = n_rows ,
519555 n_cols = n_cols ,
520556 SCALE_BLOCK_SIZE = inner_block_size ,
557+ SCALING_MODE = scaling_mode ,
521558 )
522559
523560 # Reshape output back to original shape
@@ -531,7 +568,7 @@ def triton_to_mxfp8_dim0(
531568
532569 @triton_op ("torchao::triton_to_mxfp8_dim1" , mutates_args = {})
533570 def triton_to_mxfp8_dim1 (
534- x : torch .Tensor , inner_block_size : int = 32
571+ x : torch .Tensor , inner_block_size : int = 32 , scaling_mode : str = "rceil"
535572 ) -> Tuple [torch .Tensor , torch .Tensor ]:
536573 """
537574 Input:
@@ -583,6 +620,7 @@ def triton_to_mxfp8_dim1(
583620 n_rows = n_rows ,
584621 n_cols = n_cols ,
585622 INNER_BLOCK_SIZE = inner_block_size ,
623+ SCALING_MODE = scaling_mode ,
586624 )
587625
588626 return (
0 commit comments