Skip to content

Conversation

@rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Dec 23, 2025

This PR adds MxFp8 cutlass kernels to nvfuser_direct.

  • FP16 and BF16 dtype outputs.
  • The MmaTileShape Shape<_256, _256, _256>
  • The Cluster Shape is Shape<_2, _4, _1>
  • PerSmTileShape_MNK is Shape<_128, _256, _256>

@github-actions
Copy link

github-actions bot commented Dec 23, 2025

Review updated until commit 959e14b

Description

  • Adds MXFP8 scaled matrix multiplication support using CUTLASS kernels for SM100+ GPUs

  • Implements FP16 and BF16 output support with optimized tile shapes and cluster configurations

  • Includes comprehensive input validation for MXFP8 format tensors and scale matrices

  • Provides Python bindings and extensive test coverage for the new functionality

Changes walkthrough

Relevant files
Enhancement
mxfp8_scaled_mm.cu
Core MXFP8 GEMM implementation with CUTLASS kernels           

cutlass/mxfp8_scaled_mm.cu

  • Implements MXFP8 scaled matrix multiplication using CUTLASS kernels
    for SM100+ architecture
  • Defines kernel traits and configurations for FP16/BF16 outputs with
    optimized tile shapes
  • Provides main mxfp8_scaled_mm function with proper error handling and
    workspace management
  • Includes comprehensive argument construction and GEMM execution logic
  • +316/-0 
    nvf_cutlass.cpp
    Input validation for MXFP8 scaled matrix multiplication   

    cutlass/nvf_cutlass.cpp

  • Adds validateInputsMxFp8ScaledMm function for comprehensive input
    validation
  • Validates CUDA device, contiguity, data types, and alignment
    requirements
  • Checks scale matrix properties and padding requirements for optimal
    performance
  • +112/-0 
    cutlass.cpp
    Python bindings for MXFP8 scaled matrix multiplication     

    python/python_direct/cutlass.cpp

  • Adds Python binding for mxfp8_scaled_mm function
  • Updates nvfp4_scaled_mm docstring to clarify supported output types
  • Provides proper Python interface for the new MXFP8 functionality
  • +21/-1   
    Documentation
    nvf_cutlass.h
    Header declarations and documentation for MXFP8 functions

    cutlass/nvf_cutlass.h

  • Declares validateInputsMxFp8ScaledMm and mxfp8_scaled_mm functions
  • Updates documentation for existing functions to clarify output types
  • Provides comprehensive parameter documentation for the new MXFP8
    functionality
  • +52/-1   
    Tests
    test_cutlass_mxfp8_gemm.py
    Test suite for MXFP8 GEMM functionality                                   

    tests/python/direct/test_cutlass_mxfp8_gemm.py

  • Comprehensive test suite for MXFP8 GEMM functionality
  • Helper functions for quantization, dequantization, and reference
    computation
  • Parameterized tests for different shapes and data types (FP16/BF16)
  • Validates compute capability requirements and tensor properties
  • +124/-0 
    Configuration changes
    CMakeLists.txt
    Build configuration update for MXFP8 source file                 

    CMakeLists.txt

  • Adds mxfp8_scaled_mm.cu to NVFUSER_CUTLASS_SRCS list
  • Includes new source file in the build configuration
  • +1/-0     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Missing Performance Data

    The PR lacks quantitative performance benchmarks comparing MXFP8 implementation against existing baselines (e.g., FP16/BF16 GEMM, other quantization methods). No roofline analysis or performance goals are provided to validate the effectiveness of this implementation.

    // clang-format off
    /*
     * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
     * All rights reserved.
     * SPDX-License-Identifier: BSD-3-Clause
     */
    // clang-format on
    #include <cutlass_utils.h>
    #include <exceptions.h>
    #include <nvf_cutlass.h>
    
    #include <ATen/cuda/CUDAContext.h>
    #include <c10/cuda/CUDAGuard.h>
    #include <torch/torch.h>
    
    #include "cutlass/cutlass.h"
    #include "cutlass/epilogue/collective/collective_builder.hpp"
    #include "cutlass/gemm/collective/collective_builder.hpp"
    #include "cutlass/gemm/device/gemm_universal_adapter.h"
    #include "cutlass/gemm/kernel/gemm_universal.hpp"
    #include "cutlass/util/packed_stride.hpp"
    
    namespace nvfuser::cutlass_kernels {
    
    namespace {
    
    using namespace cute;
    
    #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
    // Kernel configuration traits for different output data types
    // Defines tile shapes and cluster configurations.
    template <typename T>
    struct KernelTraits;
    
    // Kernel traits for FP16 output
    template <>
    struct KernelTraits<cutlass::half_t> {
      using MmaTileShape = Shape<_256, _256, _256>;
      using ClusterShape = Shape<_2, _4, _1>;
      using PerSmTileShape_MNK = Shape<_128, _256, _256>;
    };
    
    // Kernel traits for BF16 output
    template <>
    struct KernelTraits<cutlass::bfloat16_t> {
      using MmaTileShape = Shape<_256, _256, _256>;
      using ClusterShape = Shape<_2, _4, _1>;
      using PerSmTileShape_MNK = Shape<_128, _256, _256>;
    };
    
    // Main GEMM configuration for MXFP8 scaled matrix multiplication on SM100+
    // Defines all the types, layouts, and configurations needed for the CUTLASS
    // kernel
    template <typename T>
    struct MxFp8GemmSm100 {
      // A matrix configuration
      using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
      using LayoutATag = cutlass::layout::RowMajor;
      static constexpr int kAlignmentA = 16;
    
      // B matrix configuration
      using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
      using LayoutBTag = cutlass::layout::ColumnMajor;
      static constexpr int kAlignmentB = 16;
    
      // C/D matrix configuration
      using ElementD = T;
      using ElementC = T;
      using LayoutCTag = cutlass::layout::RowMajor;
      using LayoutDTag = cutlass::layout::RowMajor;
      static constexpr int kAlignmentD =
          128 / cutlass::sizeof_bits<ElementD>::value;
      static constexpr int kAlignmentC =
          128 / cutlass::sizeof_bits<ElementC>::value;
      // Kernel functional config
      using ElementAccumulator = float;
      using ArchTag = cutlass::arch::Sm100;
      using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
    
      // Kernel Perf config
      using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
      using ClusterShape = typename KernelTraits<T>::ClusterShape;
      using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
    
      using CollectiveEpilogue =
          typename cutlass::epilogue::collective::CollectiveBuilder<
              ArchTag,
              OperatorClass,
              PerSmTileShape_MNK,
              ClusterShape,
              cutlass::epilogue::collective::EpilogueTileAuto,
              ElementAccumulator,
              ElementAccumulator,
              ElementC,
              LayoutCTag,
              kAlignmentC,
              ElementD,
              LayoutDTag,
              kAlignmentD,
              cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
    
      using CollectiveMainloop =
          typename cutlass::gemm::collective::CollectiveBuilder<
              ArchTag,
              OperatorClass,
              ElementA,
              LayoutATag,
              kAlignmentA,
              ElementB,
              LayoutBTag,
              kAlignmentB,
              ElementAccumulator,
              MmaTileShape,
              ClusterShape,
              cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
                  sizeof(typename CollectiveEpilogue::SharedStorage))>,
              cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
    
      using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
          Shape<int, int, int, int>,
          CollectiveMainloop,
          CollectiveEpilogue,
          void>;
      using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
    
      // Reference device GEMM implementation type
      using StrideA = typename Gemm::GemmKernel::StrideA;
      using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));
      // Scale Factor tensors have an interleaved layout. Bring Layout instead of
      // stride.
      using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
      using StrideB = typename Gemm::GemmKernel::StrideB;
      using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{}));
      // Scale Factor tensors have an interleaved layout. Bring Layout instead of
      // stride.
      using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
      using StrideC = typename Gemm::GemmKernel::StrideC;
      using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{}));
      using StrideD = typename Gemm::GemmKernel::StrideD;
      using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
    };
    
    // Constructs CUTLASS GEMM arguments from PyTorch tensors and dimensions
    //
    // This function converts PyTorch tensor data and metadata into the format
    // expected by CUTLASS GEMM kernels, including proper stride calculations
    // and layout configurations for the scaled matrix multiplication.
    //
    // Parameters:
    //   output: Output tensor for storing results
    //   a: Input matrix A in MXFP8 format
    //   b: Input matrix B in MXFP8 format
    //   scales_a: Per-block scaling factors for matrix A
    //   scales_b: Per-block scaling factors for matrix B
    //   alpha: Global scaling factor
    //   M, N, K: Matrix dimensions
    //
    // Returns: CUTLASS GEMM arguments structure ready for kernel execution
    template <typename T>
    typename T::Gemm::Arguments args_from_options(
        at::Tensor& output,
        const at::Tensor& a,
        const at::Tensor& b,
        const at::Tensor& scales_a,
        const at::Tensor& scales_b,
        const at::Tensor& alpha,
        int64_t M,
        int64_t N,
        int64_t K) {
      using ElementA = typename T::Gemm::ElementA;
      using ElementB = typename T::Gemm::ElementB;
      using ElementSFA = cutlass::float_ue8m0_t;
      using ElementSFB = cutlass::float_ue8m0_t;
      using ElementD = typename T::Gemm::ElementD;
      using ElementCompute = float;
      using StrideA = typename T::StrideA;
      using StrideB = typename T::StrideB;
      using StrideD = typename T::StrideD;
      using Sm1xxBlkScaledConfig =
          typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
    
      int m = static_cast<int>(M);
      int n = static_cast<int>(N);
      int k = static_cast<int>(K);
      auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
      auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
      auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
    
      auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
          cute::make_shape(m, n, k, 1));
      auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
          cute::make_shape(m, n, k, 1));
    
      typename T::Gemm::Arguments arguments{
          cutlass::gemm::GemmUniversalMode::kGemm,
          {m, n, k, 1},
          {// Mainloop arguments
           static_cast<ElementA const*>(a.data_ptr()),
           stride_A,
           static_cast<ElementB const*>(b.data_ptr()),
           stride_B,
           static_cast<ElementSFA const*>(scales_a.data_ptr()),
           layout_SFA,
           static_cast<ElementSFB const*>(scales_b.data_ptr()),
           layout_SFB},
          {// Epilogue arguments
           {}, // epilogue.thread
           static_cast<ElementD const*>(output.data_ptr()),
           stride_D,
           static_cast<ElementD*>(output.data_ptr()),
           stride_D}};
      auto& fusion_args = arguments.epilogue.thread;
      fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
      return arguments;
    }
    
    // Executes the MXFP8 scaled matrix multiplication using CUTLASS kernels
    //
    // This function orchestrates the GEMM operation by setting up the kernel,
    // allocating workspace memory, and running the computation on the GPU.
    // It handles the complete lifecycle from kernel initialization to execution.
    //
    // Parameters:
    //   output: Output tensor to store the result
    //   a, b: Input matrices in MXFP8 format
    //   scales_a, scales_b: Per-block scaling factors
    //   alpha: Global scaling factor
    //   m, n, k: Matrix dimensions
    //   stream: CUDA stream for asynchronous execution
    template <typename T>
    void runGemm(
        at::Tensor& output,
        const at::Tensor& a,
        const at::Tensor& b,
        const at::Tensor& scales_a,
        const at::Tensor& scales_b,
        const at::Tensor& alpha,
        int64_t m,
        int64_t n,
        int64_t k,
        cudaStream_t stream) {
      typename MxFp8GemmSm100<T>::Gemm gemm;
    
      auto arguments = args_from_options<MxFp8GemmSm100<T>>(
          output, a, b, scales_a, scales_b, alpha, m, n, k);
    
      size_t workspace_size =
          MxFp8GemmSm100<T>::Gemm::get_workspace_size(arguments);
      auto const workspace_options =
          torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
      auto workspace = torch::empty(workspace_size, workspace_options);
    
      auto can_implement_status = gemm.can_implement(arguments);
      NVF_CHECK(
          can_implement_status == cutlass::Status::kSuccess,
          "Failed to implement GEMM");
    
      auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
      NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
    
      status = gemm.run(arguments, workspace.data_ptr(), stream);
      NVF_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
    }
    #else
    // Fallback implementation for unsupported CUTLASS versions
    // Throws an error when SM100+ CUTLASS support is not available
    template <typename T>
    void runGemm(
        at::Tensor& output,
        at::Tensor const& a,
        at::Tensor const& b,
        at::Tensor const& scales_a,
        at::Tensor const& scales_b,
        at::Tensor const& alpha,
        int64_t m,
        int64_t n,
        int64_t k,
        cudaStream_t stream) {
      NVF_THROW("Unsupported CUTLASS version.");
    }
    #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
    
    } // namespace
    
    torch::Tensor mxfp8_scaled_mm(
        const torch::Tensor& a,
        const torch::Tensor& b,
        const torch::Tensor& scales_a,
        const torch::Tensor& scales_b,
        const torch::Tensor& alpha,
        const at::ScalarType out_dtype,
        bool skip_checks) {
      // Validate all inputs and get matrix dimensions
      auto [m, n, k] =
          validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha, skip_checks);
    
      at::cuda::CUDAGuard device_guard{(int8_t)a.get_device()};
      const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
    
      auto options =
          at::TensorOptions().dtype(out_dtype).device(at::kCUDA, a.get_device());
      torch::Tensor output = at::empty({a.sizes()[0], b.sizes()[0]}, options);
    
      if (out_dtype == at::ScalarType::Half) {
        runGemm<cutlass::half_t>(
            output, a, b, scales_a, scales_b, alpha, m, n, k, stream);
      } else if (out_dtype == at::ScalarType::BFloat16) {
        runGemm<cutlass::bfloat16_t>(
            output, a, b, scales_a, scales_b, alpha, m, n, k, stream);
      } else {
        NVF_THROW("Unsupported output data type of mxfp8 scaled_mm.");
      }
      return output;
    }
    
    } // namespace nvfuser::cutlass_kernels
    Limited Architecture Support

    The implementation is restricted to SM100+ architecture with no fallback or alternative paths for older GPU architectures. This significantly limits the practical deployment scope of this feature.

    #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
    // Kernel configuration traits for different output data types
    // Defines tile shapes and cluster configurations.
    template <typename T>
    struct KernelTraits;
    
    // Kernel traits for FP16 output
    template <>
    struct KernelTraits<cutlass::half_t> {
      using MmaTileShape = Shape<_256, _256, _256>;
      using ClusterShape = Shape<_2, _4, _1>;
      using PerSmTileShape_MNK = Shape<_128, _256, _256>;
    };
    
    // Kernel traits for BF16 output
    template <>
    struct KernelTraits<cutlass::bfloat16_t> {
      using MmaTileShape = Shape<_256, _256, _256>;
      using ClusterShape = Shape<_2, _4, _1>;
      using PerSmTileShape_MNK = Shape<_128, _256, _256>;
    };
    
    Incomplete Test Coverage

    Tests only cover basic correctness but lack edge cases like non-divisible dimensions, very large matrices, numerical stability under extreme values, and performance regression testing.

    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
    @pytest.mark.parametrize(
        "shape", [(128, 128, 128), (128, 128, 256), (256, 128, 128), (128, 256, 256)]
    )
    @torch.inference_mode()
    def test_mxfp8_gemm(
        dtype: torch.dtype,
        shape: tuple[int, int, int],
    ) -> None:
        m, n, k = shape
        block_size = 32
        a_dtype = torch.randn((m, k), dtype=dtype, device="cuda")
        b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
    
        alpha = torch.tensor(1.0, device="cuda")
        a_fp8, a_scale_linear = pytorch_mxfp8_quantize(a_dtype)
        b_fp8, b_scale_linear = pytorch_mxfp8_quantize(b_dtype)
        a_scale_interleaved = linear_to_swizzled_128_4(a_scale_linear)
        b_scale_interleaved = linear_to_swizzled_128_4(b_scale_linear)
    
        expected_out = get_ref_results(
            a_fp8,
            b_fp8,
            a_scale_interleaved,
            b_scale_interleaved,
            m,
            n,
        )
        out = nvf_cutlass.mxfp8_scaled_mm(
            a_fp8, b_fp8, a_scale_interleaved, b_scale_interleaved, alpha, dtype
        )
    
        torch.testing.assert_close(out, expected_out.to(dtype=dtype))

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 23, 2025

    Greptile Summary

    This PR adds support for MxFp8 (microscaling FP8) block-scaled matrix multiplication to nvfuser_direct by implementing CUTLASS kernels optimized for SM100+ (compute capability 10.x) architectures.

    Key Changes:

    • Implements mxfp8_scaled_mm kernel in cutlass/mxfp8_scaled_mm.cu using CUTLASS 3.x collective builders with cluster shape Shape<_2, _4, _1> and tile shape Shape<_256, _256, _256>
    • Supports FP16 and BF16 output dtypes with Float8_e4m3fn input matrices and Float8_e8m0fnu block scale factors
    • Follows established patterns from nvfp4_scaled_mm.cu for consistency across the codebase
    • Adds comprehensive input validation in validateInputsMxFp8ScaledMm with alignment checks (K and N must be divisible by 16) and scale matrix shape validation
    • Includes Python bindings and test suite with quantization/dequantization utilities for correctness verification

    Architecture:
    The implementation uses the same architecture as existing NVFP4 kernels: validation layer → kernel launcher → CUTLASS adapter. The MxFp8 format uses per-block (32 elements) scaling factors stored in an interleaved/swizzled layout for optimal memory access patterns.

    Confidence Score: 5/5

    • This PR is safe to merge with no critical issues found
    • The implementation follows established patterns from existing nvfp4 kernels, includes comprehensive validation logic, and provides thorough test coverage. The code is well-documented and properly integrates with the existing build system. Previous review comments addressed error message corrections which are minor issues.
    • No files require special attention

    Important Files Changed

    Filename Overview
    cutlass/mxfp8_scaled_mm.cu New CUTLASS kernel implementation for MxFp8 matrix multiplication with proper SM100+ architecture support, follows existing patterns from nvfp4_scaled_mm.cu
    cutlass/nvf_cutlass.cpp Adds comprehensive input validation for MxFp8 operations with proper dtype checks and alignment requirements
    cutlass/nvf_cutlass.h Header declarations for MxFp8 API matching existing NVFP4 patterns with appropriate documentation
    python/python_direct/cutlass.cpp Python bindings for mxfp8_scaled_mm added correctly with proper documentation
    tests/python/direct/test_cutlass_mxfp8_gemm.py Comprehensive test suite with quantization, dequantization, and reference comparison for multiple dtypes and shapes
    CMakeLists.txt Adds mxfp8_scaled_mm.cu to build system correctly

    Sequence Diagram

    sequenceDiagram
        participant User as Python User
        participant Binding as cutlass.cpp (Python Binding)
        participant API as mxfp8_scaled_mm (nvf_cutlass.cpp)
        participant Validator as validateInputsMxFp8ScaledMm
        participant Kernel as runGemm<T>
        participant CUTLASS as CUTLASS GEMM Adapter
    
        User->>Binding: mxfp8_scaled_mm(a, b, scales_a, scales_b, alpha, dtype)
        Binding->>API: cutlass_kernels::mxfp8_scaled_mm(...)
        API->>Validator: validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha, skip_checks)
        Validator->>Validator: Check dimensions (a.dim==2, b.dim==2)
        Validator->>Validator: Check CUDA device & contiguity
        Validator->>Validator: Validate dtypes (Float8_e4m3fn, Float8_e8m0fnu)
        Validator->>Validator: Check alignment (K%16==0, N%16==0)
        Validator->>Validator: Validate scale matrix shapes
        Validator-->>API: Return (m, n, k)
        API->>API: Create output tensor
        API->>Kernel: runGemm<cutlass::half_t or bfloat16_t>(...)
        Kernel->>Kernel: args_from_options (setup CUTLASS arguments)
        Kernel->>Kernel: Allocate workspace
        Kernel->>CUTLASS: gemm.can_implement(arguments)
        CUTLASS-->>Kernel: Status
        Kernel->>CUTLASS: gemm.initialize(arguments, workspace, stream)
        CUTLASS-->>Kernel: Status
        Kernel->>CUTLASS: gemm.run(arguments, workspace, stream)
        CUTLASS-->>Kernel: Status (compute C = alpha * A @ B)
        Kernel-->>API: Return
        API-->>Binding: Return output tensor
        Binding-->>User: Return torch.Tensor
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 5 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 5, 2026

    Greptile's behavior is changing!

    From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

    This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 added the Low precision FP8, FP4, MXFP8, nvFP4 label Jan 5, 2026
    @rdspring1 rdspring1 requested a review from jacobhinkle January 5, 2026 18:10
    Copy link
    Collaborator

    @jacobhinkle jacobhinkle left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM other than some minor comments. This appears to match the nvfp4 versions pretty closely as expected.

    @rdspring1
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    This PR adds MxFp8 (Microscaling FP8) block-scaled matrix multiplication support to nvfuser_direct, following the established pattern from existing NVFP4 implementations. The implementation includes:

    • Core kernel implementation (mxfp8_scaled_mm.cu): SM100+ CUTLASS kernel with float_e4m3fn input format and float_e8m0fnu scale factors, supporting FP16/BF16 outputs
    • Kernel configuration: Uses Shape<_256, _256, _256> MMA tiles, Shape<_2, _4, _1> cluster shape (different from NVFP4's Shape<_4, _4, _1>), and Shape<_128, _256, _256> per-SM tiles as specified
    • Validation and API: Comprehensive input validation with alignment checks (K and N must be divisible by 16, scales must be padded/swizzled to 128x4 blocks)
    • Python bindings: Clean integration following existing patterns
    • Test coverage: Parametrized tests across multiple dtypes (FP16, BF16) and shapes with reference implementation validation

    The code follows existing conventions, includes proper error handling, and is well-documented.

    Confidence Score: 5/5

    • This PR is safe to merge with minimal risk
    • The implementation closely follows established patterns from NVFP4 kernels, includes comprehensive validation and error handling, has proper test coverage with reference implementations, and all changes are additive without modifying existing functionality
    • No files require special attention

    Important Files Changed

    File Analysis

    Filename Score Overview
    cutlass/mxfp8_scaled_mm.cu 5/5 New MxFp8 scaled matrix multiplication kernel with proper SM100+ support, input validation, and error handling
    cutlass/nvf_cutlass.cpp 5/5 Implemented validation logic for MxFp8 inputs with proper dimension checks and alignment requirements
    tests/python/direct/test_cutlass_mxfp8_gemm.py 5/5 Comprehensive test suite with multiple dtypes and shapes, proper quantization/dequantization for validation

    Sequence Diagram

    sequenceDiagram
        participant User as Python User
        participant Binding as cutlass.cpp (Python Binding)
        participant API as mxfp8_scaled_mm
        participant Validate as validateInputsMxFp8ScaledMm
        participant Kernel as runGemm<T>
        participant CUTLASS as MxFp8GemmSm100::Gemm
    
        User->>Binding: mxfp8_scaled_mm(a, b, scales_a, scales_b, alpha, dtype)
        Binding->>API: cutlass_kernels::mxfp8_scaled_mm(...)
        
        API->>Validate: validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha)
        Validate->>Validate: Check tensor dimensions (2D matrices)
        Validate->>Validate: Check K dimensions match (a.size[1] == b.size[1])
        Validate->>Validate: Check CUDA device & contiguity
        Validate->>Validate: Validate dtypes (Float8_e4m3fn, Float8_e8m0fnu)
        Validate->>Validate: Check alignment (K % 16 == 0, N % 16 == 0)
        Validate->>Validate: Validate scale matrix shapes (padded to 128x4 blocks)
        Validate-->>API: Return (m, n, k)
        
        API->>API: Set CUDA device guard
        API->>API: Create output tensor (m x n, dtype)
        
        alt dtype == Half
            API->>Kernel: runGemm<cutlass::half_t>(...)
        else dtype == BFloat16
            API->>Kernel: runGemm<cutlass::bfloat16_t>(...)
        end
        
        Kernel->>Kernel: args_from_options (setup strides, layouts)
        Kernel->>Kernel: Allocate workspace memory
        Kernel->>CUTLASS: gemm.can_implement(arguments)
        CUTLASS-->>Kernel: Status::kSuccess
        Kernel->>CUTLASS: gemm.initialize(arguments, workspace, stream)
        CUTLASS-->>Kernel: Status::kSuccess
        Kernel->>CUTLASS: gemm.run(arguments, workspace, stream)
        Note over CUTLASS: Execute SM100 Block-Scaled TensorOp<br/>MMA: 256x256x256<br/>Cluster: 2x4x1
        CUTLASS-->>Kernel: Status::kSuccess
        
        Kernel-->>API: void (output filled)
        API-->>Binding: output tensor
        Binding-->>User: output tensor
    
    Loading

    @rdspring1 rdspring1 merged commit 463dfba into main Jan 8, 2026
    61 checks passed
    @rdspring1 rdspring1 deleted the mxfp8_cutlass branch January 8, 2026 16:33
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Cutlass Low precision FP8, FP4, MXFP8, nvFP4

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants