Skip to content

Commit b57a58e

Browse files
[mxfp8 moe training] integrate new cuda kernel for blocked layout for groups along K
stack-info: PR: #3505, branch: danielvegamyhre/stack/87
1 parent 0dfef18 commit b57a58e

File tree

4 files changed

+11
-13
lines changed

4 files changed

+11
-13
lines changed

torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,11 @@ mxfp8_quantize(const at::Tensor& input, bool rowwise, bool colwise,
114114
if (colwise) {
115115
const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y;
116116
output_colwise = at::empty_strided({rows, cols}, {1, rows}, options_fp8);
117-
// Need scales_colwise to be this shape so the 'col' dim stride is 1,
118-
// for colwise scaling, we can avoid uncoalesced writes to global memory.
119-
// This is because each of the 32 threads in a warp will be computing
120-
// a scale for a different column of 32 input data values, then each writing
121-
// that scale to global memory - so the stride along this `col` dim should be 1
122-
// so writes can be coalesced into a single transaction.
123-
scales_colwise = at::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale);
117+
118+
// Accept uncoalesced global stores for scale tensor, since row major is much for favorable for the subsequent
119+
// per-group blocked format kernel.
120+
// Microbenchmarks show the memory bandwidth utilization is virtually identical to coalesced global stores.
121+
scales_colwise = at::empty({cols, num_row_blocks}, options_scale);
124122
} else {
125123
output_colwise = at::empty({0}, options_fp8);
126124
scales_colwise = at::empty({0}, options_scale);

torchao/prototype/moe_training/kernels/mxfp8/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
2+
mx_block_rearrange_2d_K_groups_cuda, # noqa: F401
23
mxfp8_quantize_cuda_3d, # noqa: F401
34
torch_to_blocked_2d_K_groups, # noqa: F401
45
torch_to_blocked_2d_M_groups, # noqa: F401

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
triton_fp8_rowwise_3d_transpose_rhs,
1818
)
1919
from torchao.prototype.moe_training.kernels.mxfp8 import (
20+
mx_block_rearrange_2d_K_groups_cuda,
2021
mxfp8_quantize_cuda_3d,
21-
triton_mx_block_rearrange_2d_K_groups,
2222
triton_mx_block_rearrange_2d_M_groups,
2323
triton_mx_block_rearrange_per_group_3d,
2424
)
@@ -437,11 +437,11 @@ def backward(ctx, grad_out: torch.Tensor):
437437

438438
# Convert scales to blocked format for 2d-2d grouped mm
439439
scale_group_offsets = offs // block_size
440-
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
440+
grad_out_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda(
441441
grad_out_t_scales,
442442
scale_group_offsets,
443443
)
444-
A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
444+
A_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda(
445445
A_t_scales,
446446
scale_group_offsets,
447447
)

torchao/prototype/mx_formats/kernels.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,10 +1181,9 @@ def _fake_mxfp8_quantize(
11811181
(rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device
11821182
)
11831183

1184-
# colwise scales are written in column-major format to avoid uncoalesced global memory accesses
1185-
scales_colwise = torch.empty_strided(
1184+
# and microb
1185+
scales_colwise = torch.empty(
11861186
(cols, num_row_blocks),
1187-
(1, cols),
11881187
dtype=torch.float8_e8m0fnu,
11891188
device=x.device,
11901189
)

0 commit comments

Comments
 (0)