|
17 | 17 | triton_fp8_rowwise_3d_transpose_rhs, |
18 | 18 | ) |
19 | 19 | from torchao.prototype.moe_training.kernels.mxfp8 import ( |
| 20 | + mx_block_rearrange_2d_K_groups_cuda, |
20 | 21 | mxfp8_quantize_cuda_3d, |
21 | | - triton_mx_block_rearrange_2d_K_groups, |
22 | 22 | triton_mx_block_rearrange_2d_M_groups, |
23 | 23 | triton_mx_block_rearrange_per_group_3d, |
24 | 24 | ) |
@@ -92,28 +92,28 @@ def forward( |
92 | 92 | assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D" |
93 | 93 | assert B_t.ndim == 3, "B must be 3D" |
94 | 94 |
|
95 | | - assert A.size(-1) % 16 == 0, ( |
96 | | - f"A must have a last dim divisible by 16, but got shape: {A.shape}" |
97 | | - ) |
98 | | - assert B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0, ( |
99 | | - f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}" |
100 | | - ) |
| 95 | + assert ( |
| 96 | + A.size(-1) % 16 == 0 |
| 97 | + ), f"A must have a last dim divisible by 16, but got shape: {A.shape}" |
| 98 | + assert ( |
| 99 | + B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0 |
| 100 | + ), f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}" |
101 | 101 |
|
102 | 102 | # Assert input tensors are in high-precision dtypes. |
103 | | - assert A.dtype == torch.float32 or A.dtype == torch.bfloat16, ( |
104 | | - "A must be float32 or bfloat16" |
105 | | - ) |
106 | | - assert B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16, ( |
107 | | - "B must be float32 or bfloat16" |
108 | | - ) |
109 | | - assert offs is None or offs.dtype == torch.int32, ( |
110 | | - "offs must be int32 tensor or None" |
111 | | - ) |
| 103 | + assert ( |
| 104 | + A.dtype == torch.float32 or A.dtype == torch.bfloat16 |
| 105 | + ), "A must be float32 or bfloat16" |
| 106 | + assert ( |
| 107 | + B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16 |
| 108 | + ), "B must be float32 or bfloat16" |
| 109 | + assert ( |
| 110 | + offs is None or offs.dtype == torch.int32 |
| 111 | + ), "offs must be int32 tensor or None" |
112 | 112 |
|
113 | 113 | # Assert A and B dims are compatible for a scaled grouped GEMM. |
114 | | - assert A.size(-1) == B_t.size(-2), ( |
115 | | - f"shape {A.shape} and {B_t.shape} are not compatible for _quantize_then_scaled_grouped_mm" |
116 | | - ) |
| 114 | + assert A.size(-1) == B_t.size( |
| 115 | + -2 |
| 116 | + ), f"shape {A.shape} and {B_t.shape} are not compatible for _quantize_then_scaled_grouped_mm" |
117 | 117 |
|
118 | 118 | # The left operand in the scaled grouped GEMM must be row-major due to hardware requirements. |
119 | 119 | assert not _is_column_major(A), "A must be row-major" |
@@ -154,12 +154,12 @@ def forward( |
154 | 154 |
|
155 | 155 | # Perform scaled grouped GEMM and return result. |
156 | 156 | # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N) |
157 | | - assert not _is_column_major(A_data_row_major), ( |
158 | | - "A must be row-major for output = A @ B" |
159 | | - ) |
160 | | - assert _is_column_major(B_t_data_col_major), ( |
161 | | - "B must be column-major for output = A @ B" |
162 | | - ) |
| 157 | + assert not _is_column_major( |
| 158 | + A_data_row_major |
| 159 | + ), "A must be row-major for output = A @ B" |
| 160 | + assert _is_column_major( |
| 161 | + B_t_data_col_major |
| 162 | + ), "B must be column-major for output = A @ B" |
163 | 163 |
|
164 | 164 | # Squeeze empty dims out of scales, to comply with grouped mm API. |
165 | 165 | # A_scales shape: (M,1) or (B, M, 1) |
@@ -209,12 +209,12 @@ def backward(ctx, grad_output: torch.Tensor): |
209 | 209 | # Compute grad_A. |
210 | 210 | # grad_A = grad_output @ B |
211 | 211 | # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) |
212 | | - assert not _is_column_major(grad_output_data_row_major), ( |
213 | | - "grad_output must be row-major for grad_A = grad_output @ B" |
214 | | - ) |
215 | | - assert _is_column_major(B_data_col_major), ( |
216 | | - "B must be column-major for grad_A = grad_output @ B" |
217 | | - ) |
| 212 | + assert not _is_column_major( |
| 213 | + grad_output_data_row_major |
| 214 | + ), "grad_output must be row-major for grad_A = grad_output @ B" |
| 215 | + assert _is_column_major( |
| 216 | + B_data_col_major |
| 217 | + ), "B must be column-major for grad_A = grad_output @ B" |
218 | 218 |
|
219 | 219 | # Squeeze empty dims out of scales, to comply with grouped mm API. |
220 | 220 | # grad_output_scales shape: (M,1) or (B, M, 1) |
@@ -259,12 +259,12 @@ def backward(ctx, grad_output: torch.Tensor): |
259 | 259 |
|
260 | 260 | # Compute grad_B = grad_output_t @ A. |
261 | 261 | # grad_B = grad_output_t @ A |
262 | | - assert not _is_column_major(grad_output_t_data_row_major), ( |
263 | | - "grad_output_t must be row-major for grad_B = grad_output_t @ A" |
264 | | - ) |
265 | | - assert _is_column_major(A_data_col_major), ( |
266 | | - "A must be column-major for grad_B = grad_output_t @ A" |
267 | | - ) |
| 262 | + assert not _is_column_major( |
| 263 | + grad_output_t_data_row_major |
| 264 | + ), "grad_output_t must be row-major for grad_B = grad_output_t @ A" |
| 265 | + assert _is_column_major( |
| 266 | + A_data_col_major |
| 267 | + ), "A must be column-major for grad_B = grad_output_t @ A" |
268 | 268 |
|
269 | 269 | # Per-token group scales computed via triton kernels above do not have |
270 | 270 | # the empty dim like the scales computed via tensor_to_scale, so we need |
@@ -449,11 +449,11 @@ def backward(ctx, grad_out: torch.Tensor): |
449 | 449 |
|
450 | 450 | # Convert scales to blocked format for 2d-2d grouped mm |
451 | 451 | scale_group_offsets = offs // block_size |
452 | | - grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups( |
| 452 | + grad_out_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda( |
453 | 453 | grad_out_t_scales, |
454 | 454 | scale_group_offsets, |
455 | 455 | ) |
456 | | - A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups( |
| 456 | + A_t_scales_blocked = mx_block_rearrange_2d_K_groups_cuda( |
457 | 457 | A_t_scales, |
458 | 458 | scale_group_offsets, |
459 | 459 | ) |
@@ -518,21 +518,21 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d( |
518 | 518 | ) -> torch.Tensor: |
519 | 519 | assert A_data.ndim == 2, f"A must be 2D, got {A_data.ndim}" |
520 | 520 | assert B_data.ndim == 3, f"B must be 3D, got {B_data.ndim}" |
521 | | - assert A_scale.shape[0] == A_data.shape[0], ( |
522 | | - f"A_scale must have same M dim as A_data, got A={A_data.shape} and A_scale={A_scale.shape}" |
523 | | - ) |
524 | | - assert A_scale.shape[1] == A_data.shape[1] // block_size, ( |
525 | | - f"A_scale dim1 should be size K//block_size, got A={A_data.shape} and A_scale={A_scale.shape}" |
526 | | - ) |
527 | | - assert B_scale.shape[0] == B_data.shape[0], ( |
528 | | - f"B_scale must have same E dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}" |
529 | | - ) |
530 | | - assert B_scale.shape[1] == B_data.shape[1], ( |
531 | | - f"B_scale must have same N dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}" |
532 | | - ) |
533 | | - assert B_scale.shape[2] == B_data.shape[2] // block_size, ( |
534 | | - f"B_scale dim2 should be size K//block_size, got B={B_data.shape} and B_scale={B_scale.shape}" |
535 | | - ) |
| 521 | + assert ( |
| 522 | + A_scale.shape[0] == A_data.shape[0] |
| 523 | + ), f"A_scale must have same M dim as A_data, got A={A_data.shape} and A_scale={A_scale.shape}" |
| 524 | + assert ( |
| 525 | + A_scale.shape[1] == A_data.shape[1] // block_size |
| 526 | + ), f"A_scale dim1 should be size K//block_size, got A={A_data.shape} and A_scale={A_scale.shape}" |
| 527 | + assert ( |
| 528 | + B_scale.shape[0] == B_data.shape[0] |
| 529 | + ), f"B_scale must have same E dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}" |
| 530 | + assert ( |
| 531 | + B_scale.shape[1] == B_data.shape[1] |
| 532 | + ), f"B_scale must have same N dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}" |
| 533 | + assert ( |
| 534 | + B_scale.shape[2] == B_data.shape[2] // block_size |
| 535 | + ), f"B_scale dim2 should be size K//block_size, got B={B_data.shape} and B_scale={B_scale.shape}" |
536 | 536 |
|
537 | 537 | # Dequantize input |
538 | 538 | # A_data shape: (M, K) |
|
0 commit comments