From 0ee2ede273083a0b401eb35947b35cc9396aacfe Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Fri, 5 Jun 2026 07:16:24 -0700 Subject: [PATCH 1/2] [PyTorch] NVFP4: cache GEMM-swizzled weight scale factors across microbatches For block-scaled NVFP4 a cached weight participates in two GEMMs per step: fprop (rowwise scales) and dgrad (columnwise scales). The GEMM-ready scale swizzle was recomputed lazily inside every GEMM and discarded, so with N microbatches the weight scale swizzle ran 2*N times per step even though the weight is quantized only once. Because weight RHT is disabled, the weight scales are not swizzled by the cast-fusion path; with optimize_for_gemm off they also skip the post-quantize fallback swizzle, so the only swizzle site left for the weight is the lazy one inside general_gemm (swizzle_scales_for_gemm), which re-runs on every GEMM. (Activation input/grad_output quantizers already set optimize_for_gemm=True, so they were pre-swizzled via cast-fusion/fallback; only the weight was missed.) Set weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the swizzle is done once at quantize time (via the post-quantize fallback), persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused by every GEMM (swizzle_scales_for_gemm early-returns) -> 2 swizzles per step instead of 2*N. Applied to Linear, LayerNormLinear, LayerNormMLP (fc1+fc2) and GroupedLinear (per expert). Gated to the cached path (is_first_microbatch is not None) with fsdp_group is None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled scale layout, so pre-swizzling is unsupported there. No-op for recipes whose scales do not require swizzling (e.g. per-tensor FP8). Swizzling is a pure layout permutation, so numerics are unchanged. Add tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py verifying the cached eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for Linear/LayerNormLinear/GroupedLinear and that the swizzled flag is persisted. Signed-off-by: Cael Ling --- .../nvfp4/test_nvfp4_weight_swizzle_cache.py | 194 ++++++++++++++++++ .../pytorch/module/grouped_linear.py | 12 ++ .../pytorch/module/layernorm_linear.py | 11 + .../pytorch/module/layernorm_mlp.py | 16 ++ transformer_engine/pytorch/module/linear.py | 14 ++ 5 files changed, 247 insertions(+) create mode 100644 tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py diff --git a/tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py b/tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py new file mode 100644 index 0000000000..3cc793e133 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py @@ -0,0 +1,194 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Tests for the cached-weight scale-swizzle optimization. + +For block-scaled NVFP4 a weight participates in two GEMMs per step: + + * fprop: ``y = x @ Wt`` -> consumes the weight's **rowwise** scale factors + * dgrad: ``dx = dY @ W`` -> consumes the weight's **columnwise** scale factors + +cuBLAS/CUTLASS needs those scale factors in a GEMM-"swizzled" layout. Without +``optimize_for_gemm`` on the *weight* quantizer that swizzle is recomputed +lazily inside every GEMM and discarded, so with ``N`` micro-batches the weight +scale swizzle runs ``2*N`` times per step even though the weight is quantized +once. When the quantized weight is cached across micro-batches +(``is_first_microbatch`` is not ``None``) and FSDP is not in use, the module +sets ``weight_quantizer.optimize_for_gemm = True`` so the swizzle is done once +at quantize time, persisted on the cached workspace +(``_with_gemm_swizzled_scales = True``), and reused by every GEMM -> ``2`` +swizzles per step instead of ``2*N``. + +These tests verify that: + +1. The optimization is **numerically a no-op**: swizzling is a pure layout + permutation of the scale factors, so the cached (eager-swizzle) path must + produce the same fprop output and dgrad as the un-cached (lazy-swizzle) + baseline, for every distinct micro-batch. +2. The ``_with_gemm_swizzled_scales`` flag is actually set and persisted on the + cached weight workspace. +""" + +import pytest +import torch + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import NVFP4BlockScaling + + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + +pytestmark = pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) + + +def _make_module(kind, in_features, out_features, device): + if kind == "Linear": + return te.Linear( + in_features, out_features, bias=True, params_dtype=torch.bfloat16 + ).to(device) + if kind == "LayerNormLinear": + return te.LayerNormLinear( + in_features, out_features, bias=True, params_dtype=torch.bfloat16 + ).to(device) + raise ValueError(f"unknown module kind {kind}") + + +def _clone_params(src, dst): + """Copy src's parameters into dst so both modules start identical.""" + with torch.no_grad(): + dst_params = dict(dst.named_parameters()) + for name, param in src.named_parameters(): + dst_params[name].copy_(param) + + +def _step(module, x, is_first, recipe): + x = x.detach().clone().requires_grad_(True) + module.zero_grad(set_to_none=True) # per-micro-batch grads (no accumulation) + with te.autocast(enabled=True, recipe=recipe): + out = module(x, is_first_microbatch=is_first) + out.sum().backward() + return out.detach().float(), x.grad.detach().float() + + +@pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"]) +@pytest.mark.parametrize("microbatches", [1, 4]) +@pytest.mark.parametrize( + "shape", [(1024, 1024), (2048, 512)], ids=["1024x1024", "2048x512"] +) +def test_weight_swizzle_cache_numerics(kind, microbatches, shape): + """Cached eager-swizzle path == lazy-swizzle baseline (fprop + dgrad).""" + torch.manual_seed(1234) + device = "cuda" + in_features, out_features = shape + batch = 512 + + # Stochastic rounding is the only run-to-run nondeterminism source (RHT uses + # a fixed sign mask) and it is applied to the bwd grad regardless of this + # optimization, so disable it to make eager-vs-lazy weight swizzle + # bit-comparable. The swizzle is a pure layout transform, so with SR off the + # two paths must match tightly. + recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) + + # ref: always lazy-swizzle (is_first_microbatch=None => no weight cache => + # optimize_for_gemm stays False). opt: cached eager-swizzle path. Identical + # weights so per-micro-batch outputs are directly comparable. + ref = _make_module(kind, in_features, out_features, device) + opt = _make_module(kind, in_features, out_features, device) + _clone_params(ref, opt) + + # Distinct inputs per micro-batch (mirrors gradient accumulation: different + # data each micro-batch, same weight). + inputs = [ + torch.randn(batch, in_features, dtype=torch.bfloat16, device=device) + for _ in range(microbatches) + ] + + atol, rtol = 1e-3, 1e-3 + for mb in range(microbatches): + ref_out, ref_dgrad = _step(ref, inputs[mb], None, recipe) + opt_out, opt_dgrad = _step(opt, inputs[mb], mb == 0, recipe) + torch.testing.assert_close( + opt_out, ref_out, atol=atol, rtol=rtol, msg=f"fprop mismatch at mb {mb}" + ) + torch.testing.assert_close( + opt_dgrad, ref_dgrad, atol=atol, rtol=rtol, msg=f"dgrad mismatch at mb {mb}" + ) + + # The swizzled flag must be set & persisted on the cached weight workspace. + workspaces = opt._fp8_workspaces + assert workspaces, "no cached weight workspace was created on the optimized module" + for name, ws in workspaces.items(): + assert getattr(ws, "_with_gemm_swizzled_scales", False) is True, ( + f"cached weight workspace {name!r} scales were not pre-swizzled " + "(optimize_for_gemm not applied)" + ) + + +@pytest.mark.parametrize("microbatches", [1, 4]) +@pytest.mark.parametrize("num_gemms", [1, 2]) +def test_grouped_linear_weight_swizzle_cache_numerics(microbatches, num_gemms): + """GroupedLinear (MoE expert path): cached eager-swizzle == lazy baseline.""" + torch.manual_seed(1234) + device = "cuda" + in_features, out_features = 1024, 1024 + tokens_per_gemm = 256 + batch = tokens_per_gemm * num_gemms + m_splits = [tokens_per_gemm] * num_gemms + + recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) + + ref = te.GroupedLinear( + num_gemms, in_features, out_features, bias=True, params_dtype=torch.bfloat16 + ).to(device) + opt = te.GroupedLinear( + num_gemms, in_features, out_features, bias=True, params_dtype=torch.bfloat16 + ).to(device) + _clone_params(ref, opt) + + inputs = [ + torch.randn(batch, in_features, dtype=torch.bfloat16, device=device) + for _ in range(microbatches) + ] + + def grouped_step(module, x, is_first): + x = x.detach().clone().requires_grad_(True) + module.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + out = module(x, m_splits, is_first_microbatch=is_first) + out.sum().backward() + return out.detach().float(), x.grad.detach().float() + + atol, rtol = 1e-3, 1e-3 + for mb in range(microbatches): + ref_out, ref_dgrad = grouped_step(ref, inputs[mb], None) + opt_out, opt_dgrad = grouped_step(opt, inputs[mb], mb == 0) + torch.testing.assert_close( + opt_out, ref_out, atol=atol, rtol=rtol, msg=f"fprop mismatch at mb {mb}" + ) + torch.testing.assert_close( + opt_dgrad, ref_dgrad, atol=atol, rtol=rtol, msg=f"dgrad mismatch at mb {mb}" + ) + + workspaces = opt._fp8_workspaces + assert len(workspaces) == num_gemms, "expected one cached workspace per expert" + for name, ws in workspaces.items(): + assert getattr(ws, "_with_gemm_swizzled_scales", False) is True, ( + f"cached weight workspace {name!r} scales were not pre-swizzled" + ) + + +@pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"]) +def test_lazy_path_not_swizzled(kind): + """Without weight caching (is_first_microbatch=None) no workspace is created + and the optimization stays off — guards against accidentally always-on.""" + torch.manual_seed(0) + device = "cuda" + recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) + module = _make_module(kind, 1024, 1024, device) + x = torch.randn(512, 1024, dtype=torch.bfloat16, device=device, requires_grad=True) + with te.autocast(enabled=True, recipe=recipe): + out = module(x, is_first_microbatch=None) + out.sum().backward() + assert not module._fp8_workspaces, ( + "lazy path (is_first_microbatch=None) must not populate the weight cache" + ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 15ec3fe322..d5c4aebe1e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -1735,6 +1735,18 @@ def forward( else [None] * num_gemms ) + # Pre-swizzle (and cache) the weight scale factors when the quantized + # weights are cached across microbatches, so the per-GEMM scale swizzle + # (fprop rowwise + dgrad columnwise, redone every microbatch) collapses + # from 2*num_microbatches kernels to 2 per step per expert. Gated to the + # cached, non-FSDP path (FSDP/FSDP2 all-gather weights with un-swizzled + # scales; see NVFP4Tensor.fsdp_pre_all_gather), so pre-swizzling is + # unsupported there. No-op for non-swizzled recipes (e.g. per-tensor FP8). + if cache_weight and self.fsdp_group is None and not self.is_fsdp2: + for weight_quantizer in weight_quantizers: + if weight_quantizer is not None: + weight_quantizer.optimize_for_gemm = True + non_tensor_args = ( self.apply_bias, is_first_microbatch, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7fc96d4779..ca3cafe0f1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1716,6 +1716,17 @@ def forward( self._fp8_workspaces.get(cache_name) if cache_name is not None else None ) + # Pre-swizzle (and cache) the weight scale factors when the quantized + # weight is cached across microbatches, so the per-GEMM scale swizzle + # (fprop rowwise + dgrad columnwise, redone every microbatch) collapses + # from 2*num_microbatches kernels to 2 per step. Gated to the cached, + # non-FSDP path (FSDP all-gathers weights with un-swizzled scales; see + # NVFP4Tensor.fsdp_pre_all_gather). No-op for non-swizzled recipes. + if weight_quantizer is not None: + weight_quantizer.optimize_for_gemm = ( + cache_name is not None and self.fsdp_group is None + ) + non_tensor_args = ( self.eps, is_first_microbatch, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6c6cca74ef..cd31e1de7b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -2358,6 +2358,22 @@ def forward( self._fp8_workspaces.get(cache_name_fc2) if cache_name_fc2 is not None else None ) + # Pre-swizzle (and cache) the weight scale factors when the quantized + # weight is cached across microbatches, so the per-GEMM scale swizzle + # (fprop rowwise + dgrad columnwise, redone every microbatch) collapses + # from 2*num_microbatches kernels to 2 per step per FC. Gated to the + # cached, non-FSDP path (FSDP all-gathers weights with un-swizzled + # scales; see NVFP4Tensor.fsdp_pre_all_gather). No-op for non-swizzled + # recipes (e.g. per-tensor FP8). + if fc1_weight_quantizer is not None: + fc1_weight_quantizer.optimize_for_gemm = ( + cache_name_fc1 is not None and self.fsdp_group is None + ) + if fc2_weight_quantizer is not None: + fc2_weight_quantizer.optimize_for_gemm = ( + cache_name_fc2 is not None and self.fsdp_group is None + ) + non_tensor_args = ( self.eps, is_first_microbatch, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c2d98d160..cbf396740c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1861,6 +1861,20 @@ def forward( self._fp8_workspaces.get(cache_name) if cache_name is not None else None ) + # When the quantized weight is cached and reused across microbatches, + # pre-swizzle its scale factors at quantize time and persist them on the + # cached workspace. This collapses the per-GEMM scale swizzle (fprop + # rowwise + dgrad columnwise, redone every microbatch) from + # 2*num_microbatches kernels down to 2 per optimizer step. Gated to the + # cached, non-FSDP path: FSDP/FSDP2 all-gather weights using the + # un-swizzled scale layout (see NVFP4Tensor.fsdp_pre_all_gather), so + # pre-swizzling is unsupported there. No-op for recipes whose scales do + # not require swizzling (e.g. per-tensor FP8). + if weight_quantizer is not None: + weight_quantizer.optimize_for_gemm = ( + cache_name is not None and self.fsdp_group is None + ) + if self.fp8: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: From f04d8006f4988af720a59431b9a84f62093893cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jun 2026 14:31:57 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../nvfp4/test_nvfp4_weight_swizzle_cache.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py b/tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py index 3cc793e133..5eb60972da 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py +++ b/tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py @@ -43,9 +43,9 @@ def _make_module(kind, in_features, out_features, device): if kind == "Linear": - return te.Linear( - in_features, out_features, bias=True, params_dtype=torch.bfloat16 - ).to(device) + return te.Linear(in_features, out_features, bias=True, params_dtype=torch.bfloat16).to( + device + ) if kind == "LayerNormLinear": return te.LayerNormLinear( in_features, out_features, bias=True, params_dtype=torch.bfloat16 @@ -72,9 +72,7 @@ def _step(module, x, is_first, recipe): @pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"]) @pytest.mark.parametrize("microbatches", [1, 4]) -@pytest.mark.parametrize( - "shape", [(1024, 1024), (2048, 512)], ids=["1024x1024", "2048x512"] -) +@pytest.mark.parametrize("shape", [(1024, 1024), (2048, 512)], ids=["1024x1024", "2048x512"]) def test_weight_swizzle_cache_numerics(kind, microbatches, shape): """Cached eager-swizzle path == lazy-swizzle baseline (fprop + dgrad).""" torch.manual_seed(1234) @@ -172,9 +170,9 @@ def grouped_step(module, x, is_first): workspaces = opt._fp8_workspaces assert len(workspaces) == num_gemms, "expected one cached workspace per expert" for name, ws in workspaces.items(): - assert getattr(ws, "_with_gemm_swizzled_scales", False) is True, ( - f"cached weight workspace {name!r} scales were not pre-swizzled" - ) + assert ( + getattr(ws, "_with_gemm_swizzled_scales", False) is True + ), f"cached weight workspace {name!r} scales were not pre-swizzled" @pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"]) @@ -189,6 +187,6 @@ def test_lazy_path_not_swizzled(kind): with te.autocast(enabled=True, recipe=recipe): out = module(x, is_first_microbatch=None) out.sum().backward() - assert not module._fp8_workspaces, ( - "lazy path (is_first_microbatch=None) must not populate the weight cache" - ) + assert ( + not module._fp8_workspaces + ), "lazy path (is_first_microbatch=None) must not populate the weight cache"