From 36109b7f23ee9e8c996c0c5ddecf0364cbdd76da Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 16 Dec 2025 18:21:49 -0800 Subject: [PATCH] [mxfp8 moe training] test floor vs rceil in 3d quantization benchmarks stack-info: PR: https://github.com/pytorch/ao/pull/3507, branch: danielvegamyhre/stack/89 --- .../moe_training/mxfp8/bench_quantize_3d.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py b/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py index b57ca81d4c..4cde6c9cb4 100644 --- a/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py +++ b/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py +import itertools from dataclasses import dataclass from typing import List @@ -17,6 +18,7 @@ from torchao.prototype.moe_training.scaled_grouped_mm import ( _to_mxfp8_dim1_3d, ) +from torchao.prototype.mx_formats.config import ScaleCalculationMode from torchao.prototype.mx_formats.mx_tensor import to_mx device = torch.device("cuda") @@ -28,6 +30,7 @@ @dataclass(frozen=True) class ExperimentConfig: input_shape: tuple[int] + scaling_mode: ScaleCalculationMode @dataclass(frozen=True) @@ -58,11 +61,13 @@ def get_configs() -> List[ExperimentConfig]: (16, 8192, 5120), (64, 8192, 5120), ] + round_modes = [ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL] configs = [] - for shape in input_shapes: + for shape, scaling_mode in itertools.product(input_shapes, round_modes): configs.append( ExperimentConfig( input_shape=shape, + scaling_mode=scaling_mode, ) ) return configs @@ -76,6 +81,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: dtype=torch.bfloat16, device=device, ) + scaling_mode = config.scaling_mode def using_to_mx(x: torch.Tensor) -> torch.Tensor: # Reference implementation @@ -85,6 +91,7 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor: x.transpose(-2, -1).contiguous(), elem_dtype=torch.float8_e4m3fn, block_size=block_size, + scaling_mode=scaling_mode, ) # Transpose tensors and scales back so we have effectively @@ -107,6 +114,8 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor: time_cuda_2d_us = benchmark_cuda_function_in_microseconds( using_cuda_2d_c, input_tensor, + block_size=block_size, + scaling_mode=scaling_mode, ) # bench 3d cuda kernel @@ -114,6 +123,8 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor: time_cuda_3d_us = benchmark_cuda_function_in_microseconds( mxfp8_quantize_cuda_3d, input_tensor, + block_size=block_size, + scaling_mode=str(scaling_mode.value), ) # mem bw calculations @@ -146,6 +157,7 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor: def print_results(experiments: List[Experiment]): headers = [ "input_shape", + "scaling_mode", "to_mx_us", "cuda_2d_us", "cuda_3d_us", @@ -158,6 +170,7 @@ def print_results(experiments: List[Experiment]): rows.append( [ str(experiment.config.input_shape), + str(experiment.config.scaling_mode), experiment.result.to_mx_us, experiment.result.cuda_2d_us, experiment.result.cuda_3d_us,