Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,42 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F));
};

auto is_bf16_wgrad_dtype = [&]() -> bool {
auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
auto A_type = get_cuda_dtype(inputA->data.dtype);
auto B_type = get_cuda_dtype(inputB->data.dtype);
auto D_type = get_cuda_dtype(OutputD->data.dtype);

return (A_type == CUDA_R_16BF) && (B_type == CUDA_R_16BF) &&
(D_type == CUDA_R_32F || D_type == CUDA_R_16BF);
};

// K-grouped BF16 wgrad shape eligibility: every group must be 2D NT with a matching
// (ragged) K and a uniform hidden/expert. Shapes outside this fall back to cuBLAS
// instead of hard-erroring inside the varlen-k kernel.
auto is_bf16_wgrad_shape = [&]() -> bool {
int64_t ref_hidden = -1, ref_expert = -1;
for (size_t i = 0; i < num_gemms; i++) {
const auto *inp = transformer_engine::convertNVTETensorCheck(A[i]);
const auto *grad = transformer_engine::convertNVTETensorCheck(B[i]);
if (inp->data.shape.size() != 2 || grad->data.shape.size() != 2) return false;
const int64_t k = inp->data.shape[0];
const int64_t hidden = inp->data.shape[1];
const int64_t expert = grad->data.shape[1];
if (static_cast<int64_t>(grad->data.shape[0]) != k || hidden <= 0 || expert <= 0)
return false;
if (ref_hidden < 0) {
ref_hidden = hidden;
ref_expert = expert;
} else if (hidden != ref_hidden || expert != ref_expert) {
return false;
}
}
return true;
};

// CUTLASS Grouped GEMM fast path (SM90/TMA)
// Conditions:
// - No fused epilogue: both bias and pre_gelu_out are empty.
Expand All @@ -1123,6 +1159,13 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
all_groups_uniform_k128(B, transb)) {
cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
current_device, math_sm_count, stream);
} else if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_bf16_wgrad_dtype() && !transa &&
transb && grad && is_bf16_wgrad_shape()) {
// Dedicated K-grouped (ragged-K) BF16-in / (FP32 or BF16)-out wgrad path:
// D_i = B_i.T @ A_i, K_i = routed-token dim. Shape eligibility is guarded above, so
// unsupported shapes fall back to cuBLAS rather than hard-erroring in the kernel.
cutlass_grouped_gemm_varlen_k(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
current_device, math_sm_count, stream);
Comment thread
cassiewilliam marked this conversation as resolved.
} else {
if (warn_fallback) {
NVTE_WARN("Fallback to cuBLAS grouped GEMM.");
Expand Down
102 changes: 102 additions & 0 deletions transformer_engine/common/gemm/cutlass_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
* See LICENSE for license information.
**************************************************************************************************/

#include <cuda_bf16.h>
#include <cuda_runtime_api.h>

#include <cstdint>
#include <vector>

#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass_grouped_gemm.cuh"
Expand Down Expand Up @@ -36,6 +42,18 @@ template void CutlassGroupedGemm<false, true, cutlass::bfloat16_t>(const NVTETen
NVTETensor*, float, float, int,
cudaStream_t, int, int);

// Explicit instantiation: BF16-in / FP32-out (default) wgrad path.
template void CutlassGroupedGemmWgrad<true, false, float>(const NVTETensor*, const NVTETensor*,
NVTETensor*, NVTETensor*, float, float,
int, cudaStream_t, int, int);

// Explicit instantiation: BF16-in / BF16-out wgrad path.
template void CutlassGroupedGemmWgrad<true, false, cutlass::bfloat16_t>(const NVTETensor*,
const NVTETensor*,
NVTETensor*, NVTETensor*,
float, float, int,
cudaStream_t, int, int);

} // namespace grouped_gemm
} // namespace transformer_engine

Expand Down Expand Up @@ -75,3 +93,87 @@ void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor*
NVTE_ERROR("Unsupported dtype: only BF16(FP16) are supported.");
}
}

namespace {

// Zero-initialize empty (K=0) groups (when not accumulating) and forward the non-empty groups to
// CUTLASS. Precondition: the dispatcher (nvte_multi_tensor_gemm) has already validated the BF16 NT
// wgrad contract -- 2D, matching ragged K, uniform hidden/expert, BF16-in / (FP32|BF16)-out -- via
// is_bf16_wgrad_dtype() + is_bf16_wgrad_shape(), so it is not re-checked here.
void collect_bf16_wgrad_nt_groups(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
int num_gemms, bool accumulate, cudaStream_t stream,
std::vector<NVTETensor>* A_nz, std::vector<NVTETensor>* B_nz,
std::vector<NVTETensor>* D_nz,
transformer_engine::DType* out_dtype) {
using namespace transformer_engine;
// hidden/expert/output-dtype are uniform across groups; read them once from group 0.
const int64_t hidden = convertNVTETensorCheck(A[0])->data.shape[1];
const int64_t expert = convertNVTETensorCheck(B[0])->data.shape[1];
*out_dtype = convertNVTETensorCheck(D[0])->data.dtype;
const size_t elem = (*out_dtype == DType::kFloat32) ? sizeof(float) : sizeof(__nv_bfloat16);

for (int i = 0; i < num_gemms; ++i) {
if (convertNVTETensorCheck(A[i])->data.shape[0] == 0) {
// Empty group: its null A/B pointers would crash TMA descriptor construction, so zero the
// output (when not accumulating) and exclude it from the launch.
auto* out = convertNVTETensorCheck(D[i]);
if (!accumulate && out->data.dptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemsetAsync(out->data.dptr, 0,
static_cast<size_t>(expert) * hidden * elem, stream));
}
} else {
A_nz->push_back(A[i]);
B_nz->push_back(B[i]);
D_nz->push_back(D[i]);
}
}
}

} // namespace

void cutlass_grouped_gemm_varlen_k(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
int num_gemms, bool transa, bool transb, bool grad,
NVTETensor* workspace, bool accumulate, int device,
int math_sm_count, cudaStream_t stream) {
using namespace transformer_engine;
// The kernel hard-codes the NT layout, so assert it: a wrong-layout caller would otherwise
// mis-compute silently. (Arch / no-epilogue / group-0 dtype are already gated by the
// dispatcher, and a wrong arch would fail loudly inside the CUTLASS kernel anyway.)
NVTE_CHECK(!transa && transb && grad,
"cutlass_grouped_gemm_varlen_k requires NT wgrad layout "
"(transa=false, transb=true, grad=true).");
NVTE_CHECK(workspace != nullptr, "cutlass_grouped_gemm_varlen_k requires a non-null workspace.");

std::vector<NVTETensor> A_nz, B_nz, D_nz;
A_nz.reserve(num_gemms);
B_nz.reserve(num_gemms);
D_nz.reserve(num_gemms);
DType out_dtype = DType::kFloat32;
collect_bf16_wgrad_nt_groups(A, B, D, num_gemms, accumulate, stream, &A_nz, &B_nz, &D_nz,
&out_dtype);

// All groups have K=0: outputs are already zero-initialized above, nothing to launch.
if (A_nz.empty()) return;

const int n_nz = static_cast<int>(A_nz.size());
float one = 1.0;
float zero = 0.0;
float alpha = one;
float beta = (accumulate) ? one : zero;

// NT wgrad: D_i = B_i^T @ A_i. Pass grad_output (outer B) as CUTLASS A (trans_a=true)
// and input (outer A) as CUTLASS B (trans_b=false). CutlassGroupedGemmWgrad validates
// the workspace size internally.
auto dispatch = [&](auto tag) {
using T = decltype(tag);
grouped_gemm::CutlassGroupedGemmWgrad<true, false, T>(B_nz.data(), A_nz.data(), D_nz.data(),
workspace, alpha, beta, n_nz, stream,
device, math_sm_count);
};

if (out_dtype == DType::kFloat32) {
dispatch(float{});
} else {
dispatch(cutlass::bfloat16_t{});
}
}
Loading