NVFP4: cache GEMM-swizzled weight scale factors across micro-batches#3093
NVFP4: cache GEMM-swizzled weight scale factors across micro-batches#3093cael-ling wants to merge 2 commits into
Conversation
…obatches For block-scaled NVFP4 a cached weight participates in two GEMMs per step: fprop (rowwise scales) and dgrad (columnwise scales). The GEMM-ready scale swizzle was recomputed lazily inside every GEMM and discarded, so with N microbatches the weight scale swizzle ran 2*N times per step even though the weight is quantized only once. Because weight RHT is disabled, the weight scales are not swizzled by the cast-fusion path; with optimize_for_gemm off they also skip the post-quantize fallback swizzle, so the only swizzle site left for the weight is the lazy one inside general_gemm (swizzle_scales_for_gemm), which re-runs on every GEMM. (Activation input/grad_output quantizers already set optimize_for_gemm=True, so they were pre-swizzled via cast-fusion/fallback; only the weight was missed.) Set weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the swizzle is done once at quantize time (via the post-quantize fallback), persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused by every GEMM (swizzle_scales_for_gemm early-returns) -> 2 swizzles per step instead of 2*N. Applied to Linear, LayerNormLinear, LayerNormMLP (fc1+fc2) and GroupedLinear (per expert). Gated to the cached path (is_first_microbatch is not None) with fsdp_group is None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled scale layout, so pre-swizzling is unsupported there. No-op for recipes whose scales do not require swizzling (e.g. per-tensor FP8). Swizzling is a pure layout permutation, so numerics are unchanged. Add tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py verifying the cached eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for Linear/LayerNormLinear/GroupedLinear and that the swizzled flag is persisted. Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR eliminates redundant GEMM-swizzle operations for cached NVFP4 block-scaled weights by setting
Confidence Score: 4/5The four module changes are a clean, consistent gating of optimize_for_gemm on the cached non-FSDP path; the only gap is that LayerNormMLP's fc1/fc2 two-quantizer path is not covered by the new tests. The optimization logic is correct and consistent across all four modules — FSDP1 and FSDP2 exclusions are properly handled. The new test file exercises three of the four modified code paths and verifies both numerical parity and flag persistence. LayerNormMLP is modified in a structurally distinct way (two independent quantizers) but has no corresponding test, leaving a gap that could hide a future regression in that path. tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py — missing LayerNormMLP test case Important Files Changed
Sequence DiagramsequenceDiagram
participant T as Training Loop
participant M as TE Module
participant Q as WeightQuantizer
participant G as general_gemm
Note over T,G: Cached path (is_first_microbatch is not None, no FSDP)
T->>M: "forward(x, is_first_microbatch=True)"
M->>Q: "set optimize_for_gemm = True"
M->>Q: quantize(weight) → swizzle scales eagerly
Q-->>M: FP4Tensor (swizzled scales cached in workspace)
M->>G: fprop GEMM (cached swizzled rowwise scales)
G-->>M: output y
M->>G: dgrad GEMM (cached swizzled columnwise scales)
G-->>M: dx
T->>M: "forward(x2, is_first_microbatch=False)"
M->>Q: "set optimize_for_gemm = True"
Note over M,Q: Weight already cached — skip requantize
M->>G: fprop GEMM (reuse cached swizzled scales)
G-->>M: output y2
M->>G: dgrad GEMM (reuse cached swizzled scales)
G-->>M: dx2
Note over T,G: Uncached path (is_first_microbatch=None) or FSDP
T->>M: "forward(x, is_first_microbatch=None)"
M->>Q: "set optimize_for_gemm = False"
M->>G: fprop GEMM (lazy swizzle rowwise inside GEMM)
G-->>M: output y
M->>G: dgrad GEMM (lazy swizzle columnwise inside GEMM)
G-->>M: dx
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| with te.autocast(enabled=True, recipe=recipe): | ||
| out = module(x, is_first_microbatch=is_first) | ||
| out.sum().backward() | ||
| return out.detach().float(), x.grad.detach().float() | ||
|
|
||
|
|
There was a problem hiding this comment.
Missing
LayerNormMLP test coverage
layernorm_mlp.py is one of four files modified by this PR, yet the test suite parametrizes only over ["Linear", "LayerNormLinear"] for both test_weight_swizzle_cache_numerics and test_lazy_path_not_swizzled. The fc1/fc2 two-quantizer path in LayerNormMLP is structurally different from the single-quantizer modules: it independently gates fc1_weight_quantizer.optimize_for_gemm and fc2_weight_quantizer.optimize_for_gemm using separate cache_name_fc1/cache_name_fc2 variables. If either gating expression were wrong (e.g. swapping fc1/fc2 names), existing tests would not catch it.
Description
For block-scaled NVFP4, a cached weight is used in two GEMMs per step — fprop (row-wise scales) and dgrad (column-wise scales) — and each GEMM needs its scale factors in the GEMM-swizzled layout. Today that swizzle is recomputed lazily inside
general_gemmon every micro-batch and thrown away, so withNmicro-batches the weight scale swizzle runs2*Ntimes per step even though the weight is quantized only once, which hurts performance. (Activation quantizers already setoptimize_for_gemm=Trueand were pre-swizzled; only the weight was missed.)This PR sets
weight_quantizer.optimize_for_gemm=Trueon the cached, non-FSDP path so the swizzle is done once at quantize time, persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused by every GEMM —2*N→2swizzles per step.Applied to
Linear,LayerNormLinear,LayerNormMLP(fc1 + fc2) andGroupedLinear(per expert).Gated to the cached path (
is_first_microbatch is not None) withfsdp_group is None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled scale layout, so pre-swizzling is unsupported there.No-op for recipes whose scales do not require swizzling (e.g. per-tensor FP8).
Swizzling is a pure layout permutation, so numerics are unchanged.
New
tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py: asserts the cached eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) forLinear/LayerNormLinear/GroupedLinear, and that_with_gemm_swizzled_scalesis set and persisted on the cached workspace.pytest tests/pytorch/test_numerics.py -k "linear or layernorm or mlp"— no regressions.pytest tests/pytorch/test_grouped_linear.py -k "not grouped_tensor and not fused_path"— no regressions.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: