Skip to content

Commit 8806b02

Browse files
[mxfp8] support RCEIL in triton_to_mxfp8_dim0 kernel with inline PTX (#3498)
1 parent 1f9bfd7 commit 8806b02

File tree

4 files changed

+134
-38
lines changed

4 files changed

+134
-38
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def run(
112112
"dim0_mxfp4_floor",
113113
"dim0_mxfp8_rceil",
114114
"dim0_mxfp8_triton_floor",
115+
"dim0_mxfp8_triton_rceil",
115116
"dim0_nvfp4",
116117
"dim0_nvfp4_triton_swizzle",
117118
"dim1_mxfp8_floor",
@@ -243,9 +244,33 @@ def run(
243244
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
244245

245246
for _ in range(2):
246-
__ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
247+
__ = triton_to_mxfp8_dim0(
248+
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
249+
)
247250
time_us = benchmark_cuda_function_in_microseconds(
248-
lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE),
251+
lambda x, b: triton_to_mxfp8_dim0(
252+
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
253+
),
254+
x,
255+
BLOCK_SIZE,
256+
)
257+
assert y_d0.dtype == torch.float8_e4m3fn
258+
assert s_d0.dtype == torch.float8_e8m0fnu
259+
bytes_r = x.numel() * bytes_per_el_bf16
260+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
261+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
262+
263+
elif mode == "dim0_mxfp8_triton_rceil":
264+
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
265+
266+
for _ in range(2):
267+
__ = triton_to_mxfp8_dim0(
268+
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
269+
)
270+
time_us = benchmark_cuda_function_in_microseconds(
271+
lambda x, b: triton_to_mxfp8_dim0(
272+
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
273+
),
249274
x,
250275
BLOCK_SIZE,
251276
)

test/prototype/mx_formats/test_kernels.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from torchao.prototype.mx_formats.utils import to_blocked
4646
from torchao.utils import (
4747
is_cuda_version_at_least,
48-
is_sm_at_least_89,
4948
is_sm_at_least_100,
5049
torch_version_at_least,
5150
)
@@ -423,15 +422,19 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
423422

424423

425424
def triton_to_mxfp8_dim0_reference(
426-
x_hp: torch.Tensor, block_size
425+
x_hp: torch.Tensor,
426+
block_size,
427+
scaling_mode=ScaleCalculationMode.FLOOR,
427428
) -> tuple[torch.Tensor, torch.Tensor]:
428429
"""
429430
A reference version of `triton_to_mxfp8_dim0` for rowwise quantization.
430431
"""
431432
from torchao.prototype.mx_formats.mx_tensor import to_mx
432433

433434
# cast across dim0 (rowwise) - no transpose needed
434-
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
435+
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(
436+
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
437+
)
435438
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
436439
return (
437440
x_hp_d0_normalized,
@@ -441,8 +444,8 @@ def triton_to_mxfp8_dim0_reference(
441444

442445
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
443446
@pytest.mark.skipif(
444-
not is_sm_at_least_89(),
445-
reason="float8 in triton requires CUDA capability 8.9 or greater",
447+
not is_sm_at_least_100(),
448+
reason="mxfp8 in triton requires CUDA capability 10.0 or greater",
446449
)
447450
@pytest.mark.parametrize("M", (128, 256))
448451
@pytest.mark.parametrize("K", (128, 256))
@@ -461,10 +464,19 @@ def test_triton_mxfp8_dim1_randn(M, K):
461464
)
462465
@pytest.mark.parametrize("M", (128, 256))
463466
@pytest.mark.parametrize("K", (128, 256))
464-
def test_triton_mxfp8_dim0_randn(M, K):
467+
@pytest.mark.parametrize(
468+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
469+
)
470+
def test_triton_mxfp8_dim0_randn(M, K, scaling_mode):
465471
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
466-
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
467-
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
472+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
473+
x, block_size=32, scaling_mode=scaling_mode
474+
)
475+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
476+
x,
477+
inner_block_size=32,
478+
scaling_mode=scaling_mode.value.lower(),
479+
)
468480
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
469481
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
470482

@@ -474,10 +486,19 @@ def test_triton_mxfp8_dim0_randn(M, K):
474486
not is_sm_at_least_100(),
475487
reason="mxfp8 requires CUDA capability 10.0 or greater",
476488
)
477-
def test_triton_mxfp8_dim0_zeros():
489+
@pytest.mark.parametrize(
490+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
491+
)
492+
def test_triton_mxfp8_dim0_zeros(scaling_mode):
478493
x = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda")
479-
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
480-
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
494+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
495+
x, block_size=32, scaling_mode=scaling_mode
496+
)
497+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
498+
x,
499+
inner_block_size=32,
500+
scaling_mode=scaling_mode.value.lower(),
501+
)
481502
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
482503
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
483504
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,13 @@ def test_linear_eager_vs_hp(
121121
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
122122

123123
if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
124-
if scale_calculation_mode != ScaleCalculationMode.FLOOR:
125-
pytest.skip("unsupported configuration")
124+
if scale_calculation_mode not in (
125+
ScaleCalculationMode.FLOOR,
126+
ScaleCalculationMode.RCEIL,
127+
):
128+
pytest.skip("triton mxfp8 quantization kernels only require sm100")
129+
if not is_sm_at_least_100():
130+
pytest.skip("triton mxfp8 quantization kernels require sm100")
126131
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
127132
if scale_calculation_mode not in (
128133
ScaleCalculationMode.FLOOR,
@@ -316,8 +321,15 @@ def test_linear_compile(
316321
pytest.skip("unsupported configuration")
317322

318323
if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
319-
if scale_calculation_mode != ScaleCalculationMode.FLOOR:
320-
pytest.skip("unsupported configuration")
324+
if scale_calculation_mode not in (
325+
ScaleCalculationMode.FLOOR,
326+
ScaleCalculationMode.RCEIL,
327+
):
328+
pytest.skip(
329+
"triton mxfp8 quantization kernels only support FLOOR and RCEIL scaling modes"
330+
)
331+
if is_sm_at_least_100():
332+
pytest.skip("triton mxfp8 quantization kernels require sm100")
321333
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
322334
if scale_calculation_mode not in (
323335
ScaleCalculationMode.FLOOR,

torchao/prototype/mx_formats/kernels.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor:
166166
from torch.library import triton_op, wrap_triton
167167

168168
@triton.jit
169-
def _triton_calculate_scale(x, axis):
169+
def _triton_calculate_scale(x, axis, SCALING_MODE: tl.constexpr):
170170
# There is no good support for accessing globals from a jit'ed triton
171171
# function, so we redefine them here. Since this is prototype code which
172172
# we plan to remove after torch.compile catches up, this is fine.
@@ -179,23 +179,48 @@ def _triton_calculate_scale(x, axis):
179179
# Find the maximum absolute value for each row
180180
max_abs = tl.max(x, axis=axis)
181181

182-
# Calculate the e8m0 scale by extracting the exponent (floor)
183-
# TODO(future PR): support other exponent extraction types (ceil, RNE)
184-
max_abs = max_abs.to(tl.bfloat16)
185-
max_abs_int16 = max_abs.to(tl.int16, bitcast=True)
186-
extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias
187-
extracted_pow2 = extracted_pow2 - target_max_pow2
188-
scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16)
189-
190-
# Clamp to exponents that can be represented in e8m0
191-
# Add 1 to capture NaNs
192-
scale_e8m0_unbiased = tl.clamp(
193-
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias + 1
194-
)
182+
# Compute e8m0 biased scale using either RCEIL or FLOOR rounding.
183+
if SCALING_MODE == "rceil":
184+
# RCEIL scaling mode using PTX instruction supported on sm100.
185+
# The input should be: amax / 448.0
186+
# where 448.0 is the max representable value in FP8 E4M3 format.
187+
F8E4M3_MAX_RCP: tl.constexpr = 1.0 / 448.0
188+
scale_input = max_abs.to(tl.float32) * F8E4M3_MAX_RCP
189+
190+
# The PTX instruction outputs a packed uint16 where:
191+
# - high byte = E8M0 of first input (0.0 in our case)
192+
# - low byte = E8M0 of second input (scale_input)
193+
# Casting uint16 to uint8 naturally truncates to the low byte.
194+
scale_e8m0_biased = tl.inline_asm_elementwise(
195+
asm="cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, $1;",
196+
constraints="=h,r",
197+
args=[scale_input.to(tl.float32, bitcast=False)],
198+
dtype=tl.uint16,
199+
is_pure=True,
200+
pack=1,
201+
).to(tl.uint8)
202+
else:
203+
tl.static_assert(SCALING_MODE == "floor")
204+
205+
# Original floor implementation
206+
# Calculate the e8m0 scale by extracting the exponent (floor)
207+
max_abs = max_abs.to(tl.bfloat16)
208+
max_abs_int16 = max_abs.to(tl.int16, bitcast=True)
209+
extracted_pow2 = (
210+
(max_abs_int16 >> bf16_mbits) & 0b11111111
211+
) - bf16_exp_bias
212+
extracted_pow2 = extracted_pow2 - target_max_pow2
213+
scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16)
214+
215+
# Clamp to exponents that can be represented in e8m0
216+
# Add 1 to capture NaNs
217+
scale_e8m0_unbiased = tl.clamp(
218+
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias + 1
219+
)
195220

196-
# Create the biased e8m0 representation and cast it to 8 bits
197-
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
198-
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
221+
# Create the biased e8m0 representation and cast it to 8 bits
222+
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
223+
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
199224

200225
# TODO(future PR): add NaN handling here,
201226
# https://github.com/pytorch/pytorch/pull/100572 will likely be useful to
@@ -248,6 +273,7 @@ def to_mxfp8_dim1_kernel(
248273
ROW_TILE_SIZE: tl.constexpr,
249274
COL_TILE_SIZE: tl.constexpr,
250275
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
276+
SCALING_MODE: tl.constexpr,
251277
):
252278
"""
253279
Example tiling for n_rows==8, n_cols=8, ROW_TILE_SIZE=4, COL_TILE_SIZE=4, INNER_BLOCK_SIZE=2,
@@ -334,7 +360,11 @@ def to_mxfp8_dim1_kernel(
334360

335361
# Find the maximum absolute value for each column
336362
# shape: (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,)
337-
col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(x_block_abs_t_r, axis=1)
363+
col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(
364+
x_block_abs_t_r,
365+
axis=1,
366+
SCALING_MODE=SCALING_MODE,
367+
)
338368

339369
# Divide each column by scale
340370
# Broadcasting col_scale to match x_block's shape
@@ -397,6 +427,7 @@ def to_mxfp8_dim0_kernel(
397427
ROW_TILE_SIZE: tl.constexpr,
398428
COL_TILE_SIZE: tl.constexpr,
399429
SCALE_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
430+
SCALING_MODE: tl.constexpr,
400431
):
401432
"""
402433
Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
@@ -432,7 +463,9 @@ def to_mxfp8_dim0_kernel(
432463

433464
# Find the maximum absolute value for each row (across columns)
434465
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
435-
scale_fp32_r, scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)
466+
scale_fp32_r, scale_e8m0_r = _triton_calculate_scale(
467+
x_block_abs_r, axis=1, mode=SCALING_MODE
468+
)
436469

437470
# Divide each row by scale
438471
# Broadcasting scale to match x_block's shape
@@ -468,12 +501,15 @@ def to_mxfp8_dim0_kernel(
468501

469502
@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
470503
def triton_to_mxfp8_dim0(
471-
x: torch.Tensor, inner_block_size: int = 32
504+
x: torch.Tensor,
505+
inner_block_size: int = 32,
506+
scaling_mode: str = "rceil",
472507
) -> Tuple[torch.Tensor, torch.Tensor]:
473508
"""
474509
Input:
475510
* `x` - input tensor, in row major memory layout
476511
* `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
512+
* `scaling_mode` - floor or rceil
477513
478514
Output:
479515
* `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
@@ -518,6 +554,7 @@ def triton_to_mxfp8_dim0(
518554
n_rows=n_rows,
519555
n_cols=n_cols,
520556
SCALE_BLOCK_SIZE=inner_block_size,
557+
SCALING_MODE=scaling_mode,
521558
)
522559

523560
# Reshape output back to original shape
@@ -531,7 +568,7 @@ def triton_to_mxfp8_dim0(
531568

532569
@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
533570
def triton_to_mxfp8_dim1(
534-
x: torch.Tensor, inner_block_size: int = 32
571+
x: torch.Tensor, inner_block_size: int = 32, scaling_mode: str = "rceil"
535572
) -> Tuple[torch.Tensor, torch.Tensor]:
536573
"""
537574
Input:
@@ -583,6 +620,7 @@ def triton_to_mxfp8_dim1(
583620
n_rows=n_rows,
584621
n_cols=n_cols,
585622
INNER_BLOCK_SIZE=inner_block_size,
623+
SCALING_MODE=scaling_mode,
586624
)
587625

588626
return (

0 commit comments

Comments
 (0)