Skip to content

Commit ac2f0a5

Browse files
[mxfp8 moe training] update readme with kernel microbenchmarks for dsv3
stack-info: PR: #3521, branch: danielvegamyhre/stack/90
1 parent 79b248c commit ac2f0a5

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

torchao/prototype/moe_training/README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,53 @@ cd benchmarks/prototype/moe_training/mxfp8
222222
python roofline_unified.py --K=7168 --N=2048 --G=8 --power_limit_percent=100 --breakdown_M=131072 --plot_file=dsv3_rooflines.png
223223
```
224224

225+
### MXFP8 Kernel Breakdown by Pass
226+
227+
The following table provides a detailed breakdown of all MXFP8 kernels used in the forward and backward passes, with shapes representative of **DeepSeekV3 671B** (dim=7168, hidden_dim=2048, total_tokens=131072, groups=8, block_size=32).
228+
229+
Benchmark results measured on **NVIDIA B200** GPU at 80% power limit (Peak BW: 5888 GB/s, Peak MXFP8: 2808 TFLOPS).
230+
231+
**Environment:**
232+
- torch: `2.11.0.dev20251216+cu128`
233+
- torchao: `0.15.0+gitd1305bc78`
234+
- NVIDIA B200
235+
236+
| Pass | Kernel | Purpose | Input Shape | Time (µs) | Efficiency |
237+
|------|--------|---------|-------------|-----------|------------|
238+
| **Forward** | `triton_to_mxfp8_dim0` | Quantize A (activations) along dim0 | (131072, 7168) | 580.6 | 83.3% peak BW |
239+
| **Forward** | `mxfp8_quantize_cuda_3d` | Quantize B (weights) along dim0 | (8, 2048, 7168) | 76.8 | 78.8% peak BW |
240+
| **Forward** | `triton_mx_block_rearrange_2d_M_groups` | Convert A scales to blocked format | (131072, 224) | 198.7 ||
241+
| **Forward** | `triton_mx_block_rearrange_per_group_3d` | Convert B scales to blocked format | (8, 2048, 224) | 11.4 ||
242+
| **Forward** | `torch._scaled_grouped_mm` | 2D-3D scaled grouped GEMM | (131072, 7168) @ (8, 7168, 2048) | 1838.1 | 74.6% peak TFLOPS |
243+
| **Backward (dA)** | `triton_to_mxfp8_dim0` | Quantize grad_out along dim0 | (131072, 2048) | 166.0 | 83.3% peak BW |
244+
| **Backward (dA)** | `mxfp8_quantize_cuda_3d` | Quantize B along dim1 (N dimension) | (8, 2048, 7168) | 76.8 | 78.8% peak BW |
245+
| **Backward (dA)** | `triton_mx_block_rearrange_2d_M_groups` | Convert grad_out scales to blocked format | (131072, 64) | 192.5 ||
246+
| **Backward (dA)** | `triton_mx_block_rearrange_per_group_3d` | Convert B scales to blocked format | (8, 7168, 64) | 11.0 ||
247+
| **Backward (dA)** | `torch._scaled_grouped_mm` | 2D-3D scaled grouped GEMM | (131072, 2048) @ (8, 2048, 7168) | 1838.1 | 74.6% peak TFLOPS |
248+
| **Backward (dB)** | `mxfp8_quantize_cuda` | Quantize grad_out along dim1 (colwise) | (131072, 2048) | 191.7 | 72.1% peak BW |
249+
| **Backward (dB)** | `mxfp8_quantize_cuda` | Quantize A along dim1 (colwise) | (131072, 7168) | 670.7 | 72.1% peak BW |
250+
| **Backward (dB)** | `mx_block_rearrange_2d_K_groups_cuda` | Convert grad_out_t scales to blocked format | (2048, 4096) | 17.4 ||
251+
| **Backward (dB)** | `mx_block_rearrange_2d_K_groups_cuda` | Convert A_t scales to blocked format | (7168, 4096) | 31.6 ||
252+
| **Backward (dB)** | `torch._scaled_grouped_mm` | 2D-2D scaled grouped GEMM | (2048, 131072) @ (131072, 7168) | 2412.4 | 56.9% peak TFLOPS |
253+
254+
**Notes:**
255+
- **Efficiency** is reported as percentage of peak achievable bandwidth (for memory-bound quantization kernels) or percentage of peak TFLOPS (for compute-bound GEMM kernels)
256+
- Scale rearrangement kernels are not conventional memory bandwidth bound or compute bound kernels, so we report absolute runtime only
257+
- Scale tensor shapes are derived from input shapes divided by `block_size=32` along the scaling dimension
258+
- Detailed kernel breakdown with timing for all kernels is available in the roofline plots above (generated by `roofline_unified.py`)
259+
- All kernels can be benchmarked individually using the scripts in `benchmarks/prototype/moe_training/mxfp8/`
260+
261+
**Benchmark Scripts:**
262+
| Kernel Type | Benchmark Script |
263+
|-------------|------------------|
264+
| 2D Quantization (dim0/dim1) | `benchmarks/mx_formats/cast_bench.py` |
265+
| 3D Quantization | `benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py` |
266+
| 2D M-groups Scale Rearrange | `benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py` |
267+
| 2D K-groups Scale Rearrange | `benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py` |
268+
| 3D Per-group Scale Rearrange | `benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py` |
269+
| Grouped GEMM (2D-3D, 2D-2D) | `benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py` |
270+
| Unified Roofline Analysis | `benchmarks/prototype/moe_training/mxfp8/roofline_unified.py` |
271+
225272
## Benchmark: single MoE layer forward + backward pass
226273

227274
| Model | total_M | N | K | bf16 time (ms) | mxfp8 time (ms) | speedup |

0 commit comments

Comments
 (0)