Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import itertools
from dataclasses import dataclass
from typing import List

Expand All @@ -17,6 +18,7 @@
from torchao.prototype.moe_training.scaled_grouped_mm import (
_to_mxfp8_dim1_3d,
)
from torchao.prototype.mx_formats.config import ScaleCalculationMode
from torchao.prototype.mx_formats.mx_tensor import to_mx

device = torch.device("cuda")
Expand All @@ -28,6 +30,7 @@
@dataclass(frozen=True)
class ExperimentConfig:
input_shape: tuple[int]
scaling_mode: ScaleCalculationMode


@dataclass(frozen=True)
Expand Down Expand Up @@ -58,11 +61,13 @@ def get_configs() -> List[ExperimentConfig]:
(16, 8192, 5120),
(64, 8192, 5120),
]
round_modes = [ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL]
configs = []
for shape in input_shapes:
for shape, scaling_mode in itertools.product(input_shapes, round_modes):
configs.append(
ExperimentConfig(
input_shape=shape,
scaling_mode=scaling_mode,
)
)
return configs
Expand All @@ -76,6 +81,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
dtype=torch.bfloat16,
device=device,
)
scaling_mode = config.scaling_mode

def using_to_mx(x: torch.Tensor) -> torch.Tensor:
# Reference implementation
Expand All @@ -85,6 +91,7 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
x.transpose(-2, -1).contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
scaling_mode=scaling_mode,
)

# Transpose tensors and scales back so we have effectively
Expand All @@ -107,13 +114,17 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
time_cuda_2d_us = benchmark_cuda_function_in_microseconds(
using_cuda_2d_c,
input_tensor,
block_size=block_size,
scaling_mode=scaling_mode,
)

# bench 3d cuda kernel
data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(input_tensor)
time_cuda_3d_us = benchmark_cuda_function_in_microseconds(
mxfp8_quantize_cuda_3d,
input_tensor,
block_size=block_size,
scaling_mode=str(scaling_mode.value),
)

# mem bw calculations
Expand Down Expand Up @@ -146,6 +157,7 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"scaling_mode",
"to_mx_us",
"cuda_2d_us",
"cuda_3d_us",
Expand All @@ -158,6 +170,7 @@ def print_results(experiments: List[Experiment]):
rows.append(
[
str(experiment.config.input_shape),
str(experiment.config.scaling_mode),
experiment.result.to_mx_us,
experiment.result.cuda_2d_us,
experiment.result.cuda_3d_us,
Expand Down
Loading