From afc993a212dbcc04a8f6d0d54a9e0550f73fad50 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Thu, 4 Jun 2026 10:42:17 -0700 Subject: [PATCH 1/2] megacpp Signed-off-by: Zhongbo Zhu --- tests/pytorch/megacpp/test_grouped_mlp.py | 476 +++++++++++ transformer_engine/pytorch/csrc/extensions.h | 24 + .../pytorch/csrc/extensions/pybind.cpp | 19 + .../pytorch/csrc/megacpp/grouped_mlp.cpp | 797 ++++++++++++++++++ .../pytorch/ops/fused/__init__.py | 6 + .../ops/fused/backward_grouped_mlp_megacpp.py | 392 +++++++++ .../ops/fused/forward_grouped_mlp_megacpp.py | 382 +++++++++ 7 files changed, 2096 insertions(+) create mode 100644 tests/pytorch/megacpp/test_grouped_mlp.py create mode 100644 transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp create mode 100644 transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py create mode 100644 transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py diff --git a/tests/pytorch/megacpp/test_grouped_mlp.py b/tests/pytorch/megacpp/test_grouped_mlp.py new file mode 100644 index 0000000000..ddddcb7fc4 --- /dev/null +++ b/tests/pytorch/megacpp/test_grouped_mlp.py @@ -0,0 +1,476 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops + + +_HIDDEN_SIZE = 512 +_FFN_HIDDEN_SIZE = 256 + + +def _megacpp_available() -> tuple[bool, str]: + if not torch.cuda.is_available(): + return False, "CUDA is required" + if not te.is_bf16_available(): + return False, "BF16 is required" + if torch.cuda.get_device_capability() < (10, 0): + return False, "megacpp grouped MLP uses SM100 grouped GEMM" + if not te_ops.fused.ForwardGroupedMLP_MegaCpp.is_supported(): + return False, "ForwardGroupedMLP_MegaCpp is not supported" + if not te_ops.fused.BackwardGroupedMLP_MegaCpp.is_supported(): + return False, "BackwardGroupedMLP_MegaCpp is not supported" + return True, "" + + +_AVAILABLE, _SKIP_REASON = _megacpp_available() +pytestmark = pytest.mark.skipif(not _AVAILABLE, reason=_SKIP_REASON) + + +def _make_grouped_mlp( + *, + num_groups: int, + hidden_size: int, + ffn_hidden_size: int, + activation_kind: str, + bias: bool, + delay_wgrad_compute: bool, + accumulate_into_main_grad: bool, + glu_interleave_size: int | None, + single_grouped_param: bool, +) -> te_ops.Sequential: + gated_activation = activation_kind in ("scaled_swiglu", "scaled_clamped_qgeglu") + fc1_out_features = 2 * ffn_hidden_size if gated_activation else ffn_hidden_size + fc1 = te_ops.GroupedLinear( + num_groups, + hidden_size, + fc1_out_features, + bias=bias, + device="cuda", + dtype=torch.bfloat16, + delay_wgrad_compute=delay_wgrad_compute, + accumulate_into_main_grad=accumulate_into_main_grad, + single_grouped_weight=single_grouped_param, + single_grouped_bias=single_grouped_param and bias, + ) + if activation_kind == "scaled_swiglu": + act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + elif activation_kind == "scaled_clamped_qgeglu": + act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + elif activation_kind == "scaled_srelu": + act = te_ops.ScaledSReLU() + else: + raise ValueError(f"Unsupported test activation_kind={activation_kind}.") + fc2 = te_ops.GroupedLinear( + num_groups, + ffn_hidden_size, + hidden_size, + bias=bias, + device="cuda", + dtype=torch.bfloat16, + delay_wgrad_compute=delay_wgrad_compute, + accumulate_into_main_grad=accumulate_into_main_grad, + single_grouped_weight=single_grouped_param, + single_grouped_bias=single_grouped_param and bias, + ) + return te_ops.Sequential(fc1, act, fc2) + + +def _copy_grouped_mlp_params(dst: te_ops.Sequential, src: te_ops.Sequential) -> None: + with torch.no_grad(): + for dst_linear, src_linear in ((dst[0], src[0]), (dst[2], src[2])): + if dst_linear.single_grouped_weight: + dst_linear.weight.rowwise_data.copy_(src_linear.weight.rowwise_data) + if dst_linear.has_bias: + dst_linear.bias.rowwise_data.copy_(src_linear.bias.rowwise_data) + else: + for group_idx in range(dst_linear.num_groups): + getattr(dst_linear, f"weight{group_idx}").copy_( + getattr(src_linear, f"weight{group_idx}") + ) + if dst_linear.has_bias: + getattr(dst_linear, f"bias{group_idx}").copy_( + getattr(src_linear, f"bias{group_idx}") + ) + + +def _init_main_grads(module: te_ops.Sequential) -> None: + for linear in (module[0], module[2]): + if linear.single_grouped_weight: + linear.weight.main_grad = torch.zeros( + linear.num_groups, + linear.out_features, + linear.in_features, + device="cuda", + dtype=torch.bfloat16, + ) + else: + for group_idx in range(linear.num_groups): + weight = getattr(linear, f"weight{group_idx}") + weight.main_grad = torch.zeros_like(weight) + + +def _run_grouped_mlp( + module: te_ops.Sequential, + x: torch.Tensor, + split_sizes: torch.Tensor, + act_scales: torch.Tensor, + dy: torch.Tensor, + *, + delay_wgrad_compute: bool, +) -> torch.Tensor: + y = module(x, split_sizes, act_scales, split_sizes) + y.backward(dy) + if delay_wgrad_compute: + module[0].backward_dw() + module[2].backward_dw() + return y + + +def _assert_grouped_mlp_close( + test: te_ops.Sequential, + ref: te_ops.Sequential, + *, + accumulate_into_main_grad: bool, +) -> None: + for test_linear, ref_linear in ((test[0], ref[0]), (test[2], ref[2])): + if test_linear.single_grouped_weight: + if accumulate_into_main_grad: + torch.testing.assert_close( + test_linear.weight.main_grad, + ref_linear.weight.main_grad, + rtol=2e-2, + atol=2e-2, + ) + else: + torch.testing.assert_close( + test_linear.weight.grad, + ref_linear.weight.grad, + rtol=2e-2, + atol=2e-2, + ) + if test_linear.has_bias: + torch.testing.assert_close( + test_linear.bias.grad, + ref_linear.bias.grad, + rtol=2e-2, + atol=2e-2, + ) + continue + for group_idx in range(test_linear.num_groups): + if accumulate_into_main_grad: + torch.testing.assert_close( + getattr(test_linear, f"weight{group_idx}").main_grad, + getattr(ref_linear, f"weight{group_idx}").main_grad, + rtol=2e-2, + atol=2e-2, + ) + else: + torch.testing.assert_close( + getattr(test_linear, f"weight{group_idx}").grad, + getattr(ref_linear, f"weight{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + if test_linear.has_bias: + torch.testing.assert_close( + getattr(test_linear, f"bias{group_idx}").grad, + getattr(ref_linear, f"bias{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + + +def _assert_grouped_mlp_nonzero_expert_grads_close( + test: te_ops.Sequential, + ref: te_ops.Sequential, + split_sizes: list[int], +) -> None: + """Compare only non-empty experts; zero-token expert grads may be unwritten.""" + for test_linear, ref_linear in ((test[0], ref[0]), (test[2], ref[2])): + for group_idx, split_size in enumerate(split_sizes): + if split_size == 0: + continue + torch.testing.assert_close( + getattr(test_linear, f"weight{group_idx}").grad, + getattr(ref_linear, f"weight{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + if test_linear.has_bias: + torch.testing.assert_close( + getattr(test_linear, f"bias{group_idx}").grad, + getattr(ref_linear, f"bias{group_idx}").grad, + rtol=2e-2, + atol=2e-2, + ) + + +def _assert_valid_prefix_close( + test: torch.Tensor, + ref: torch.Tensor, + valid_tokens: int, +) -> None: + """Paged-stashed buffers only guarantee correctness in the valid token prefix.""" + if valid_tokens == 0: + return + torch.testing.assert_close(test[:valid_tokens], ref[:valid_tokens], rtol=2e-2, atol=2e-2) + + +def _make_split_tensor( + split_sizes: list[int], + *, + dtype: torch.dtype = torch.int64, + device: str = "cuda", +) -> torch.Tensor: + return torch.tensor(split_sizes, dtype=dtype, device=device) + + +def _run_megacpp_against_python( + *, + split_sizes_list: list[int], + physical_tokens: int, + split_dtype: torch.dtype, + split_device: str, + bias: bool = True, + glu_interleave_size: int | None = None, + activation_kind: str = "scaled_swiglu", + single_grouped_param: bool = False, + accumulate_into_main_grad: bool = False, + compare_zero_expert_grads: bool = True, + monkeypatch, +) -> None: + num_groups = len(split_sizes_list) + valid_tokens = sum(split_sizes_list) + assert physical_tokens >= valid_tokens + if single_grouped_param: + monkeypatch.setenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "1") + split_sizes = _make_split_tensor(split_sizes_list, dtype=split_dtype, device=split_device) + ref = _make_grouped_mlp( + num_groups=num_groups, + hidden_size=_HIDDEN_SIZE, + ffn_hidden_size=_FFN_HIDDEN_SIZE, + activation_kind=activation_kind, + bias=bias, + delay_wgrad_compute=False, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=glu_interleave_size, + single_grouped_param=single_grouped_param, + ) + test = _make_grouped_mlp( + num_groups=num_groups, + hidden_size=_HIDDEN_SIZE, + ffn_hidden_size=_FFN_HIDDEN_SIZE, + activation_kind=activation_kind, + bias=bias, + delay_wgrad_compute=False, + accumulate_into_main_grad=accumulate_into_main_grad, + glu_interleave_size=glu_interleave_size, + single_grouped_param=single_grouped_param, + ) + _copy_grouped_mlp_params(test, ref) + if accumulate_into_main_grad: + _init_main_grads(ref) + _init_main_grads(test) + + # Paged stashing passes a static physical buffer to the op while m_splits + # describe only the valid prefix. Rows after sum(m_splits) are garbage and + # must not affect outputs/gradients for the valid prefix. + x_ref = torch.randn( + physical_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() + x_test = x_ref.detach().clone().requires_grad_() + act_scales_ref = torch.rand( + physical_tokens, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() + act_scales_test = act_scales_ref.detach().clone().requires_grad_() + dy = torch.randn(physical_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16) + + monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "0") + y_ref = _run_grouped_mlp( + ref, + x_ref, + split_sizes, + act_scales_ref, + dy, + delay_wgrad_compute=False, + ) + monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "1") + y_test = _run_grouped_mlp( + test, + x_test, + split_sizes, + act_scales_test, + dy, + delay_wgrad_compute=False, + ) + + fuser = test._module_groups[0] + assert isinstance(fuser._forward_ops[0][0], te_ops.fused.ForwardGroupedMLP_MegaCpp) + assert isinstance(fuser._backward_ops[0][0], te_ops.fused.BackwardGroupedMLP_MegaCpp) + + _assert_valid_prefix_close(y_test, y_ref, valid_tokens) + _assert_valid_prefix_close(x_test.grad, x_ref.grad, valid_tokens) + _assert_valid_prefix_close( + act_scales_test.grad, + act_scales_ref.grad, + valid_tokens, + ) + if valid_tokens == physical_tokens and compare_zero_expert_grads: + _assert_grouped_mlp_close(test, ref, accumulate_into_main_grad=accumulate_into_main_grad) + elif valid_tokens > 0 and not single_grouped_param and not accumulate_into_main_grad: + _assert_grouped_mlp_nonzero_expert_grads_close(test, ref, split_sizes_list) + + +@pytest.mark.parametrize( + "single_grouped_param", + [False, True], + ids=["discrete_weight", "packed_weight"], +) +@pytest.mark.parametrize( + "accumulate_into_main_grad", + [False, True], + ids=["cpp_allocated_wgrad", "megatron_main_grad"], +) +def test_megacpp_grouped_mlp_wgrad_storage_matches_python( + single_grouped_param, + accumulate_into_main_grad, + monkeypatch, +): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=torch.int64, + split_device="cuda", + single_grouped_param=single_grouped_param, + accumulate_into_main_grad=accumulate_into_main_grad, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize( + "split_dtype,split_device", + [ + pytest.param(torch.int64, "cuda", id="i64_cuda"), + pytest.param(torch.int32, "cuda", id="i32_cuda"), + pytest.param(torch.int64, "cpu", id="i64_cpu"), + ], +) +def test_megacpp_grouped_mlp_split_source_matches_python( + split_dtype, + split_device, + monkeypatch, +): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=split_dtype, + split_device=split_device, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize( + "activation_kind", + ["scaled_swiglu", "scaled_srelu", "scaled_clamped_qgeglu"], + ids=["swiglu", "srelu", "clamped_qgeglu"], +) +@pytest.mark.parametrize( + "glu_interleave_size", + [None, 32], + ids=["no_interleave", "interleave_32"], +) +def test_megacpp_grouped_mlp_activation_matches_python( + activation_kind, + glu_interleave_size, + monkeypatch, +): + if activation_kind == "scaled_srelu" and glu_interleave_size is not None: + pytest.skip("ScaledSReLU is not a GLU activation.") + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=torch.int64, + split_device="cuda", + activation_kind=activation_kind, + glu_interleave_size=glu_interleave_size, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize("bias", [True, False], ids=["bias", "no_bias"]) +def test_megacpp_grouped_mlp_bias_matches_python(bias, monkeypatch): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=[256, 256, 512], + physical_tokens=1024, + split_dtype=torch.int64, + split_device="cuda", + bias=bias, + monkeypatch=monkeypatch, + ) + + +@pytest.mark.parametrize( + "split_sizes_list,physical_tokens", + [ + pytest.param([256, 256, 256, 256], 1024, id="even"), + pytest.param([0, 256, 256, 512], 1024, id="zero_front"), + pytest.param([256, 0, 256, 512], 1024, id="zero_middle"), + pytest.param([256, 256, 512, 0], 1024, id="zero_end"), + pytest.param([256, 256], 1024, id="paged_stashing_even_with_garbage"), + pytest.param([0, 256, 256], 1024, id="paged_stashing_zero_front_with_garbage"), + pytest.param([256, 0, 256], 1024, id="paged_stashing_zero_middle_with_garbage"), + pytest.param([256, 256, 0], 1024, id="paged_stashing_zero_end_with_garbage"), + pytest.param([0, 0, 0, 0], 1024, id="paged_stashing_zero_tokens_all_nonempty_input"), + ], +) +def test_megacpp_grouped_mlp_split_edge_cases( + split_sizes_list, + physical_tokens, + monkeypatch, +): + torch.manual_seed(1234) + _run_megacpp_against_python( + split_sizes_list=split_sizes_list, + physical_tokens=physical_tokens, + split_dtype=torch.int64, + split_device="cuda", + compare_zero_expert_grads=False, + monkeypatch=monkeypatch, + ) + + +def test_megacpp_grouped_mlp_delay_wgrad_raises(monkeypatch): + torch.manual_seed(1234) + num_groups = 3 + split_sizes = torch.tensor([256, 256, 512], dtype=torch.int64, device="cuda") + total_tokens = int(split_sizes.sum().item()) + module = _make_grouped_mlp( + num_groups=num_groups, + hidden_size=_HIDDEN_SIZE, + ffn_hidden_size=_FFN_HIDDEN_SIZE, + activation_kind="scaled_swiglu", + bias=True, + delay_wgrad_compute=True, + accumulate_into_main_grad=False, + glu_interleave_size=None, + single_grouped_param=False, + ) + x = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16).requires_grad_() + act_scales = torch.rand( + total_tokens, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() + dy = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16) + + monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "1") + with pytest.raises(ValueError, match="delay_wgrad_compute"): + y = module(x, split_sizes, act_scales, split_sizes) + y.backward(dy) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..a59e85456d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -185,6 +185,30 @@ py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, p at::Tensor workspace_cublas, bool use_split_accumulator, int math_sm_count); +/*************************************************************************************************** + * Mega C++ grouped MLP + **************************************************************************************************/ + +std::vector megacpp_grouped_mlp_forward( + const at::Tensor &input, const at::Tensor &split_sizes, py::handle fc1_weight, + py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, + const std::optional &act_scales, const std::string &activation, + int64_t glu_interleave_size, double activation_limit, double activation_alpha, + double activation_glu_linear_offset); + +py::tuple megacpp_grouped_mlp_backward( + const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, + const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, + const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, + const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, + const std::optional &act_scales, py::handle fc1_weight, + py::handle fc2_weight, + py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, + py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, + bool input_requires_grad); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..78c9e280f3 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -357,6 +357,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("te_general_grouped_gemm_for_discrete_out", &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, "Grouped GEMM for discrete output list"); + m.def("megacpp_grouped_mlp_forward", + &transformer_engine::pytorch::megacpp_grouped_mlp_forward, + "Mega C++ grouped MLP forward", py::arg("input"), py::arg("split_sizes"), + py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"), py::arg("fc2_bias"), + py::arg("act_scales"), py::arg("activation"), py::arg("glu_interleave_size"), + py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0, + py::arg("activation_glu_linear_offset") = 0.0); + m.def("megacpp_grouped_mlp_backward", + &transformer_engine::pytorch::megacpp_grouped_mlp_backward, + "Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("split_sizes"), + py::arg("x_offsets"), py::arg("fc1_offsets"), py::arg("fc2_offsets"), + py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"), + py::arg("fc1_activation_input"), py::arg("fc2_x"), py::arg("act_scales"), + py::arg("fc1_weight"), py::arg("fc2_weight"), py::arg("fc1_wgrad_output"), + py::arg("fc1_compute_wgrad"), py::arg("fc1_accumulate_wgrad"), py::arg("fc2_wgrad_output"), + py::arg("fc2_compute_wgrad"), py::arg("fc2_accumulate_wgrad"), py::arg("activation"), + py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0, + py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0, + py::arg("act_scales_requires_grad") = true, py::arg("input_requires_grad") = true); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp new file mode 100644 index 0000000000..2f9a642041 --- /dev/null +++ b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp @@ -0,0 +1,797 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include "../extensions.h" +#include "../pybind.h" +#include "common/util/cuda_runtime.h" +#include "common/util/system.h" +#include "transformer_engine/activation.h" +#include "transformer_engine/gemm.h" +#include "transformer_engine/transformer_engine.h" + +namespace py = pybind11; + +namespace transformer_engine::pytorch { +namespace { + +constexpr int64_t kGroupedGemmCublasWorkspaceSize = 32 * 1024 * 1024 + 1024; + +bool is_none(py::handle obj) { return obj.is_none(); } + +std::vector tensor_shape_1d(const at::Tensor &tensor) { + return {static_cast(tensor.numel())}; +} + +at::Tensor maybe_cast_dtype(const at::Tensor &tensor, at::ScalarType dtype) { + at::Tensor out = tensor; + if (out.scalar_type() != dtype) { + out = out.to(out.options().dtype(dtype)); + } + return out; +} + +void check_contiguous(const at::Tensor &tensor, const std::string &name) { + NVTE_CHECK(tensor.is_contiguous(), name, " must be contiguous."); +} + +size_t num_groups_from_prepared_split_sizes(const at::Tensor &split_sizes, + const c10::Device &device) { + NVTE_CHECK(split_sizes.dim() == 1, "split_sizes must be a 1D tensor."); + NVTE_CHECK(split_sizes.device() == device, "split_sizes must be on the current CUDA device."); + NVTE_CHECK(split_sizes.scalar_type() == at::kLong, + "split_sizes must be the int64 CUDA tensor returned by splits_to_offsets_multi."); + return static_cast(split_sizes.numel()); +} + +GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prepared_split_sizes, + const at::Tensor &tensor_offsets, int64_t logical_last_dim) { + const auto num_groups = static_cast(prepared_split_sizes.numel()); + const auto total_tokens = static_cast(data.numel() / logical_last_dim); + auto grouped = GroupedTensorWrapper( + num_groups, std::vector{total_tokens, static_cast(logical_last_dim)}); + grouped.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), + tensor_shape_1d(data)); + grouped.set_first_dims(prepared_split_sizes.data_ptr(), DType::kInt64, + std::vector{num_groups}); + grouped.set_tensor_offsets(tensor_offsets.data_ptr(), DType::kInt64, + std::vector{num_groups + 1}); + return grouped; +} + +GroupedTensorWrapper make_uniform_grouped_tensor(at::Tensor data, size_t num_groups, + int64_t first_dim, int64_t last_dim) { + auto grouped = GroupedTensorWrapper( + num_groups, + std::vector{num_groups * static_cast(first_dim), + static_cast(last_dim)}); + grouped.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), + tensor_shape_1d(data)); + return grouped; +} + +struct GroupedWeightArg { + bool is_grouped = false; + at::Tensor packed; + std::vector discrete; + // Logical per-expert weight shape. For both supported layouts: + // - packed single grouped weight: packed has shape [G, rows, cols] + // - discrete weights: each tensor has shape [rows, cols] + // rows = out_features, cols = in_features. + int64_t rows = 0; + int64_t cols = 0; + + c10::Device device() const { + return is_grouped ? packed.device() : discrete[0].device(); + } +}; + +GroupedWeightArg weight_arg_from_py(py::handle arg, size_t num_groups, at::ScalarType dtype, + const std::string &name) { + GroupedWeightArg out; + if (py::isinstance(arg) || py::isinstance(arg)) { + auto seq = py::reinterpret_borrow(arg); + NVTE_CHECK(static_cast(seq.size()) == num_groups, name, " must have ", num_groups, + " tensors."); + out.discrete.reserve(num_groups); + for (size_t i = 0; i < num_groups; ++i) { + auto tensor = maybe_cast_dtype(seq[i].cast(), dtype); + check_contiguous(tensor, name); + NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2."); + if (i == 0) { + // Discrete case: each expert owns one [out_features, in_features] + // tensor. Cache the shared logical shape for later GEMM setup. + out.rows = tensor.size(0); + out.cols = tensor.size(1); + } else { + NVTE_CHECK(tensor.size(0) == out.rows && tensor.size(1) == out.cols, name, + " tensors must have a uniform shape."); + } + out.discrete.emplace_back(tensor); + } + return out; + } + + out.packed = maybe_cast_dtype(arg.cast(), dtype); + NVTE_CHECK(out.packed.dim() == 3, name, " must be a tensor with shape [num_groups, rows, cols]."); + NVTE_CHECK(static_cast(out.packed.size(0)) == num_groups, name, + " first dimension must be ", num_groups, "."); + check_contiguous(out.packed, name); + out.is_grouped = true; + // Packed case: a single [G, out_features, in_features] tensor stores all + // experts, so dimensions 1 and 2 are the same per-expert logical shape. + out.rows = out.packed.size(1); + out.cols = out.packed.size(2); + return out; +} + +at::Tensor packed_bias_from_arg(py::handle arg, size_t num_groups, at::ScalarType dtype, + int64_t out_features, const std::string &name) { + if (is_none(arg)) { + return at::Tensor(); + } + + auto packed = maybe_cast_dtype(arg.cast(), dtype); + NVTE_CHECK(packed.dim() == 2, name, " must be a tensor with shape [num_groups, features]."); + NVTE_CHECK(static_cast(packed.size(0)) == num_groups, name, " first dimension must be ", + num_groups, "."); + NVTE_CHECK(packed.size(1) == out_features, name, " second dimension must be ", out_features, "."); + check_contiguous(packed, name); + return packed; +} + +std::vector nvte_tensor_list_from_tensors(const std::vector &tensors, + std::vector *wrappers) { + wrappers->clear(); + wrappers->reserve(tensors.size()); + std::vector out; + out.reserve(tensors.size()); + for (const auto &tensor : tensors) { + wrappers->emplace_back(makeTransformerEngineTensor(tensor)); + out.emplace_back(wrappers->back().data()); + } + return out; +} + +int grouped_gemm_math_sm_count(const c10::Device &device) { + const int device_id = static_cast(device.index()); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + return sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); +} + +struct GroupedGemmResources { + c10::Device device; + size_t num_groups; + at::Tensor alpha; + at::Tensor beta_zero; + at::Tensor beta_one; + at::Tensor setup; + at::Tensor cublas; + TensorWrapper te_alpha; + TensorWrapper te_beta_zero; + TensorWrapper te_beta_one; + TensorWrapper te_setup; + TensorWrapper te_cublas; + std::optional config; + + GroupedGemmResources(const c10::Device &device_, size_t num_groups_) + : device(device_), + num_groups(num_groups_), + alpha(at::ones({static_cast(num_groups_)}, at::device(device).dtype(at::kFloat))), + beta_zero( + at::zeros({static_cast(num_groups_)}, at::device(device).dtype(at::kFloat))), + beta_one(alpha), + setup(at::empty( + {static_cast(nvte_get_grouped_gemm_setup_workspace_size(num_groups_))}, + at::device(device).dtype(at::kByte))), + cublas(at::empty({kGroupedGemmCublasWorkspaceSize}, at::device(device).dtype(at::kByte))), + te_alpha(makeTransformerEngineTensor(alpha)), + te_beta_zero(makeTransformerEngineTensor(beta_zero)), + te_beta_one(makeTransformerEngineTensor(beta_one)), + te_setup(makeTransformerEngineTensor(setup.data_ptr(), + std::vector{static_cast(setup.numel())}, + DType::kByte)), + te_cublas(makeTransformerEngineTensor( + cublas.data_ptr(), std::vector{static_cast(cublas.numel())}, + DType::kByte)) { + // These scratch tensors are intentionally local to one megacpp call. They + // are safe after this CPU function returns because every current cuBLAS + // grouped GEMM below is enqueued on at::cuda::getCurrentCUDAStream(), so + // PyTorch's caching allocator observes same-stream allocation/release + // ordering. If a future backend uses auxiliary streams, this helper must + // either record those streams on the tensors or extend resource lifetime. + const int math_sm_count = grouped_gemm_math_sm_count(device); + if (math_sm_count > 0) { + config.emplace(); + config->set_sm_count(math_sm_count); + } + } + + NVTETensor beta(bool accumulate) { + return accumulate ? te_beta_one.data() : te_beta_zero.data(); + } + + NVTEGroupedMatmulConfig config_data() { + return config.has_value() ? static_cast(*config) : nullptr; + } +}; + +GroupedGemmResources make_grouped_mlp_backend_resources(const c10::Device &device, + size_t num_groups) { + // Keep the backend resource policy private to megacpp. Today this is cuBLAS + // grouped GEMM scratch; future backends can change this helper without + // changing the Python or pybind contract. + return GroupedGemmResources(device, num_groups); +} + +void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B, bool transb, + GroupedTensorWrapper *D, GroupedGemmResources *resources, bool accumulate) { + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm(A->data(), transa, B->data(), transb, D->data(), D->data(), + resources->te_alpha.data(), resources->beta(accumulate), + resources->te_setup.data(), resources->te_cublas.data(), + resources->config_data(), + at::cuda::getCurrentCUDAStream()); + }); +} + +std::vector output_tensor_list_from_arg(py::handle arg, size_t num_groups, + at::ScalarType dtype, + const std::string &name) { + std::vector out; + if (is_none(arg)) { + return out; + } + out.reserve(num_groups); + if (py::isinstance(arg) || py::isinstance(arg)) { + auto seq = py::reinterpret_borrow(arg); + NVTE_CHECK(static_cast(seq.size()) == num_groups, name, " must have ", num_groups, + " tensors."); + for (size_t i = 0; i < num_groups; ++i) { + auto tensor = seq[i].cast(); + NVTE_CHECK(tensor.is_cuda(), name, " tensors must be CUDA tensors."); + NVTE_CHECK(tensor.scalar_type() == dtype, name, " tensors must have the requested dtype."); + NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2 wgrad buffers."); + check_contiguous(tensor, name); + out.emplace_back(tensor); + } + return out; + } + + auto packed = arg.cast(); + NVTE_CHECK(packed.is_cuda(), name, " must be a CUDA tensor."); + NVTE_CHECK(packed.scalar_type() == dtype, name, " must have the requested dtype."); + NVTE_CHECK(packed.dim() == 3, name, " must have shape [num_groups, rows, cols]."); + NVTE_CHECK(static_cast(packed.size(0)) == num_groups, name, " first dimension must be ", + num_groups, "."); + check_contiguous(packed, name); + for (size_t i = 0; i < num_groups; ++i) { + out.emplace_back(packed.select(0, static_cast(i))); + } + return out; +} + +struct WgradOutput { + std::vector tensors; + at::Tensor packed; + bool is_grouped = false; + bool owns_storage = false; +}; + +WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num_groups, + at::ScalarType dtype, const c10::Device &device, int64_t rows, + int64_t cols, const std::string &name, + bool prefer_grouped_output) { + WgradOutput out; + if (!compute_wgrad) { + return out; + } + if (is_none(arg)) { + // Cases 1 and 2: no external wgrad buffer was provided, so C++ owns the + // allocation. Single grouped weight keeps this packed as [G, N, K]; + // discrete weights split the same packed allocation into per-expert views. + out.packed = at::empty({static_cast(num_groups), rows, cols}, + at::device(device).dtype(dtype)); + out.owns_storage = true; + out.is_grouped = prefer_grouped_output; + if (out.is_grouped) { + return out; + } + out.tensors.reserve(num_groups); + for (size_t i = 0; i < num_groups; ++i) { + out.tensors.emplace_back(out.packed.select(0, static_cast(i))); + } + return out; + } + if (!py::isinstance(arg) && !py::isinstance(arg)) { + // Case 3: single grouped weight with externally-owned storage, e.g. + // Megatron main_grad viewed as [G, N, K]. GEMM writes in-place and Python + // should not receive a newly allocated grad tensor from this helper. + out.packed = arg.cast(); + NVTE_CHECK(out.packed.is_cuda(), name, " must be a CUDA tensor."); + NVTE_CHECK(out.packed.scalar_type() == dtype, name, " must have the requested dtype."); + NVTE_CHECK(out.packed.dim() == 3, name, " must have shape [num_groups, rows, cols]."); + NVTE_CHECK(static_cast(out.packed.size(0)) == num_groups, name, + " first dimension must be ", num_groups, "."); + NVTE_CHECK(out.packed.size(1) == rows && out.packed.size(2) == cols, name, + " has an unexpected shape."); + check_contiguous(out.packed, name); + out.is_grouped = true; + return out; + } + // Case 4: discrete weights with externally-owned per-expert buffers, e.g. + // Megatron main_grad list. GEMM writes each tensor in-place and returns no + // allocated grad list to Python. + out.tensors = output_tensor_list_from_arg(arg, num_groups, dtype, name); + return out; +} + +void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight, + GroupedTensorWrapper *input, bool trans_input, + GroupedTensorWrapper *output, GroupedGemmResources *resources) { + if (weights->is_grouped) { + // Single grouped weight case: weights are packed as [G, N, K]. Wrap the + // packed buffer as a uniform GroupedTensor and use the grouped-tensor GEMM. + auto grouped_weight = + make_uniform_grouped_tensor(weights->packed, input->num_tensors(), weights->rows, + weights->cols); + grouped_gemm(&grouped_weight, trans_weight, input, trans_input, output, resources, false); + } else { + // Discrete weight case: weights are a list of per-expert tensors. Use the + // discrete-input grouped GEMM variant. + std::vector weight_wrappers; + auto weight_nvte = nvte_tensor_list_from_tensors(weights->discrete, &weight_wrappers); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm_with_discrete_inputA( + weight_nvte.data(), weights->discrete.size(), trans_weight, input->data(), trans_input, + output->data(), output->data(), resources->te_alpha.data(), resources->beta(false), + resources->te_setup.data(), resources->te_cublas.data(), resources->config_data(), + at::cuda::getCurrentCUDAStream()); + }); + } +} + +std::vector grouped_gemm_wgrad(GroupedTensorWrapper *x, GroupedTensorWrapper *dy, + py::handle output, bool compute_wgrad, bool accumulate, + GroupedGemmResources *resources, at::ScalarType dtype, + int64_t rows, int64_t cols, const std::string &name, + bool prefer_grouped_output) { + auto prepared = wgrad_output_from_arg(output, compute_wgrad, resources->num_groups, dtype, + resources->device, rows, cols, name, prefer_grouped_output); + NVTE_CHECK(!(prepared.owns_storage && accumulate), name, + " cannot accumulate into a newly allocated wgrad buffer."); + std::vector returned_wgrads; + + if (prepared.is_grouped) { + // Cases 1 and 3: single grouped weight layout. + // Case 1: C++ allocated packed [G, N, K] storage; return [packed]. + // Case 3: caller provided packed storage, e.g. main_grad; write in-place + // and return nothing because autograd receives dummy wgrad tensors. + auto grouped_output = + make_uniform_grouped_tensor(prepared.packed, resources->num_groups, rows, cols); + grouped_gemm(x, false, dy, true, &grouped_output, resources, accumulate); + if (prepared.owns_storage) { + returned_wgrads.emplace_back(prepared.packed); + } + } else if (!prepared.tensors.empty()) { + // Cases 2 and 4: discrete per-expert weight layout. + // Case 2: C++ allocated packed backing storage and split it into views; + // return those views in parameter order. + // Case 4: caller provided per-expert buffers, e.g. main_grad list; write + // in-place and return nothing because autograd receives dummy wgrads. + std::vector output_wrappers; + auto output_nvte = nvte_tensor_list_from_tensors(prepared.tensors, &output_wrappers); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm_with_discrete_out( + x->data(), false, dy->data(), true, output_nvte.data(), resources->num_groups, + output_nvte.data(), resources->num_groups, resources->te_alpha.data(), + resources->beta(accumulate), resources->te_setup.data(), resources->te_cublas.data(), + resources->config_data(), at::cuda::getCurrentCUDAStream()); + }); + if (prepared.owns_storage) { + returned_wgrads = prepared.tensors; + } + } + return returned_wgrads; +} + +GroupedTensorWrapper make_grouped_bias(const at::Tensor &bias, size_t num_groups, + at::ScalarType dtype, int64_t out_features) { + NVTE_CHECK(bias.defined(), "Bias tensor must be defined."); + auto grouped = GroupedTensorWrapper( + num_groups, std::vector{num_groups, static_cast(out_features)}); + grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(dtype), tensor_shape_1d(bias)); + return grouped; +} + +void add_grouped_bias(GroupedTensorWrapper *output, const at::Tensor &bias, size_t num_groups, + at::ScalarType dtype, int64_t out_features, + std::optional bias_scale = std::nullopt) { + if (!bias.defined()) { + return; + } + auto grouped_bias = make_grouped_bias(bias, num_groups, dtype, out_features); + if (bias_scale.has_value()) { + auto scale = maybe_cast_dtype(*bias_scale, at::kFloat); + check_contiguous(scale, "bias_scale"); + scale = scale.view({-1}); + auto te_scale = makeTransformerEngineTensor(scale); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_scaled_bias_add(output->data(), grouped_bias.data(), te_scale.data(), + at::cuda::getCurrentCUDAStream()); + }); + } else { + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_bias_add(output->data(), grouped_bias.data(), at::cuda::getCurrentCUDAStream()); + }); + } +} + +bool is_gated_activation(const std::string &activation) { + return activation == "swiglu" || activation == "clamped_swiglu" || activation == "geglu" || + activation == "reglu" || activation == "qgeglu" || activation == "sreglu"; +} + +at::Tensor maybe_deinterleave_glu(const at::Tensor &input, int64_t glu_interleave_size) { + if (glu_interleave_size <= 0) { + return input; + } + auto shape = input.sizes().vec(); + const int64_t last_dim = shape.back(); + NVTE_CHECK(last_dim % (2 * glu_interleave_size) == 0, + "GLU interleaving requires the last dimension to be divisible by 2*interleave."); + check_contiguous(input, "GLU input"); + // Explicit layout materialization: GLU interleave changes memory order. + return input.view({-1, last_dim / (2 * glu_interleave_size), 2, glu_interleave_size}) + .transpose(1, 2) + .contiguous() + .view(shape); +} + +at::Tensor maybe_reinterleave_glu_grad(const at::Tensor &input, int64_t glu_interleave_size) { + if (glu_interleave_size <= 0) { + return input; + } + auto shape = input.sizes().vec(); + const int64_t last_dim = shape.back(); + check_contiguous(input, "GLU grad input"); + // Explicit layout materialization: reverse GLU interleave changes memory order. + return input.view({-1, 2, last_dim / (2 * glu_interleave_size), glu_interleave_size}) + .transpose(1, 2) + .contiguous() + .view(shape); +} + +at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &activation, + double activation_limit, double activation_alpha, + double activation_glu_linear_offset) { + const int64_t out_features = + is_gated_activation(activation) ? input.size(-1) / 2 : input.size(-1); + auto output = at::empty({input.size(0), out_features}, input.options()); + auto te_input = makeTransformerEngineTensor(input); + auto te_output = makeTransformerEngineTensor(output); + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + if (activation == "swiglu") { + nvte_swiglu(te_input.data(), te_output.data(), stream); + } else if (activation == "glu") { + nvte_glu(te_input.data(), te_output.data(), stream); + } else if (activation == "geglu") { + nvte_geglu(te_input.data(), te_output.data(), stream); + } else if (activation == "qgeglu") { + nvte_qgeglu(te_input.data(), te_output.data(), stream); + } else if (activation == "reglu") { + nvte_reglu(te_input.data(), te_output.data(), stream); + } else if (activation == "sreglu") { + nvte_sreglu(te_input.data(), te_output.data(), stream); + } else if (activation == "clamped_swiglu") { + nvte_clamped_swiglu_v2(te_input.data(), te_output.data(), static_cast(activation_limit), + static_cast(activation_alpha), + static_cast(activation_glu_linear_offset), stream); + } else if (activation == "srelu") { + nvte_srelu(te_input.data(), te_output.data(), stream); + } else if (activation == "gelu") { + nvte_gelu(te_input.data(), te_output.data(), stream); + } else if (activation == "qgelu") { + nvte_qgelu(te_input.data(), te_output.data(), stream); + } else if (activation == "relu") { + nvte_relu(te_input.data(), te_output.data(), stream); + } else if (activation == "silu") { + nvte_silu(te_input.data(), te_output.data(), stream); + } else { + NVTE_ERROR("Unsupported megacpp grouped MLP activation: ", activation); + } + }); + return output; +} + +at::Tensor activation_backward_impl(const at::Tensor &grad, const at::Tensor &input, + const std::string &activation, double activation_limit, + double activation_alpha, + double activation_glu_linear_offset) { + auto output = at::empty_like(input); + auto te_grad = makeTransformerEngineTensor(grad); + auto te_input = makeTransformerEngineTensor(input); + auto te_output = makeTransformerEngineTensor(output); + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + if (activation == "swiglu") { + nvte_dswiglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "glu") { + nvte_dglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "geglu") { + nvte_dgeglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "qgeglu") { + nvte_dqgeglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "reglu") { + nvte_dreglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "sreglu") { + nvte_dsreglu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "clamped_swiglu") { + nvte_clamped_dswiglu_v2(te_grad.data(), te_input.data(), te_output.data(), + static_cast(activation_limit), + static_cast(activation_alpha), + static_cast(activation_glu_linear_offset), stream); + } else if (activation == "srelu") { + nvte_dsrelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "gelu") { + nvte_dgelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "qgelu") { + nvte_dqgelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "relu") { + nvte_drelu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else if (activation == "silu") { + nvte_dsilu(te_grad.data(), te_input.data(), te_output.data(), stream); + } else { + NVTE_ERROR("Unsupported megacpp grouped MLP activation backward: ", activation); + } + }); + return output; +} + +at::Tensor grouped_mlp_activation_forward( + const at::Tensor &input, const std::optional &act_scales, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, at::ScalarType dtype) { + auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); + auto activation_output = activation_forward_impl(activation_input, activation, activation_limit, + activation_alpha, activation_glu_linear_offset); + if (!act_scales.has_value()) { + return activation_output; + } + auto act_scales_for_fc2 = maybe_cast_dtype(*act_scales, dtype); + check_contiguous(act_scales_for_fc2, "act_scales"); + return activation_output * act_scales_for_fc2.view({-1, 1}); +} + +struct ActivationBackwardResult { + at::Tensor grad_input; + at::Tensor grad_act_scales; +}; + +ActivationBackwardResult grouped_mlp_activation_backward( + const at::Tensor &grad_output, const at::Tensor &input, + const std::optional &act_scales, const std::string &activation, + int64_t glu_interleave_size, double activation_limit, double activation_alpha, + double activation_glu_linear_offset, at::ScalarType dtype, bool act_scales_requires_grad) { + auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); + + at::Tensor grad_activation_output = grad_output; + at::Tensor grad_act_scales; + if (act_scales.has_value()) { + if (act_scales_requires_grad) { + // Scaled activations compute y = activation(x) * act_scales[:, None]. + // Recompute activation(x) for dact_scales to match the Python basic-op + // path without saving another [tokens, hidden] activation tensor. + auto activation_output = + activation_forward_impl(activation_input, activation, activation_limit, activation_alpha, + activation_glu_linear_offset); + grad_act_scales = (activation_output * grad_output).sum(-1); + } + auto act_scales_for_grad = maybe_cast_dtype(*act_scales, dtype); + check_contiguous(act_scales_for_grad, "act_scales"); + grad_activation_output = grad_output * act_scales_for_grad.view({-1, 1}); + } + + auto grad_activation_input = + activation_backward_impl(grad_activation_output, activation_input, activation, activation_limit, + activation_alpha, activation_glu_linear_offset); + return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size), + grad_act_scales}; +} + +} // namespace + +std::vector megacpp_grouped_mlp_forward( + const at::Tensor &input, const at::Tensor &split_sizes, py::handle fc1_weight, + py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias, + const std::optional &act_scales, const std::string &activation, + int64_t glu_interleave_size, double activation_limit, double activation_alpha, + double activation_glu_linear_offset) { + NVTE_CHECK(input.is_cuda(), "megacpp_grouped_mlp_forward requires CUDA input."); + at::cuda::CUDAGuard device_guard(input.device()); + + const auto num_groups = static_cast(split_sizes.numel()); + NVTE_CHECK(num_groups > 0, "megacpp grouped MLP requires at least one group."); + + const auto dtype = input.scalar_type(); + NVTE_CHECK(dtype == at::kBFloat16 || dtype == at::kHalf, + "megacpp grouped MLP currently supports BF16/FP16 only."); + + auto fc1_weights = weight_arg_from_py(fc1_weight, num_groups, dtype, "fc1_weight"); + auto fc2_weights = weight_arg_from_py(fc2_weight, num_groups, dtype, "fc2_weight"); + const int64_t in_features = fc1_weights.cols; + const int64_t fc1_out_features = fc1_weights.rows; + const int64_t fc2_out_features = fc2_weights.rows; + const int64_t fc2_in_features = fc2_weights.cols; + const int64_t activation_out_features = + is_gated_activation(activation) ? fc1_out_features / 2 : fc1_out_features; + NVTE_CHECK(activation_out_features == fc2_in_features, + "FC1 activation output dimension must match FC2 input dimension."); + auto fc1_bias_tensor = + packed_bias_from_arg(fc1_bias, num_groups, dtype, fc1_out_features, "fc1_bias"); + auto fc2_bias_tensor = + packed_bias_from_arg(fc2_bias, num_groups, dtype, fc2_out_features, "fc2_bias"); + + auto x = maybe_cast_dtype(input, dtype); + check_contiguous(x, "input"); + x = x.view({-1, in_features}); + auto [split_sizes_i64, split_offsets] = splits_to_offsets_multi( + split_sizes, x.device(), + std::vector{1, in_features, fc1_out_features, fc2_in_features, fc2_out_features}, + std::vector{true, true, true, true, true}, + std::vector{at::kLong, at::kLong, at::kLong, at::kLong, at::kLong}, + true); + // splits_to_offsets_multi returns the canonical int64 CUDA split sizes and + // offsets in the same order as the stride list above. The CuTe path also asks + // for int32 split_points, but cuBLAS grouped GEMM does not consume them. + NVTE_CHECK(split_offsets.size() == 5, "Expected five grouped split-offset tensors."); + auto base_offsets = split_offsets[0]; + auto x_offsets = split_offsets[1]; + auto fc1_offsets = split_offsets[2]; + auto fc2_offsets = split_offsets[3]; + auto output_offsets = split_offsets[4]; + const int64_t total_tokens = x.size(0); + auto gemm_resources = make_grouped_mlp_backend_resources(x.device(), num_groups); + + auto fc1_preact = at::empty({total_tokens, fc1_out_features}, x.options()); + auto grouped_x = make_grouped_tensor(x.view({-1}), split_sizes_i64, x_offsets, in_features); + auto grouped_fc1_preact = + make_grouped_tensor(fc1_preact.view({-1}), split_sizes_i64, fc1_offsets, fc1_out_features); + grouped_gemm_fwd_dgrad(&fc1_weights, true, &grouped_x, false, &grouped_fc1_preact, + &gemm_resources); + add_grouped_bias(&grouped_fc1_preact, fc1_bias_tensor, num_groups, dtype, fc1_out_features); + + auto fc2_x = + grouped_mlp_activation_forward(fc1_preact, act_scales, activation, glu_interleave_size, + activation_limit, activation_alpha, + activation_glu_linear_offset, dtype); + + std::vector out_shape = input.sizes().vec(); + out_shape.back() = fc2_out_features; + auto output = at::empty(out_shape, x.options()); + auto output_2d = output.view({-1, fc2_out_features}); + auto grouped_fc2_x = + make_grouped_tensor(fc2_x.view({-1}), split_sizes_i64, fc2_offsets, fc2_in_features); + auto grouped_output = + make_grouped_tensor(output_2d.view({-1}), split_sizes_i64, output_offsets, fc2_out_features); + grouped_gemm_fwd_dgrad(&fc2_weights, true, &grouped_fc2_x, false, &grouped_output, + &gemm_resources); + add_grouped_bias(&grouped_output, fc2_bias_tensor, num_groups, dtype, fc2_out_features); + + return {output, x, split_sizes_i64, base_offsets, x_offsets, fc1_offsets, fc2_offsets, + output_offsets, fc1_preact, fc2_x}; +} + +py::tuple megacpp_grouped_mlp_backward( + const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, + const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, + const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, + const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, + const std::optional &act_scales, py::handle fc1_weight, + py::handle fc2_weight, + py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, + py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, + const std::string &activation, int64_t glu_interleave_size, double activation_limit, + double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, + bool input_requires_grad) { + (void)base_offsets; + NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output."); + at::cuda::CUDAGuard device_guard(grad_output.device()); + + const auto num_groups = num_groups_from_prepared_split_sizes(split_sizes, grad_output.device()); + const auto dtype = grad_output.scalar_type(); + auto fc1_weights = weight_arg_from_py(fc1_weight, num_groups, dtype, "fc1_weight"); + auto fc2_weights = weight_arg_from_py(fc2_weight, num_groups, dtype, "fc2_weight"); + + const int64_t in_features = fc1_weights.cols; + const int64_t fc1_out_features = fc1_weights.rows; + const int64_t fc2_out_features = fc2_weights.rows; + const int64_t fc2_in_features = fc2_weights.cols; + + auto dy = maybe_cast_dtype(grad_output, dtype); + check_contiguous(dy, "grad_output"); + dy = dy.view({-1, fc2_out_features}); + const int64_t total_tokens = dy.size(0); + auto gemm_resources = make_grouped_mlp_backend_resources(grad_output.device(), num_groups); + + auto grouped_dy = + make_grouped_tensor(dy.view({-1}), split_sizes, fc2_dy_offsets, fc2_out_features); + std::vector fc2_wgrads; + if (fc2_compute_wgrad) { + auto fc2_x_for_wgrad = maybe_cast_dtype(fc2_x, dtype); + check_contiguous(fc2_x_for_wgrad, "fc2_x"); + fc2_x_for_wgrad = fc2_x_for_wgrad.view({-1, fc2_in_features}); + auto grouped_fc2_x_for_wgrad = + make_grouped_tensor(fc2_x_for_wgrad.view({-1}), split_sizes, fc2_offsets, fc2_in_features); + fc2_wgrads = + grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output, + fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype, + fc2_out_features, fc2_in_features, "fc2_wgrad_output", + fc2_weights.is_grouped); + } + + auto fc2_dx = at::empty({total_tokens, fc2_in_features}, dy.options()); + auto grouped_fc2_dx = + make_grouped_tensor(fc2_dx.view({-1}), split_sizes, fc2_offsets, fc2_in_features); + grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, + &gemm_resources); + + auto activation_grads = grouped_mlp_activation_backward( + fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, + activation_limit, activation_alpha, activation_glu_linear_offset, dtype, + act_scales_requires_grad); + auto fc1_dy = activation_grads.grad_input; + auto grad_act_scales = activation_grads.grad_act_scales; + auto grouped_fc1_dy = + make_grouped_tensor(fc1_dy.view({-1}), split_sizes, fc1_offsets, fc1_out_features); + + std::vector fc1_wgrads; + if (fc1_compute_wgrad) { + auto x_for_wgrad = maybe_cast_dtype(x, dtype); + check_contiguous(x_for_wgrad, "x"); + x_for_wgrad = x_for_wgrad.view({-1, in_features}); + auto grouped_x_for_wgrad = + make_grouped_tensor(x_for_wgrad.view({-1}), split_sizes, x_offsets, in_features); + fc1_wgrads = + grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output, + fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype, + fc1_out_features, in_features, "fc1_wgrad_output", + fc1_weights.is_grouped); + } + + at::Tensor grad_input; + if (input_requires_grad) { + std::vector grad_input_shape = grad_output.sizes().vec(); + grad_input_shape.back() = in_features; + grad_input = at::empty(grad_input_shape, dy.options()); + auto grad_input_2d = grad_input.view({-1, in_features}); + auto grouped_grad_input = make_grouped_tensor(grad_input_2d.view({-1}), split_sizes, + x_offsets, in_features); + grouped_gemm_fwd_dgrad(&fc1_weights, false, &grouped_fc1_dy, false, &grouped_grad_input, + &gemm_resources); + } else { + grad_input = at::empty({0}, dy.options()); + } + + auto empty_return = at::empty({0}, dy.options()); + if (!grad_act_scales.defined()) { + grad_act_scales = empty_return; + } + return py::make_tuple(grad_input, fc1_dy, grad_act_scales, fc1_wgrads, fc2_wgrads); +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 78f9d880ba..fd09162ade 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -39,3 +39,9 @@ BackwardGroupedMLP_CuTeGEMMDGLU, BackwardGroupedMLP_CuTeGEMMDUnary, ) +from .forward_grouped_mlp_megacpp import ( # pylint: disable=wrong-import-position + ForwardGroupedMLP_MegaCpp, +) +from .backward_grouped_mlp_megacpp import ( # pylint: disable=wrong-import-position + BackwardGroupedMLP_MegaCpp, +) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py new file mode 100644 index 0000000000..ebaf30d075 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py @@ -0,0 +1,392 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mega C++ grouped MLP backward fuser.""" + +from __future__ import annotations +import functools +from typing import Optional + +import torch + +import transformer_engine_torch as tex +from ...quantization import Recipe +from ...utils import clear_tensor_data, get_device_compute_capability +from ...triton.grouped_dbias_dscales import compute_grouped_dbias +from ..basic import GroupedLinear +from ..fuser import register_backward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import ( + get_accumulate_flag_in_param, + get_dummy_wgrads_for_params, + get_main_grad_from_param, + view_main_grad_as_grouped_buffer, +) +from .forward_grouped_mlp_megacpp import ( + _megacpp_activation_config, + _megacpp_enabled, + _megacpp_supports_recipe, + _resolve_megacpp_grouped_mlp_config, +) + + +def _megacpp_saved_weight_arg( + saved_tensors: tuple[torch.Tensor, ...], + *, + single_weight_arg: bool, + num_groups: int, +) -> tuple[torch.Tensor | list[torch.Tensor], tuple[torch.Tensor, ...]]: + """Unpack saved C++ weight argument in the same shape used by forward.""" + if single_weight_arg: + return saved_tensors[0], saved_tensors[1:] + return list(saved_tensors[:num_groups]), saved_tensors[num_groups:] + + +def _delay_wgrad(fc_op: GroupedLinear, ctx: OperationContext) -> bool: + """Whether this FC op requested unsupported delayed wgrad.""" + return bool( + ctx.weight_requires_grad + and fc_op.wgrad_store is not None + and fc_op.wgrad_store.delay_wgrad_compute() + ) + + +def _compute_bias_grad_params( + fc_op: GroupedLinear, + dy_2d: torch.Tensor, + base_offsets: torch.Tensor, + *, + num_groups: int, + dtype: torch.dtype, +) -> tuple[Optional[list[torch.Tensor]], Optional[torch.Tensor]]: + """Compute bias grads in GroupedLinear parameter layout.""" + if not fc_op.has_bias: + return None, None + dbias_packed = compute_grouped_dbias(dy_2d, base_offsets, num_groups).to(dtype=dtype) + if fc_op.single_grouped_bias: + return None, dbias_packed + return [dbias_packed[idx] for idx in range(num_groups)], None + + +def _prepare_cpp_wgrad_output( + fc_op: GroupedLinear, + ctx: OperationContext, + *, + num_groups: int, + weight_shape: tuple[int, int], + label: str, +) -> tuple[Optional[torch.Tensor | list[torch.Tensor]], bool, bool, list[Optional[torch.Tensor]]]: + """Return an optional externally-owned wgrad buffer for C++. + + If Megatron has already installed ``main_grad`` buffers, C++ writes into + them. Otherwise this returns ``None`` and C++ allocates/returns a packed + ``[num_groups, out_features, in_features]`` wgrad tensor. + """ + weights = fc_op._get_weight_tensors() + weight_grads: list[Optional[torch.Tensor]] = ( + [None] if fc_op.single_grouped_weight else [None] * num_groups + ) + if _delay_wgrad(fc_op, ctx): + raise ValueError("megacpp grouped MLP does not support delay_wgrad_compute=True.") + if not ctx.weight_requires_grad: + return None, False, False, weight_grads + + accumulate_into_main_grad = False + if fc_op.single_grouped_weight: + if fc_op._accumulate_into_main_grad: + main_grad = get_main_grad_from_param(weights[0], op_label=label) + wgrad_output = view_main_grad_as_grouped_buffer( + main_grad, + num_groups, + weight_shape, + label=f"{label} weight", + ) + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + weight_grads = get_dummy_wgrads_for_params(weights) + else: + wgrad_output = None + else: + if fc_op._accumulate_into_main_grad: + wgrad_output = [get_main_grad_from_param(w, op_label=label) for w in weights] + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + weight_grads = get_dummy_wgrads_for_params(weights) + else: + wgrad_output = None + + return wgrad_output, True, accumulate_into_main_grad, weight_grads + + +def _assemble_grad_params( + fc_op: GroupedLinear, + weight_grads: list[Optional[torch.Tensor]], + bias_grads: Optional[list[torch.Tensor]], + bias_grad_packed: Optional[torch.Tensor], + *, + num_groups: int, +) -> list[Optional[torch.Tensor]]: + """Assemble parameter grads in GroupedLinear registration order.""" + if not fc_op.has_bias: + return weight_grads + if fc_op.single_grouped_bias: + return weight_grads + [bias_grad_packed] + bias_list = bias_grads if bias_grads is not None else [None] * num_groups + if fc_op.single_grouped_weight: + return bias_list + weight_grads + return weight_grads + bias_list + + +class BackwardGroupedMLP_MegaCpp(FusedOperation): + """Experimental C++ grouped MLP backward for BF16/FP16. + + Weight gradients are computed in C++. Delayed wgrad is intentionally not + supported in this first implementation to keep ownership and lifetime rules + simple. + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + if not torch.cuda.is_available(): + return False + if get_device_compute_capability()[0] < 10: + return False + return hasattr(tex, "megacpp_grouped_mlp_backward") + + def __init__( + self, + *, + fc1: GroupedLinear, + activation: Optional[FusibleOperation], + fc2: GroupedLinear, + ) -> None: + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) + _resolve_megacpp_grouped_mlp_config(fc1, activation, fc2) + if fc1._scale_bias or fc2._scale_bias: + raise RuntimeError("megacpp grouped MLP does not support scale_bias yet.") + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + **unused, # pylint: disable=unused-argument + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + fc1_op, activation_op, fc2_op = self.basic_ops + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs + num_groups = fc1_op.num_groups + dtype = fc1_ctx.dtype + + fc1_saved = fc1_ctx.saved_tensors + split_sizes, base_offsets, x_offsets, fc1_offsets = fc1_saved[:4] + x, fc1_activation_input = fc1_saved[4:6] + fc1_weight_arg, _ = _megacpp_saved_weight_arg( + fc1_saved[6:], + single_weight_arg=bool(getattr(fc1_ctx, "single_weight_arg", False)), + num_groups=num_groups, + ) + + activation_config = _megacpp_activation_config(activation_op) + _, act_scales = activation_ctx.saved_tensors + + fc2_saved = fc2_ctx.saved_tensors + fc2_offsets = fc2_saved[2] + fc2_dy_offsets = fc2_saved[3] + fc2_x = fc2_saved[4] + fc2_weight_arg, _ = _megacpp_saved_weight_arg( + fc2_saved[5:], + single_weight_arg=bool(getattr(fc2_ctx, "single_weight_arg", False)), + num_groups=num_groups, + ) + + ( + fc1_wgrad_output, + fc1_compute_wgrad, + fc1_accumulate_wgrad, + fc1_weight_grads, + ) = _prepare_cpp_wgrad_output( + fc1_op, + fc1_ctx, + num_groups=num_groups, + weight_shape=(fc1_op.out_features, fc1_op.in_features), + label="Grouped MLP megacpp backward (FC1)", + ) + ( + fc2_wgrad_output, + fc2_compute_wgrad, + fc2_accumulate_wgrad, + fc2_weight_grads, + ) = _prepare_cpp_wgrad_output( + fc2_op, + fc2_ctx, + num_groups=num_groups, + weight_shape=(fc2_op.out_features, fc2_op.in_features), + label="Grouped MLP megacpp backward (FC2)", + ) + ( + grad_input, + fc1_dy, + grad_act_scales, + fc1_owned_weight_grads, + fc2_owned_weight_grads, + ) = tex.megacpp_grouped_mlp_backward( + grad_output.to(dtype=dtype), + split_sizes, + x_offsets, + fc1_offsets, + fc2_offsets, + fc2_dy_offsets, + base_offsets, + x, + fc1_activation_input, + fc2_x, + act_scales, + fc1_weight_arg, + fc2_weight_arg, + fc1_wgrad_output, + fc1_compute_wgrad, + fc1_accumulate_wgrad, + fc2_wgrad_output, + fc2_compute_wgrad, + fc2_accumulate_wgrad, + activation_config.name, + activation_config.glu_interleave_size, + activation_config.limit, + activation_config.alpha, + activation_config.glu_linear_offset, + bool(activation_ctx.extra_input_requires_grad), + bool(fc1_ctx.input_requires_grad), + ) + if not fc1_ctx.input_requires_grad: + grad_input = None + + grad_output_2d = grad_output.reshape(-1, fc2_op.out_features).to(dtype=dtype) + fc2_bias_grads, fc2_bias_grad_packed = _compute_bias_grad_params( + fc2_op, + grad_output_2d, + base_offsets, + num_groups=num_groups, + dtype=dtype, + ) + fc1_bias_grads, fc1_bias_grad_packed = _compute_bias_grad_params( + fc1_op, + fc1_dy, + base_offsets, + num_groups=num_groups, + dtype=dtype, + ) + + # Wgrad ownership cases: + # 1. No weight grad: keep [None] placeholders prepared above. + # 2. Megatron-owned main_grad: C++ wrote into the provided buffer; + # keep dummy wgrads prepared above for autograd. + # 3. C++-owned allocation: replace the placeholder list with returned + # wgrads. Single grouped weight returns [packed], discrete weights + # return one tensor per expert. + if fc2_ctx.weight_requires_grad and not fc2_op._accumulate_into_main_grad: + expected_wgrads = 1 if fc2_op.single_grouped_weight else num_groups + if len(fc2_owned_weight_grads) != expected_wgrads: + raise RuntimeError(f"FC2 expected {expected_wgrads} owned wgrad tensors.") + fc2_weight_grads = fc2_owned_weight_grads + fc2_grad_params = _assemble_grad_params( + fc2_op, + fc2_weight_grads, + fc2_bias_grads, + fc2_bias_grad_packed, + num_groups=num_groups, + ) + clear_tensor_data(fc2_x) + + # Same ownership policy as FC2. Megatron-owned main_grad keeps the + # prepared dummy grads; C++-owned allocation uses the returned wgrads. + if fc1_ctx.weight_requires_grad and not fc1_op._accumulate_into_main_grad: + expected_wgrads = 1 if fc1_op.single_grouped_weight else num_groups + if len(fc1_owned_weight_grads) != expected_wgrads: + raise RuntimeError(f"FC1 expected {expected_wgrads} owned wgrad tensors.") + fc1_weight_grads = fc1_owned_weight_grads + fc1_grad_params = _assemble_grad_params( + fc1_op, + fc1_weight_grads, + fc1_bias_grads, + fc1_bias_grad_packed, + num_groups=num_groups, + ) + clear_tensor_data(x) + + activation_grad_extra = ( + (grad_act_scales.to(dtype=dtype),) + if activation_ctx.extra_input_requires_grad + else (None,) + ) + + return grad_input, [fc1_grad_params, (), fc2_grad_params], [ + (None,), + activation_grad_extra, + (None,), + ] + + +def fuse_backward_megacpp_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply opt-in C++ grouped MLP backward fusion for BF16/FP16.""" + if not _megacpp_enabled(): + return ops + if not _megacpp_supports_recipe(recipe): + return ops + if not BackwardGroupedMLP_MegaCpp.is_supported(): + return ops + + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[2], GroupedLinear) + ): + matches_pattern = False + elif ( + window[0]._scale_bias + or window[2]._scale_bias + ): + matches_pattern = False + else: + try: + _resolve_megacpp_grouped_mlp_config(window[0], window[1], window[2]) + except (TypeError, ValueError, RuntimeError): + matches_pattern = False + + if matches_pattern: + window = [ + BackwardGroupedMLP_MegaCpp( + fc1=window[0], + activation=window[1], + fc2=window[2], + ) + ] + else: + out.extend(window[:-2]) + window = window[-2:] + + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + out.extend(window) + return out + + +# Use the same opt-in and recipe gate as forward. Unsupported recipes fall +# through unchanged so the matching recipe-specific backward fuser can run. +register_backward_fusion(fuse_backward_megacpp_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py new file mode 100644 index 0000000000..bd05b6218f --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py @@ -0,0 +1,382 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mega C++ grouped MLP forward fuser.""" + +from __future__ import annotations +from collections.abc import Iterable +import functools +import os +from typing import Any, NamedTuple, Optional + +import torch + +import transformer_engine_torch as tex +from ...quantization import Recipe +from ...tensor import Quantizer +from ...utils import get_device_compute_capability +from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU, ScaledSwiGLU +from ..fuser import register_forward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext + + +def _megacpp_enabled() -> bool: + """Whether the experimental grouped MLP C++ path is explicitly enabled.""" + return int(os.getenv("NVTE_MEGACPP_GROUPED_LINEAR", "0")) > 0 + + +def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool: + """Whether megacpp is a valid candidate for the active quantization recipe. + + Today the C++ implementation is BF16/FP16-only, so only the no-recipe path + is supported. Returning False for FP8 recipes is intentional: it leaves the + op list unchanged so the existing MXFP8/NVFP4 CuTe DSL fusers can match. + Future MXFP8/NVFP4 support should be enabled by changing this predicate, + not by reordering fusion registrations. + """ + if recipe is None: + return True + if recipe.mxfp8() or recipe.nvfp4(): + return False + return False + + +class _MegaCppActivationConfig(NamedTuple): + """Activation semantics consumed by the C++ grouped MLP path.""" + + name: str + is_scaled: bool + is_gated: bool + glu_interleave_size: int + limit: float = 0.0 + alpha: float = 0.0 + glu_linear_offset: float = 0.0 + + +def _megacpp_activation_config(activation) -> _MegaCppActivationConfig: + """Return activation parameters consumed by the C++ grouped MLP path.""" + glu_interleave_size = int(getattr(activation, "glu_interleave_size", None) or 0) + if isinstance(activation, ScaledSwiGLU): + return _MegaCppActivationConfig("swiglu", True, True, glu_interleave_size) + if isinstance(activation, ScaledClampedQGeGLU): + return _MegaCppActivationConfig( + "clamped_swiglu", + True, + True, + glu_interleave_size, + float(activation._clamped.limit), + float(activation._clamped.alpha), + float(activation._clamped.glu_linear_offset), + ) + if isinstance(activation, ScaledSReLU): + return _MegaCppActivationConfig("srelu", True, False, 0) + if getattr(activation, "num_extra_inputs", 0) == 0: + return _MegaCppActivationConfig("plain_unsupported", False, False, 0) + raise TypeError( + "megacpp grouped MLP currently supports only ScaledSwiGLU, " + "ScaledClampedQGeGLU, and ScaledSReLU." + ) + + +def _resolve_megacpp_grouped_mlp_config( + fc1: GroupedLinear, + activation, + fc2: GroupedLinear, +) -> _MegaCppActivationConfig: + """Resolve megacpp activation config and validate grouped MLP support.""" + config = _megacpp_activation_config(activation) + if not config.is_scaled: + raise RuntimeError( + "megacpp grouped MLP keeps an optional-scale activation API, but plain " + f"{activation.__class__.__name__} is not supported yet." + ) + if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0: + raise ValueError( + f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, " + f"in_features={fc1.in_features}, out_features={fc1.out_features})." + ) + if fc2.in_features % 64 != 0 or fc2.out_features % 64 != 0: + raise ValueError( + f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " + f"in_features={fc2.in_features}, out_features={fc2.out_features})." + ) + expected_fc1_out_features = 2 * fc2.in_features if config.is_gated else fc2.in_features + if fc1.out_features != expected_fc1_out_features or fc1.num_groups != fc2.num_groups: + raise ValueError( + f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " + f"out_features={fc1.out_features}) " + f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " + f"out_features={fc2.out_features}) do not match." + ) + if config.glu_interleave_size and fc1.out_features % (2 * config.glu_interleave_size) != 0: + raise ValueError( + "GLU interleaving requires FC1 out_features to be divisible by " + f"2*glu_interleave_size, got out_features={fc1.out_features}, " + f"glu_interleave_size={config.glu_interleave_size}." + ) + return config + + +def _megacpp_weight_arg( + linear_op: GroupedLinear, + dtype: torch.dtype, + *, + input_requires_grad: bool, +) -> torch.Tensor | list[torch.Tensor]: + """Return GEMM-ready high-precision weights for the current C++ path. + + Keep the layout policy in GroupedLinear. This handles quantized weights the + same way as the Python grouped GEMM path: BF16/FP16 compute dequantizes when + needed, while a future quantized-compute path can preserve quantized weights + by switching ``with_quantized_compute``. + """ + with_quantized_compute = False + if linear_op.single_grouped_weight: + grouped_weight = linear_op._get_grouped_weight_for_gemm( + linear_op.weight, + [linear_op.get_quantizer("forward", 1)], + columnwise_usage=input_requires_grad, + with_quantized_compute=with_quantized_compute, + dtype=dtype, + ) + if grouped_weight.rowwise_data is None: + raise RuntimeError("megacpp grouped MLP expected dense grouped weight rowwise_data.") + # Keep single grouped weight packed. The C++ path wraps this as a + # uniform GroupedTensor and dispatches nvte_grouped_gemm instead of + # expanding it into per-expert discrete tensors. + return grouped_weight.rowwise_data.view( + linear_op.num_groups, + linear_op.out_features, + linear_op.in_features, + ) + return linear_op._get_discrete_weights_for_gemm( + [getattr(linear_op, f"weight{idx}") for idx in range(linear_op.num_groups)], + [linear_op.get_quantizer("forward", 2 * idx + 1) for idx in range(linear_op.num_groups)], + columnwise_usage=input_requires_grad, + with_quantized_compute=with_quantized_compute, + dtype=dtype, + ) + + +def _megacpp_bias_arg(linear_op: GroupedLinear, dtype: torch.dtype) -> Optional[torch.Tensor]: + """Return a packed [G, N] high-precision bias tensor or None.""" + grouped_bias = linear_op._get_grouped_bias_for_gemm(dtype) + if grouped_bias is None: + return None + return grouped_bias.rowwise_data.view(linear_op.num_groups, linear_op.out_features) + + +class ForwardGroupedMLP_MegaCpp(FusedOperation): + """Experimental BF16/FP16 grouped MLP forward implemented in C++. + + The C++ function returns plain tensors only. Python still owns autograd + context layout; delayed wgrad is rejected by the matching backward op. + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether the C++ grouped MLP path can be dispatched.""" + if not torch.cuda.is_available(): + return False + if get_device_compute_capability()[0] < 10: + return False + return hasattr(tex, "megacpp_grouped_mlp_forward") + + def __init__( + self, + *, + fc1: GroupedLinear, + activation: Optional[FusibleOperation], + fc2: GroupedLinear, + ) -> None: + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) + _resolve_megacpp_grouped_mlp_config(fc1, activation, fc2) + if fc1._scale_bias or fc2._scale_bias: + raise RuntimeError("megacpp grouped MLP does not support scale_bias yet.") + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + del prev_op_grad_output_quantizer, next_op_input_quantizer, basic_op_kwargs + fc1_op, activation_op, fc2_op = self.basic_ops + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs + num_groups = fc1_op.num_groups + + split_sizes = basic_op_extra_inputs[0][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] + if ( + split_sizes.size() != fc2_split_sizes.size() + or split_sizes.data_ptr() != fc2_split_sizes.data_ptr() + ): + raise RuntimeError(f"{self.__class__.__name__} got different split sizes for FC1/FC2.") + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, got {int(split_sizes.numel())}.") + + activation_config = _megacpp_activation_config(activation_op) + act_scales = basic_op_extra_inputs[1][0] + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + dtype = ( + torch.get_autocast_dtype("cuda") + if torch.is_autocast_enabled() + else fc1_weight_param.dtype + ) + if dtype not in (torch.bfloat16, torch.float16): + raise RuntimeError(f"megacpp grouped MLP supports BF16/FP16 only, got {dtype}.") + + requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) + input_requires_grad = requires_grad + fc1_weight_requires_grad = requires_grad and fc1_weight_param.requires_grad + fc2_weight_requires_grad = requires_grad and fc2_weight_param.requires_grad + + fc1_weights = _megacpp_weight_arg( + fc1_op, + dtype, + input_requires_grad=input_requires_grad, + ) + fc2_weights = _megacpp_weight_arg( + fc2_op, + dtype, + input_requires_grad=input_requires_grad, + ) + ( + fc2_out, + x, + split_sizes_i64, + base_split_offsets, + x_offsets, + fc1_offsets, + fc2_offsets, + fc2_dy_offsets, + fc1_activation_input, + fc2_x, + ) = tex.megacpp_grouped_mlp_forward( + input_.to(dtype=dtype), + split_sizes, + fc1_weights, + _megacpp_bias_arg(fc1_op, dtype), + fc2_weights, + _megacpp_bias_arg(fc2_op, dtype), + act_scales, + activation_config.name, + activation_config.glu_interleave_size, + activation_config.limit, + activation_config.alpha, + activation_config.glu_linear_offset, + ) + + if x.data_ptr() == input_.data_ptr(): + x._do_not_clear = True + + if requires_grad: + fc1_saved_weights = [fc1_weights] if isinstance(fc1_weights, torch.Tensor) else fc1_weights + fc2_saved_weights = [fc2_weights] if isinstance(fc2_weights, torch.Tensor) else fc2_weights + + fc1_ctx.save_for_backward( + split_sizes_i64, + base_split_offsets, + x_offsets, + fc1_offsets, + x, + fc1_activation_input, + *fc1_saved_weights, + ) + fc1_ctx.use_megacpp_grouped_mlp = True + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = fc1_weight_requires_grad + fc1_ctx.single_weight_arg = isinstance(fc1_weights, torch.Tensor) + + activation_ctx.save_for_backward(fc1_activation_input, act_scales) + activation_ctx.extra_input_requires_grad = act_scales.requires_grad + activation_ctx.input_requires_grad = True + activation_ctx.dtype = dtype + + fc2_ctx.save_for_backward( + split_sizes_i64, + base_split_offsets, + fc2_offsets, + fc2_dy_offsets, + fc2_x, + *fc2_saved_weights, + ) + fc2_ctx.use_megacpp_grouped_mlp = True + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = fc2_weight_requires_grad + fc2_ctx.single_weight_arg = isinstance(fc2_weights, torch.Tensor) + + return fc2_out, [(), (), ()] + + +def fuse_forward_megacpp_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply opt-in C++ grouped MLP fusion for BF16/FP16.""" + if not _megacpp_enabled(): + return ops + if not _megacpp_supports_recipe(recipe): + return ops + if not ForwardGroupedMLP_MegaCpp.is_supported(): + return ops + + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[2], GroupedLinear) + ): + matches_pattern = False + elif ( + window[0]._scale_bias + or window[2]._scale_bias + ): + matches_pattern = False + else: + try: + _resolve_megacpp_grouped_mlp_config(window[0], window[1], window[2]) + except (TypeError, ValueError, RuntimeError): + matches_pattern = False + + if matches_pattern: + window = [ + ForwardGroupedMLP_MegaCpp( + fc1=window[0], + activation=window[1], + fc2=window[2], + ) + ] + else: + out.extend(window[:-2]) + window = window[-2:] + + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + out.extend(window) + return out + + +# Explicit env opt-in gives megacpp first chance. Unsupported recipes intentionally +# return the ops unchanged so lower-priority recipe-specific fusers remain the +# fallback path. +register_forward_fusion(fuse_forward_megacpp_ops, prepend=True) From 1120ec9e4182f4d33e5b9423ff144029f5ce200b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Jun 2026 07:39:18 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/megacpp/test_grouped_mlp.py | 6 +- transformer_engine/pytorch/csrc/extensions.h | 18 ++- .../pytorch/csrc/extensions/pybind.cpp | 6 +- .../pytorch/csrc/megacpp/grouped_mlp.cpp | 122 ++++++++---------- .../ops/fused/backward_grouped_mlp_megacpp.py | 24 ++-- .../ops/fused/forward_grouped_mlp_megacpp.py | 18 ++- 6 files changed, 85 insertions(+), 109 deletions(-) diff --git a/tests/pytorch/megacpp/test_grouped_mlp.py b/tests/pytorch/megacpp/test_grouped_mlp.py index ddddcb7fc4..d3cc9cd04c 100644 --- a/tests/pytorch/megacpp/test_grouped_mlp.py +++ b/tests/pytorch/megacpp/test_grouped_mlp.py @@ -464,10 +464,10 @@ def test_megacpp_grouped_mlp_delay_wgrad_raises(monkeypatch): glu_interleave_size=None, single_grouped_param=False, ) - x = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16).requires_grad_() - act_scales = torch.rand( - total_tokens, device="cuda", dtype=torch.bfloat16 + x = torch.randn( + total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16 ).requires_grad_() + act_scales = torch.rand(total_tokens, device="cuda", dtype=torch.bfloat16).requires_grad_() dy = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16) monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "1") diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a59e85456d..1fd334ea69 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -198,16 +198,14 @@ std::vector megacpp_grouped_mlp_forward( py::tuple megacpp_grouped_mlp_backward( const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, - const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, - const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, - const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, - const std::optional &act_scales, py::handle fc1_weight, - py::handle fc2_weight, - py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, - py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, - const std::string &activation, int64_t glu_interleave_size, double activation_limit, - double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, - bool input_requires_grad); + const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets, + const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input, + const at::Tensor &fc2_x, const std::optional &act_scales, py::handle fc1_weight, + py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad, + bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad, + bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size, + double activation_limit, double activation_alpha, double activation_glu_linear_offset, + bool act_scales_requires_grad, bool input_requires_grad); /*************************************************************************************************** * Transpose diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 78c9e280f3..d70b6cb813 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -357,15 +357,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("te_general_grouped_gemm_for_discrete_out", &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, "Grouped GEMM for discrete output list"); - m.def("megacpp_grouped_mlp_forward", - &transformer_engine::pytorch::megacpp_grouped_mlp_forward, + m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward, "Mega C++ grouped MLP forward", py::arg("input"), py::arg("split_sizes"), py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"), py::arg("fc2_bias"), py::arg("act_scales"), py::arg("activation"), py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0); - m.def("megacpp_grouped_mlp_backward", - &transformer_engine::pytorch::megacpp_grouped_mlp_backward, + m.def("megacpp_grouped_mlp_backward", &transformer_engine::pytorch::megacpp_grouped_mlp_backward, "Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("split_sizes"), py::arg("x_offsets"), py::arg("fc1_offsets"), py::arg("fc2_offsets"), py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"), diff --git a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp index 2f9a642041..f85837f40a 100644 --- a/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp +++ b/transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp @@ -4,6 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include #include #include @@ -12,9 +14,6 @@ #include #include -#include -#include - #include "../extensions.h" #include "../pybind.h" #include "common/util/cuda_runtime.h" @@ -58,7 +57,8 @@ size_t num_groups_from_prepared_split_sizes(const at::Tensor &split_sizes, } GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prepared_split_sizes, - const at::Tensor &tensor_offsets, int64_t logical_last_dim) { + const at::Tensor &tensor_offsets, + int64_t logical_last_dim) { const auto num_groups = static_cast(prepared_split_sizes.numel()); const auto total_tokens = static_cast(data.numel() / logical_last_dim); auto grouped = GroupedTensorWrapper( @@ -75,9 +75,8 @@ GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prep GroupedTensorWrapper make_uniform_grouped_tensor(at::Tensor data, size_t num_groups, int64_t first_dim, int64_t last_dim) { auto grouped = GroupedTensorWrapper( - num_groups, - std::vector{num_groups * static_cast(first_dim), - static_cast(last_dim)}); + num_groups, std::vector{num_groups * static_cast(first_dim), + static_cast(last_dim)}); grouped.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()), tensor_shape_1d(data)); return grouped; @@ -94,9 +93,7 @@ struct GroupedWeightArg { int64_t rows = 0; int64_t cols = 0; - c10::Device device() const { - return is_grouped ? packed.device() : discrete[0].device(); - } + c10::Device device() const { return is_grouped ? packed.device() : discrete[0].device(); } }; GroupedWeightArg weight_arg_from_py(py::handle arg, size_t num_groups, at::ScalarType dtype, @@ -201,9 +198,9 @@ struct GroupedGemmResources { te_alpha(makeTransformerEngineTensor(alpha)), te_beta_zero(makeTransformerEngineTensor(beta_zero)), te_beta_one(makeTransformerEngineTensor(beta_one)), - te_setup(makeTransformerEngineTensor(setup.data_ptr(), - std::vector{static_cast(setup.numel())}, - DType::kByte)), + te_setup(makeTransformerEngineTensor( + setup.data_ptr(), std::vector{static_cast(setup.numel())}, + DType::kByte)), te_cublas(makeTransformerEngineTensor( cublas.data_ptr(), std::vector{static_cast(cublas.numel())}, DType::kByte)) { @@ -220,9 +217,7 @@ struct GroupedGemmResources { } } - NVTETensor beta(bool accumulate) { - return accumulate ? te_beta_one.data() : te_beta_zero.data(); - } + NVTETensor beta(bool accumulate) { return accumulate ? te_beta_one.data() : te_beta_zero.data(); } NVTEGroupedMatmulConfig config_data() { return config.has_value() ? static_cast(*config) : nullptr; @@ -243,14 +238,12 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B, nvte_grouped_gemm(A->data(), transa, B->data(), transb, D->data(), D->data(), resources->te_alpha.data(), resources->beta(accumulate), resources->te_setup.data(), resources->te_cublas.data(), - resources->config_data(), - at::cuda::getCurrentCUDAStream()); + resources->config_data(), at::cuda::getCurrentCUDAStream()); }); } std::vector output_tensor_list_from_arg(py::handle arg, size_t num_groups, - at::ScalarType dtype, - const std::string &name) { + at::ScalarType dtype, const std::string &name) { std::vector out; if (is_none(arg)) { return out; @@ -303,8 +296,8 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num // Cases 1 and 2: no external wgrad buffer was provided, so C++ owns the // allocation. Single grouped weight keeps this packed as [G, N, K]; // discrete weights split the same packed allocation into per-expert views. - out.packed = at::empty({static_cast(num_groups), rows, cols}, - at::device(device).dtype(dtype)); + out.packed = + at::empty({static_cast(num_groups), rows, cols}, at::device(device).dtype(dtype)); out.owns_storage = true; out.is_grouped = prefer_grouped_output; if (out.is_grouped) { @@ -345,9 +338,8 @@ void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight, if (weights->is_grouped) { // Single grouped weight case: weights are packed as [G, N, K]. Wrap the // packed buffer as a uniform GroupedTensor and use the grouped-tensor GEMM. - auto grouped_weight = - make_uniform_grouped_tensor(weights->packed, input->num_tensors(), weights->rows, - weights->cols); + auto grouped_weight = make_uniform_grouped_tensor(weights->packed, input->num_tensors(), + weights->rows, weights->cols); grouped_gemm(&grouped_weight, trans_weight, input, trans_input, output, resources, false); } else { // Discrete weight case: weights are a list of per-expert tensors. Use the @@ -413,7 +405,8 @@ GroupedTensorWrapper make_grouped_bias(const at::Tensor &bias, size_t num_groups NVTE_CHECK(bias.defined(), "Bias tensor must be defined."); auto grouped = GroupedTensorWrapper( num_groups, std::vector{num_groups, static_cast(out_features)}); - grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(dtype), tensor_shape_1d(bias)); + grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(dtype), + tensor_shape_1d(bias)); return grouped; } @@ -498,7 +491,8 @@ at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &a } else if (activation == "sreglu") { nvte_sreglu(te_input.data(), te_output.data(), stream); } else if (activation == "clamped_swiglu") { - nvte_clamped_swiglu_v2(te_input.data(), te_output.data(), static_cast(activation_limit), + nvte_clamped_swiglu_v2(te_input.data(), te_output.data(), + static_cast(activation_limit), static_cast(activation_alpha), static_cast(activation_glu_linear_offset), stream); } else if (activation == "srelu") { @@ -520,8 +514,7 @@ at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &a at::Tensor activation_backward_impl(const at::Tensor &grad, const at::Tensor &input, const std::string &activation, double activation_limit, - double activation_alpha, - double activation_glu_linear_offset) { + double activation_alpha, double activation_glu_linear_offset) { auto output = at::empty_like(input); auto te_grad = makeTransformerEngineTensor(grad); auto te_input = makeTransformerEngineTensor(input); @@ -568,7 +561,7 @@ at::Tensor grouped_mlp_activation_forward( double activation_alpha, double activation_glu_linear_offset, at::ScalarType dtype) { auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size); auto activation_output = activation_forward_impl(activation_input, activation, activation_limit, - activation_alpha, activation_glu_linear_offset); + activation_alpha, activation_glu_linear_offset); if (!act_scales.has_value()) { return activation_output; } @@ -607,10 +600,9 @@ ActivationBackwardResult grouped_mlp_activation_backward( } auto grad_activation_input = - activation_backward_impl(grad_activation_output, activation_input, activation, activation_limit, - activation_alpha, activation_glu_linear_offset); - return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size), - grad_act_scales}; + activation_backward_impl(grad_activation_output, activation_input, activation, + activation_limit, activation_alpha, activation_glu_linear_offset); + return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size), grad_act_scales}; } } // namespace @@ -653,8 +645,7 @@ std::vector megacpp_grouped_mlp_forward( split_sizes, x.device(), std::vector{1, in_features, fc1_out_features, fc2_in_features, fc2_out_features}, std::vector{true, true, true, true, true}, - std::vector{at::kLong, at::kLong, at::kLong, at::kLong, at::kLong}, - true); + std::vector{at::kLong, at::kLong, at::kLong, at::kLong, at::kLong}, true); // splits_to_offsets_multi returns the canonical int64 CUDA split sizes and // offsets in the same order as the stride list above. The CuTe path also asks // for int32 split_points, but cuBLAS grouped GEMM does not consume them. @@ -675,10 +666,9 @@ std::vector megacpp_grouped_mlp_forward( &gemm_resources); add_grouped_bias(&grouped_fc1_preact, fc1_bias_tensor, num_groups, dtype, fc1_out_features); - auto fc2_x = - grouped_mlp_activation_forward(fc1_preact, act_scales, activation, glu_interleave_size, - activation_limit, activation_alpha, - activation_glu_linear_offset, dtype); + auto fc2_x = grouped_mlp_activation_forward( + fc1_preact, act_scales, activation, glu_interleave_size, activation_limit, activation_alpha, + activation_glu_linear_offset, dtype); std::vector out_shape = input.sizes().vec(); out_shape.back() = fc2_out_features; @@ -692,22 +682,20 @@ std::vector megacpp_grouped_mlp_forward( &gemm_resources); add_grouped_bias(&grouped_output, fc2_bias_tensor, num_groups, dtype, fc2_out_features); - return {output, x, split_sizes_i64, base_offsets, x_offsets, fc1_offsets, fc2_offsets, - output_offsets, fc1_preact, fc2_x}; + return {output, x, split_sizes_i64, base_offsets, x_offsets, + fc1_offsets, fc2_offsets, output_offsets, fc1_preact, fc2_x}; } py::tuple megacpp_grouped_mlp_backward( const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets, - const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, - const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, - const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x, - const std::optional &act_scales, py::handle fc1_weight, - py::handle fc2_weight, - py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad, - py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad, - const std::string &activation, int64_t glu_interleave_size, double activation_limit, - double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad, - bool input_requires_grad) { + const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets, + const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input, + const at::Tensor &fc2_x, const std::optional &act_scales, py::handle fc1_weight, + py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad, + bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad, + bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size, + double activation_limit, double activation_alpha, double activation_glu_linear_offset, + bool act_scales_requires_grad, bool input_requires_grad) { (void)base_offsets; NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output."); at::cuda::CUDAGuard device_guard(grad_output.device()); @@ -737,23 +725,20 @@ py::tuple megacpp_grouped_mlp_backward( fc2_x_for_wgrad = fc2_x_for_wgrad.view({-1, fc2_in_features}); auto grouped_fc2_x_for_wgrad = make_grouped_tensor(fc2_x_for_wgrad.view({-1}), split_sizes, fc2_offsets, fc2_in_features); - fc2_wgrads = - grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output, - fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype, - fc2_out_features, fc2_in_features, "fc2_wgrad_output", - fc2_weights.is_grouped); + fc2_wgrads = grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output, + fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype, + fc2_out_features, fc2_in_features, "fc2_wgrad_output", + fc2_weights.is_grouped); } auto fc2_dx = at::empty({total_tokens, fc2_in_features}, dy.options()); auto grouped_fc2_dx = make_grouped_tensor(fc2_dx.view({-1}), split_sizes, fc2_offsets, fc2_in_features); - grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, - &gemm_resources); + grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, &gemm_resources); auto activation_grads = grouped_mlp_activation_backward( - fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, - activation_limit, activation_alpha, activation_glu_linear_offset, dtype, - act_scales_requires_grad); + fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, activation_limit, + activation_alpha, activation_glu_linear_offset, dtype, act_scales_requires_grad); auto fc1_dy = activation_grads.grad_input; auto grad_act_scales = activation_grads.grad_act_scales; auto grouped_fc1_dy = @@ -766,11 +751,10 @@ py::tuple megacpp_grouped_mlp_backward( x_for_wgrad = x_for_wgrad.view({-1, in_features}); auto grouped_x_for_wgrad = make_grouped_tensor(x_for_wgrad.view({-1}), split_sizes, x_offsets, in_features); - fc1_wgrads = - grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output, - fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype, - fc1_out_features, in_features, "fc1_wgrad_output", - fc1_weights.is_grouped); + fc1_wgrads = grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output, + fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype, + fc1_out_features, in_features, "fc1_wgrad_output", + fc1_weights.is_grouped); } at::Tensor grad_input; @@ -779,8 +763,8 @@ py::tuple megacpp_grouped_mlp_backward( grad_input_shape.back() = in_features; grad_input = at::empty(grad_input_shape, dy.options()); auto grad_input_2d = grad_input.view({-1, in_features}); - auto grouped_grad_input = make_grouped_tensor(grad_input_2d.view({-1}), split_sizes, - x_offsets, in_features); + auto grouped_grad_input = + make_grouped_tensor(grad_input_2d.view({-1}), split_sizes, x_offsets, in_features); grouped_gemm_fwd_dgrad(&fc1_weights, false, &grouped_fc1_dy, false, &grouped_grad_input, &gemm_resources); } else { diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py index ebaf30d075..a0a69c5804 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp_megacpp.py @@ -324,11 +324,15 @@ def fuser_backward( else (None,) ) - return grad_input, [fc1_grad_params, (), fc2_grad_params], [ - (None,), - activation_grad_extra, - (None,), - ] + return ( + grad_input, + [fc1_grad_params, (), fc2_grad_params], + [ + (None,), + activation_grad_extra, + (None,), + ], + ) def fuse_backward_megacpp_ops( @@ -349,15 +353,9 @@ def fuse_backward_megacpp_ops( window, ops = ops[:3], ops[3:] while len(window) == 3: matches_pattern = True - if not ( - isinstance(window[0], GroupedLinear) - and isinstance(window[2], GroupedLinear) - ): + if not (isinstance(window[0], GroupedLinear) and isinstance(window[2], GroupedLinear)): matches_pattern = False - elif ( - window[0]._scale_bias - or window[2]._scale_bias - ): + elif window[0]._scale_bias or window[2]._scale_bias: matches_pattern = False else: try: diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py index bd05b6218f..61906e3714 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp_megacpp.py @@ -280,8 +280,12 @@ def fuser_forward( x._do_not_clear = True if requires_grad: - fc1_saved_weights = [fc1_weights] if isinstance(fc1_weights, torch.Tensor) else fc1_weights - fc2_saved_weights = [fc2_weights] if isinstance(fc2_weights, torch.Tensor) else fc2_weights + fc1_saved_weights = ( + [fc1_weights] if isinstance(fc1_weights, torch.Tensor) else fc1_weights + ) + fc2_saved_weights = ( + [fc2_weights] if isinstance(fc2_weights, torch.Tensor) else fc2_weights + ) fc1_ctx.save_for_backward( split_sizes_i64, @@ -338,15 +342,9 @@ def fuse_forward_megacpp_ops( window, ops = ops[:3], ops[3:] while len(window) == 3: matches_pattern = True - if not ( - isinstance(window[0], GroupedLinear) - and isinstance(window[2], GroupedLinear) - ): + if not (isinstance(window[0], GroupedLinear) and isinstance(window[2], GroupedLinear)): matches_pattern = False - elif ( - window[0]._scale_bias - or window[2]._scale_bias - ): + elif window[0]._scale_bias or window[2]._scale_bias: matches_pattern = False else: try: