Skip to content

Commit 29a521a

Browse files
[mxfp8 moe training] add CUDA kernel for per-group conversion of scale factors to blocked layout
stack-info: PR: #3504, branch: danielvegamyhre/stack/86
1 parent 7035fb7 commit 29a521a

File tree

6 files changed

+871
-0
lines changed

6 files changed

+871
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
15+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
17+
mx_block_rearrange_2d_K_groups_cuda,
18+
torch_to_blocked_2d_K_groups,
19+
triton_mx_block_rearrange_2d_K_groups,
20+
)
21+
from torchao.prototype.moe_training.utils import generate_jagged_offs
22+
23+
device = torch.device("cuda")
24+
25+
# Needed since changing args to function causes recompiles
26+
torch._dynamo.config.cache_size_limit = 1000
27+
28+
29+
@dataclass(frozen=True)
30+
class ExperimentConfig:
31+
input_shape: tuple[int]
32+
num_groups: int
33+
version: str # "naive" or "parallel"
34+
35+
36+
@dataclass(frozen=True)
37+
class ExperimentResult:
38+
time_us: float
39+
mem_bw_gbps: float
40+
41+
42+
@dataclass(frozen=True)
43+
class Experiment:
44+
config: ExperimentConfig
45+
result: ExperimentResult
46+
47+
48+
def get_configs() -> List[ExperimentConfig]:
49+
# Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups.
50+
block_size = 32
51+
input_shapes = [
52+
(8192, 32768 // block_size),
53+
(8192, 65536 // block_size),
54+
(8192, 131072 // block_size),
55+
(5120, 32768 // block_size),
56+
(5120, 65536 // block_size),
57+
(5120, 131072 // block_size),
58+
(7168, 32768 // block_size),
59+
(7168, 65536 // block_size),
60+
(7168, 131072 // block_size),
61+
(2048, 32768 // block_size),
62+
(2048, 65536 // block_size),
63+
(2048, 131072 // block_size),
64+
]
65+
num_groups = [8]
66+
versions = [
67+
"torch",
68+
"triton",
69+
# CUDA kernel versions: cuda_{max_cols}_{chunks_per_tb}
70+
"cuda_64_4",
71+
"cuda_64_8",
72+
"cuda_64_16",
73+
"cuda_128_4",
74+
"cuda_128_8",
75+
"cuda_128_16",
76+
]
77+
78+
configs = []
79+
for shape, groups, version in itertools.product(
80+
input_shapes,
81+
num_groups,
82+
versions,
83+
):
84+
configs.append(
85+
ExperimentConfig(
86+
input_shape=shape,
87+
num_groups=groups,
88+
version=version,
89+
)
90+
)
91+
return configs
92+
93+
94+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
95+
input_shape, num_groups, version = (
96+
config.input_shape,
97+
config.num_groups,
98+
config.version,
99+
)
100+
input_tensor = torch.randint(
101+
low=0,
102+
high=256,
103+
size=input_shape,
104+
dtype=torch.uint8,
105+
device=device,
106+
)
107+
108+
M, Kg = input_shape
109+
block_size = 32
110+
input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size)
111+
112+
# Select which kernel to benchmark based on version
113+
if version == "torch":
114+
kernel_fn = torch_to_blocked_2d_K_groups
115+
kernel_input = input_tensor
116+
elif version == "triton":
117+
kernel_fn = triton_mx_block_rearrange_2d_K_groups
118+
# Triton uses row-major input
119+
kernel_input = input_tensor
120+
elif version.startswith("cuda_"):
121+
# Parse version string: cuda_{max_cols}_{chunks_per_tb}
122+
parts = version.split("_")
123+
max_cols = int(parts[1])
124+
chunks_per_tb = int(parts[2])
125+
kernel_fn = (
126+
lambda t,
127+
o,
128+
mc=max_cols,
129+
cptb=chunks_per_tb: mx_block_rearrange_2d_K_groups_cuda(
130+
t,
131+
o,
132+
max_cols=mc,
133+
chunks_per_tb=cptb,
134+
)
135+
)
136+
kernel_input = input_tensor.view(torch.float8_e8m0fnu)
137+
else:
138+
raise ValueError(f"Unknown version: {version}")
139+
140+
# Run kernel to get output shape
141+
outputs = kernel_fn(
142+
kernel_input,
143+
input_group_offsets,
144+
)
145+
if isinstance(outputs, tuple): # torch returns a tuple with extra metadata
146+
out_scales, _ = outputs
147+
else:
148+
out_scales = outputs
149+
150+
# Benchmark the kernel
151+
time_us = benchmark_cuda_function_in_microseconds(
152+
kernel_fn,
153+
kernel_input,
154+
input_group_offsets,
155+
)
156+
157+
# Calculate memory bandwidth
158+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
159+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
160+
161+
read_bytes = input_tensor.numel() * bytes_per_input_el
162+
write_bytes = out_scales.numel() * bytes_per_output_el
163+
164+
mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (time_us / 1e6)
165+
166+
return ExperimentResult(
167+
time_us=time_us,
168+
mem_bw_gbps=mem_bw_gbps,
169+
)
170+
171+
172+
def print_results(experiments: List[Experiment]):
173+
# Group experiments by input shape
174+
shapes_dict = {}
175+
for exp in experiments:
176+
shape_key = exp.config.input_shape
177+
if shape_key not in shapes_dict:
178+
shapes_dict[shape_key] = {}
179+
shapes_dict[shape_key][exp.config.version] = exp.result
180+
181+
headers = [
182+
"kernel_version",
183+
"scale_shape",
184+
"time_us",
185+
"mem_bw_gbps",
186+
"speedup_vs_torch",
187+
"speedup_vs_triton",
188+
]
189+
190+
rows = []
191+
for shape, versions in shapes_dict.items():
192+
# Get torch baseline time for speedup calculation
193+
torch_time_us = versions.get("torch").time_us if "torch" in versions else None
194+
195+
# Get triton baseline time for speedup calculation
196+
triton_time_us = (
197+
versions.get("triton").time_us if "triton" in versions else None
198+
)
199+
200+
# Add rows for each version
201+
for version, result in versions.items():
202+
# Calculate speedup vs torch
203+
speedup_vs_torch_str = ""
204+
if version != "torch" and torch_time_us is not None:
205+
speedup = torch_time_us / result.time_us
206+
speedup_vs_torch_str = f"{speedup:.2f}x"
207+
208+
# Calculate speedup vs triton (only for CUDA kernels)
209+
speedup_vs_triton_str = ""
210+
if version.startswith("cuda_") and triton_time_us is not None:
211+
speedup = triton_time_us / result.time_us
212+
speedup_vs_triton_str = f"{speedup:.2f}x"
213+
214+
rows.append(
215+
[
216+
version,
217+
f"({shape[0]}, {shape[1]})",
218+
f"{result.time_us:.2f}",
219+
round(result.mem_bw_gbps, 3),
220+
speedup_vs_torch_str,
221+
speedup_vs_triton_str,
222+
]
223+
)
224+
225+
# Find best CUDA kernel speedup vs triton for this shape
226+
best_cuda_speedup = 0.0
227+
best_cuda_version = None
228+
for version, result in versions.items():
229+
if version.startswith("cuda_") and triton_time_us is not None:
230+
speedup = triton_time_us / result.time_us
231+
if speedup > best_cuda_speedup:
232+
best_cuda_speedup = speedup
233+
best_cuda_version = version
234+
235+
if best_cuda_version is not None:
236+
rows.append(
237+
[
238+
f">>> BEST: {best_cuda_speedup:.2f}x vs triton with {best_cuda_version}",
239+
"",
240+
"",
241+
"",
242+
"",
243+
]
244+
)
245+
246+
# Add empty row for visual separation between shapes
247+
rows.append([""] * len(headers))
248+
249+
print(tabulate(rows, headers=headers))
250+
251+
252+
def main():
253+
torch.random.manual_seed(123)
254+
configs = get_configs()
255+
results = []
256+
for config in tqdm(configs):
257+
result = run_experiment(config)
258+
results.append(Experiment(config=config, result=result))
259+
260+
# Use Tabulate to print results
261+
print_results(results)
262+
263+
264+
if __name__ == "__main__":
265+
main()

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ def get_extensions():
709709
mxfp8_sources = [
710710
os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"),
711711
os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"),
712+
os.path.join(mxfp8_extension_dir, "mx_block_rearrange_2d_K_groups.cu"),
712713
]
713714

714715
# Only add the extension if the source files exist AND we are building for sm100

test/prototype/moe_training/test_kernels.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,57 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
352352
# Check quantized values
353353
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
354354
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"
355+
356+
357+
@pytest.mark.skipif(
358+
not is_sm_at_least_100(),
359+
reason="MXFP8 requires CUDA capability 10.0 or greater",
360+
)
361+
@pytest.mark.parametrize("m", [256, 512, 1024, 5120])
362+
@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384])
363+
@pytest.mark.parametrize("n_groups", [1, 4, 8, 16])
364+
def test_cuda_mx_block_rearrange_2d_K_groups(
365+
m: int,
366+
total_k: int,
367+
n_groups: int,
368+
):
369+
"""
370+
Test CUDA kernel for mx_block_rearrange_2d_K_groups against Triton reference.
371+
"""
372+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
373+
mx_block_rearrange_2d_K_groups_cuda,
374+
)
375+
376+
device = "cuda"
377+
block_size = 32
378+
input_data = torch.randn(m, total_k, device=device)
379+
380+
e8m0_scales, _ = to_mx(
381+
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
382+
)
383+
384+
# Generate group end offsets along total_K, then divide by block_size to get scale group end offsets
385+
input_group_offsets = generate_jagged_offs(
386+
n_groups, total_k, multiple_of=block_size, device=device
387+
)
388+
scale_group_offsets = input_group_offsets // block_size
389+
390+
# Triton reference implementation
391+
triton_out_scales = triton_mx_block_rearrange_2d_K_groups(
392+
e8m0_scales,
393+
scale_group_offsets,
394+
)
395+
396+
# CUDA kernel implementation
397+
cuda_out_scales = mx_block_rearrange_2d_K_groups_cuda(
398+
e8m0_scales,
399+
scale_group_offsets,
400+
)
401+
402+
# Check that outputs match
403+
assert torch.equal(triton_out_scales, cuda_out_scales.view(torch.float8_e8m0fnu)), (
404+
"CUDA and Triton blocked scales not equal"
405+
)
406+
407+
# Check strides
408+
assert triton_out_scales.stride() == cuda_out_scales.stride(), "strides not equal"

0 commit comments

Comments
 (0)