Skip to content

Commit 3d4d7b1

Browse files
[mxfp8 moe training] test floor vs rceil in 3d quantization benchmarks
stack-info: PR: #3507, branch: danielvegamyhre/stack/89
1 parent 613953d commit 3d4d7b1

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77

8+
import itertools
89
from dataclasses import dataclass
910
from typing import List
1011

@@ -17,6 +18,7 @@
1718
from torchao.prototype.moe_training.scaled_grouped_mm import (
1819
_to_mxfp8_dim1_3d,
1920
)
21+
from torchao.prototype.mx_formats.config import ScaleCalculationMode
2022
from torchao.prototype.mx_formats.mx_tensor import to_mx
2123

2224
device = torch.device("cuda")
@@ -28,6 +30,7 @@
2830
@dataclass(frozen=True)
2931
class ExperimentConfig:
3032
input_shape: tuple[int]
33+
scaling_mode: ScaleCalculationMode
3134

3235

3336
@dataclass(frozen=True)
@@ -58,11 +61,13 @@ def get_configs() -> List[ExperimentConfig]:
5861
(16, 8192, 5120),
5962
(64, 8192, 5120),
6063
]
64+
round_modes = [ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL]
6165
configs = []
62-
for shape in input_shapes:
66+
for shape, scaling_mode in itertools.product(input_shapes, round_modes):
6367
configs.append(
6468
ExperimentConfig(
6569
input_shape=shape,
70+
scaling_mode=scaling_mode,
6671
)
6772
)
6873
return configs
@@ -76,6 +81,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7681
dtype=torch.bfloat16,
7782
device=device,
7883
)
84+
scaling_mode = config.scaling_mode
7985

8086
def using_to_mx(x: torch.Tensor) -> torch.Tensor:
8187
# Reference implementation
@@ -85,6 +91,7 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
8591
x.transpose(-2, -1).contiguous(),
8692
elem_dtype=torch.float8_e4m3fn,
8793
block_size=block_size,
94+
scaling_mode=scaling_mode,
8895
)
8996

9097
# Transpose tensors and scales back so we have effectively
@@ -107,13 +114,17 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
107114
time_cuda_2d_us = benchmark_cuda_function_in_microseconds(
108115
using_cuda_2d_c,
109116
input_tensor,
117+
block_size=block_size,
118+
scaling_mode=scaling_mode,
110119
)
111120

112121
# bench 3d cuda kernel
113122
data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(input_tensor)
114123
time_cuda_3d_us = benchmark_cuda_function_in_microseconds(
115124
mxfp8_quantize_cuda_3d,
116125
input_tensor,
126+
block_size=block_size,
127+
scaling_mode=str(scaling_mode.value),
117128
)
118129

119130
# mem bw calculations
@@ -146,6 +157,7 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
146157
def print_results(experiments: List[Experiment]):
147158
headers = [
148159
"input_shape",
160+
"scaling_mode",
149161
"to_mx_us",
150162
"cuda_2d_us",
151163
"cuda_3d_us",
@@ -158,6 +170,7 @@ def print_results(experiments: List[Experiment]):
158170
rows.append(
159171
[
160172
str(experiment.config.input_shape),
173+
str(experiment.config.scaling_mode),
161174
experiment.result.to_mx_us,
162175
experiment.result.cuda_2d_us,
163176
experiment.result.cuda_3d_us,

0 commit comments

Comments
 (0)