-
Notifications
You must be signed in to change notification settings - Fork 74
Add Cutlass MxFp8 Block Scale Matrix Multiplication #5736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Review updated until commit 959e14b Description
|
| Relevant files | |||||||
|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||
| Documentation |
| ||||||
| Tests |
| ||||||
| Configuration changes |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Missing Performance Data
|
Greptile SummaryThis PR adds support for MxFp8 (microscaling FP8) block-scaled matrix multiplication to nvfuser_direct by implementing CUTLASS kernels optimized for SM100+ (compute capability 10.x) architectures. Key Changes:
Architecture: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as Python User
participant Binding as cutlass.cpp (Python Binding)
participant API as mxfp8_scaled_mm (nvf_cutlass.cpp)
participant Validator as validateInputsMxFp8ScaledMm
participant Kernel as runGemm<T>
participant CUTLASS as CUTLASS GEMM Adapter
User->>Binding: mxfp8_scaled_mm(a, b, scales_a, scales_b, alpha, dtype)
Binding->>API: cutlass_kernels::mxfp8_scaled_mm(...)
API->>Validator: validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha, skip_checks)
Validator->>Validator: Check dimensions (a.dim==2, b.dim==2)
Validator->>Validator: Check CUDA device & contiguity
Validator->>Validator: Validate dtypes (Float8_e4m3fn, Float8_e8m0fnu)
Validator->>Validator: Check alignment (K%16==0, N%16==0)
Validator->>Validator: Validate scale matrix shapes
Validator-->>API: Return (m, n, k)
API->>API: Create output tensor
API->>Kernel: runGemm<cutlass::half_t or bfloat16_t>(...)
Kernel->>Kernel: args_from_options (setup CUTLASS arguments)
Kernel->>Kernel: Allocate workspace
Kernel->>CUTLASS: gemm.can_implement(arguments)
CUTLASS-->>Kernel: Status
Kernel->>CUTLASS: gemm.initialize(arguments, workspace, stream)
CUTLASS-->>Kernel: Status
Kernel->>CUTLASS: gemm.run(arguments, workspace, stream)
CUTLASS-->>Kernel: Status (compute C = alpha * A @ B)
Kernel-->>API: Return
API-->>Binding: Return output tensor
Binding-->>User: Return torch.Tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 5 comments
67ef893 to
372473e
Compare
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
372473e to
71390b5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 1 comment
71390b5 to
ca3d9f6
Compare
|
!test |
jacobhinkle
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM other than some minor comments. This appears to match the nvfp4 versions pretty closely as expected.
ca3d9f6 to
959e14b
Compare
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds MxFp8 (Microscaling FP8) block-scaled matrix multiplication support to nvfuser_direct, following the established pattern from existing NVFP4 implementations. The implementation includes:
- Core kernel implementation (
mxfp8_scaled_mm.cu): SM100+ CUTLASS kernel withfloat_e4m3fninput format andfloat_e8m0fnuscale factors, supporting FP16/BF16 outputs - Kernel configuration: Uses
Shape<_256, _256, _256>MMA tiles,Shape<_2, _4, _1>cluster shape (different from NVFP4'sShape<_4, _4, _1>), andShape<_128, _256, _256>per-SM tiles as specified - Validation and API: Comprehensive input validation with alignment checks (K and N must be divisible by 16, scales must be padded/swizzled to 128x4 blocks)
- Python bindings: Clean integration following existing patterns
- Test coverage: Parametrized tests across multiple dtypes (FP16, BF16) and shapes with reference implementation validation
The code follows existing conventions, includes proper error handling, and is well-documented.
Confidence Score: 5/5
- This PR is safe to merge with minimal risk
- The implementation closely follows established patterns from NVFP4 kernels, includes comprehensive validation and error handling, has proper test coverage with reference implementations, and all changes are additive without modifying existing functionality
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| cutlass/mxfp8_scaled_mm.cu | 5/5 | New MxFp8 scaled matrix multiplication kernel with proper SM100+ support, input validation, and error handling |
| cutlass/nvf_cutlass.cpp | 5/5 | Implemented validation logic for MxFp8 inputs with proper dimension checks and alignment requirements |
| tests/python/direct/test_cutlass_mxfp8_gemm.py | 5/5 | Comprehensive test suite with multiple dtypes and shapes, proper quantization/dequantization for validation |
Sequence Diagram
sequenceDiagram
participant User as Python User
participant Binding as cutlass.cpp (Python Binding)
participant API as mxfp8_scaled_mm
participant Validate as validateInputsMxFp8ScaledMm
participant Kernel as runGemm<T>
participant CUTLASS as MxFp8GemmSm100::Gemm
User->>Binding: mxfp8_scaled_mm(a, b, scales_a, scales_b, alpha, dtype)
Binding->>API: cutlass_kernels::mxfp8_scaled_mm(...)
API->>Validate: validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha)
Validate->>Validate: Check tensor dimensions (2D matrices)
Validate->>Validate: Check K dimensions match (a.size[1] == b.size[1])
Validate->>Validate: Check CUDA device & contiguity
Validate->>Validate: Validate dtypes (Float8_e4m3fn, Float8_e8m0fnu)
Validate->>Validate: Check alignment (K % 16 == 0, N % 16 == 0)
Validate->>Validate: Validate scale matrix shapes (padded to 128x4 blocks)
Validate-->>API: Return (m, n, k)
API->>API: Set CUDA device guard
API->>API: Create output tensor (m x n, dtype)
alt dtype == Half
API->>Kernel: runGemm<cutlass::half_t>(...)
else dtype == BFloat16
API->>Kernel: runGemm<cutlass::bfloat16_t>(...)
end
Kernel->>Kernel: args_from_options (setup strides, layouts)
Kernel->>Kernel: Allocate workspace memory
Kernel->>CUTLASS: gemm.can_implement(arguments)
CUTLASS-->>Kernel: Status::kSuccess
Kernel->>CUTLASS: gemm.initialize(arguments, workspace, stream)
CUTLASS-->>Kernel: Status::kSuccess
Kernel->>CUTLASS: gemm.run(arguments, workspace, stream)
Note over CUTLASS: Execute SM100 Block-Scaled TensorOp<br/>MMA: 256x256x256<br/>Cluster: 2x4x1
CUTLASS-->>Kernel: Status::kSuccess
Kernel-->>API: void (output filled)
API-->>Binding: output tensor
Binding-->>User: output tensor
This PR adds MxFp8 cutlass kernels to
nvfuser_direct.Shape<_256, _256, _256>Shape<_2, _4, _1>Shape<_128, _256, _256>