Skip to content

Commit c076eb1

Browse files
[mxfp8 moe training] integrate new cuda kernel for blocked layout for groups along K
stack-info: PR: #3505, branch: danielvegamyhre/stack/87
1 parent 492b8ce commit c076eb1

File tree

4 files changed

+62
-64
lines changed

4 files changed

+62
-64
lines changed

torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,11 @@ mxfp8_quantize(const at::Tensor& input, bool rowwise, bool colwise,
114114
if (colwise) {
115115
const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y;
116116
output_colwise = at::empty_strided({rows, cols}, {1, rows}, options_fp8);
117-
// Need scales_colwise to be this shape so the 'col' dim stride is 1,
118-
// for colwise scaling, we can avoid uncoalesced writes to global memory.
119-
// This is because each of the 32 threads in a warp will be computing
120-
// a scale for a different column of 32 input data values, then each writing
121-
// that scale to global memory - so the stride along this `col` dim should be 1
122-
// so writes can be coalesced into a single transaction.
123-
scales_colwise = at::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale);
117+
118+
// Accept uncoalesced global stores for scale tensor, since row major is much for favorable for the subsequent
119+
// per-group blocked format kernel.
120+
// Microbenchmarks show the memory bandwidth utilization is virtually identical to coalesced global stores.
121+
scales_colwise = at::empty({cols, num_row_blocks}, options_scale);
124122
} else {
125123
output_colwise = at::empty({0}, options_fp8);
126124
scales_colwise = at::empty({0}, options_scale);

torchao/prototype/moe_training/kernels/mxfp8/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
2+
mx_block_rearrange_2d_K_groups_cuda, # noqa: F401
23
mxfp8_quantize_cuda_3d, # noqa: F401
34
torch_to_blocked_2d_K_groups, # noqa: F401
45
torch_to_blocked_2d_M_groups, # noqa: F401

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
triton_fp8_rowwise_3d_transpose_rhs,
1818
)
1919
from torchao.prototype.moe_training.kernels.mxfp8 import (
20+
mx_block_rearrange_2d_K_groups_cuda,
2021
mxfp8_quantize_cuda_3d,
21-
triton_mx_block_rearrange_2d_K_groups,
2222
triton_mx_block_rearrange_2d_M_groups,
2323
triton_mx_block_rearrange_per_group_3d,
2424
)
@@ -92,28 +92,28 @@ def forward(
9292
assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D"
9393
assert B_t.ndim == 3, "B must be 3D"
9494

95-
assert A.size(-1) % 16 == 0, (
96-
f"A must have a last dim divisible by 16, but got shape: {A.shape}"
97-
)
98-
assert B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0, (
99-
f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}"
100-
)
95+
assert (
96+
A.size(-1) % 16 == 0
97+
), f"A must have a last dim divisible by 16, but got shape: {A.shape}"
98+
assert (
99+
B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0
100+
), f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}"
101101

102102
# Assert input tensors are in high-precision dtypes.
103-
assert A.dtype == torch.float32 or A.dtype == torch.bfloat16, (
104-
"A must be float32 or bfloat16"
105-
)
106-
assert B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16, (
107-
"B must be float32 or bfloat16"
108-
)
109-
assert offs is None or offs.dtype == torch.int32, (
110-
"offs must be int32 tensor or None"
111-
)
103+
assert (
104+
A.dtype == torch.float32 or A.dtype == torch.bfloat16
105+
), "A must be float32 or bfloat16"
106+
assert (
107+
B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16
108+
), "B must be float32 or bfloat16"
109+
assert (
110+
offs is None or offs.dtype == torch.int32
111+
), "offs must be int32 tensor or None"
112112

113113
# Assert A and B dims are compatible for a scaled grouped GEMM.
114-
assert A.size(-1) == B_t.size(-2), (
115-
f"shape {A.shape} and {B_t.shape} are not compatible for _quantize_then_scaled_grouped_mm"
116-
)
114+
assert A.size(-1) == B_t.size(
115+
-2
116+
), f"shape {A.shape} and {B_t.shape} are not compatible for _quantize_then_scaled_grouped_mm"
117117

118118
# The left operand in the scaled grouped GEMM must be row-major due to hardware requirements.
119119
assert not _is_column_major(A), "A must be row-major"
@@ -154,12 +154,12 @@ def forward(
154154

155155
# Perform scaled grouped GEMM and return result.
156156
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
157-
assert not _is_column_major(A_data_row_major), (
158-
"A must be row-major for output = A @ B"
159-
)
160-
assert _is_column_major(B_t_data_col_major), (
161-
"B must be column-major for output = A @ B"
162-
)
157+
assert not _is_column_major(
158+
A_data_row_major
159+
), "A must be row-major for output = A @ B"
160+
assert _is_column_major(
161+
B_t_data_col_major
162+
), "B must be column-major for output = A @ B"
163163

164164
# Squeeze empty dims out of scales, to comply with grouped mm API.
165165
# A_scales shape: (M,1) or (B, M, 1)
@@ -209,12 +209,12 @@ def backward(ctx, grad_output: torch.Tensor):
209209
# Compute grad_A.
210210
# grad_A = grad_output @ B
211211
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
212-
assert not _is_column_major(grad_output_data_row_major), (
213-
"grad_output must be row-major for grad_A = grad_output @ B"
214-
)
215-
assert _is_column_major(B_data_col_major), (
216-
"B must be column-major for grad_A = grad_output @ B"
217-
)
212+
assert not _is_column_major(
213+
grad_output_data_row_major
214+
), "grad_output must be row-major for grad_A = grad_output @ B"
215+
assert _is_column_major(
216+
B_data_col_major
217+
), "B must be column-major for grad_A = grad_output @ B"
218218

219219
# Squeeze empty dims out of scales, to comply with grouped mm API.
220220
# grad_output_scales shape: (M,1) or (B, M, 1)
@@ -259,12 +259,12 @@ def backward(ctx, grad_output: torch.Tensor):
259259

260260
# Compute grad_B = grad_output_t @ A.
261261
# grad_B = grad_output_t @ A
262-
assert not _is_column_major(grad_output_t_data_row_major), (
263-
"grad_output_t must be row-major for grad_B = grad_output_t @ A"
264-
)
265-
assert _is_column_major(A_data_col_major), (
266-
"A must be column-major for grad_B = grad_output_t @ A"
267-
)
262+
assert not _is_column_major(
263+
grad_output_t_data_row_major
264+
), "grad_output_t must be row-major for grad_B = grad_output_t @ A"
265+
assert _is_column_major(
266+
A_data_col_major
267+
), "A must be column-major for grad_B = grad_output_t @ A"
268268

269269
# Per-token group scales computed via triton kernels above do not have
270270
# the empty dim like the scales computed via tensor_to_scale, so we need
@@ -449,11 +449,11 @@ def backward(ctx, grad_out: torch.Tensor):
449449

450450
# Convert scales to blocked format for 2d-2d grouped mm
451451
scale_group_offsets = offs // block_size
452-
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
452+
grad_out_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda(
453453
grad_out_t_scales,
454454
scale_group_offsets,
455455
)
456-
A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
456+
A_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda(
457457
A_t_scales,
458458
scale_group_offsets,
459459
)
@@ -518,21 +518,21 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
518518
) -> torch.Tensor:
519519
assert A_data.ndim == 2, f"A must be 2D, got {A_data.ndim}"
520520
assert B_data.ndim == 3, f"B must be 3D, got {B_data.ndim}"
521-
assert A_scale.shape[0] == A_data.shape[0], (
522-
f"A_scale must have same M dim as A_data, got A={A_data.shape} and A_scale={A_scale.shape}"
523-
)
524-
assert A_scale.shape[1] == A_data.shape[1] // block_size, (
525-
f"A_scale dim1 should be size K//block_size, got A={A_data.shape} and A_scale={A_scale.shape}"
526-
)
527-
assert B_scale.shape[0] == B_data.shape[0], (
528-
f"B_scale must have same E dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}"
529-
)
530-
assert B_scale.shape[1] == B_data.shape[1], (
531-
f"B_scale must have same N dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}"
532-
)
533-
assert B_scale.shape[2] == B_data.shape[2] // block_size, (
534-
f"B_scale dim2 should be size K//block_size, got B={B_data.shape} and B_scale={B_scale.shape}"
535-
)
521+
assert (
522+
A_scale.shape[0] == A_data.shape[0]
523+
), f"A_scale must have same M dim as A_data, got A={A_data.shape} and A_scale={A_scale.shape}"
524+
assert (
525+
A_scale.shape[1] == A_data.shape[1] // block_size
526+
), f"A_scale dim1 should be size K//block_size, got A={A_data.shape} and A_scale={A_scale.shape}"
527+
assert (
528+
B_scale.shape[0] == B_data.shape[0]
529+
), f"B_scale must have same E dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}"
530+
assert (
531+
B_scale.shape[1] == B_data.shape[1]
532+
), f"B_scale must have same N dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}"
533+
assert (
534+
B_scale.shape[2] == B_data.shape[2] // block_size
535+
), f"B_scale dim2 should be size K//block_size, got B={B_data.shape} and B_scale={B_scale.shape}"
536536

537537
# Dequantize input
538538
# A_data shape: (M, K)

torchao/prototype/mx_formats/kernels.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,10 +1219,9 @@ def _fake_mxfp8_quantize(
12191219
(rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device
12201220
)
12211221

1222-
# colwise scales are written in column-major format to avoid uncoalesced global memory accesses
1223-
scales_colwise = torch.empty_strided(
1222+
# and microb
1223+
scales_colwise = torch.empty(
12241224
(cols, num_row_blocks),
1225-
(1, cols),
12261225
dtype=torch.float8_e8m0fnu,
12271226
device=x.device,
12281227
)

0 commit comments

Comments
 (0)