Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 76 additions & 21 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,8 +1519,8 @@ def abstract(
additional_args: Either
* group_offsets: 1D array containing offsets for each group (not yet implemented)
OR
* alpha: 1D array of shape (G,) containing alpha values for each group
* beta: 1D array of shape (G,) containing beta values for each group
* alpha: 1D array of shape (G,) or (1,) containing alpha values
* beta: 1D array of shape (G,) or (1,) containing beta values
lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed
rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed
scaling_mode: Scaling mode for the GEMM operations
Expand Down Expand Up @@ -1606,12 +1606,17 @@ def abstract(
f" GEMM primitive, but got {len(additional_args)} arguments."
)
alpha_aval, beta_aval = additional_args
if alpha_aval.shape != (num_groups,):
raise ValueError(f"Expected alpha shape {(num_groups,)}, got {alpha_aval.shape}")
valid_alpha_beta_shapes = ((num_groups,), (1,))
if alpha_aval.shape not in valid_alpha_beta_shapes:
raise ValueError(
f"Expected alpha shape {(num_groups,)} or (1,), got {alpha_aval.shape}"
)
if alpha_aval.dtype != jnp.float32:
raise ValueError(f"Expected alpha dtype float32, got {alpha_aval.dtype}")
if beta_aval.shape != (num_groups,):
raise ValueError(f"Expected beta shape {(num_groups,)}, got {beta_aval.shape}")
if beta_aval.shape not in valid_alpha_beta_shapes:
raise ValueError(
f"Expected beta shape {(num_groups,)} or (1,), got {beta_aval.shape}"
)
if beta_aval.dtype != jnp.float32:
raise ValueError(f"Expected beta dtype float32, got {beta_aval.dtype}")

Expand Down Expand Up @@ -2091,6 +2096,11 @@ def _should_enforce_v2_grouped_gemm() -> bool:
) from e


def _v2_grouped_gemm_supports_per_group_alpha_beta() -> bool:
"""Whether nvte_grouped_gemm accepts per-group alpha/beta on all visible devices."""
return get_min_device_compute_capability() >= 100
Comment on lines +2099 to +2101
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing cache on capability check

_v2_grouped_gemm_supports_per_group_alpha_beta() calls get_min_device_compute_capability() on every grouped_gemm invocation, but unlike _should_enforce_v2_grouped_gemm() (which is decorated with @cache) it has no memoization. Both functions encode a process-wide constant; querying CUDA device capability in a hot path can add unnecessary overhead. Adding @cache (or @functools.lru_cache(maxsize=None)) mirrors the pattern already used by the sibling helper.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!



def _is_v2_grouped_gemm_supported(
scaling_mode: ScalingMode,
dtype: jnp.dtype,
Expand All @@ -2111,24 +2121,31 @@ def _is_v2_grouped_gemm_supported(
),
)

# nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer).
# Fall back to the v1 path on SM90 (Hopper) and older architectures.
if get_min_device_compute_capability() < 100:
# nvte_grouped_gemm (the v2 kernel) supports BF16 on SM90+ (Hopper or newer).
# MXFP8 remains gated to SM100+ below.
if get_min_device_compute_capability() < 90:
return (
False,
(
"The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current min device"
"The TE V2 grouped GEMM requires SM90+ (Hopper or newer) but current min device"
f" compute capability is {get_min_device_compute_capability()}."
),
)

if has_bias:
return False, "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel."

if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16:
return True, ""

if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
if get_min_device_compute_capability() < 100:
return (
False,
(
"The TE V2 grouped GEMM for MXFP8 requires SM100+ (Blackwell or newer) but"
" current min device compute capability is"
f" {get_min_device_compute_capability()}."
),
)

# V2 MXFP8 requires that the total first dimension of both operands (up to
# axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement.
# Individual group sizes must also be 128-aligned (dynamic constraint).
Expand Down Expand Up @@ -2188,9 +2205,10 @@ def _is_v2_grouped_gemm_supported(
return (
False,
(
"The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D"
" block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input"
f" parameters do not meet these requirements (scaling_mode= {scaling_mode},"
"The TE V2 grouped GEMM currently only supports non-quantized BF16, and MXFP8 with"
" 1D block scaling on SM100+, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and"
" the input parameters do not meet these requirements"
f" (scaling_mode= {scaling_mode},"
f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape},"
f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})."
),
Expand Down Expand Up @@ -2390,6 +2408,35 @@ def _get_num_gemms(
)


def _add_grouped_gemm_bias(
out: jnp.ndarray,
bias: jnp.ndarray,
out_first_dims: Optional[jnp.ndarray],
out_last_dims: Optional[jnp.ndarray],
out_shape: Tuple[int, ...],
num_gemms: int,
n_dim: int,
) -> jnp.ndarray:
"""Add grouped GEMM bias in JAX for V2 kernels that do not fuse bias."""
if out_last_dims is not None:
raise NotImplementedError("V2 grouped GEMM bias is not supported for ragged last dims")
Comment on lines +2421 to +2422
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Runtime error surfaces after kernel execution for bias + ragged-last-dims

When has_bias=True and out_last_dims is not None, _is_v2_grouped_gemm_supported still returns True for BF16 on SM90+ (no gate for this combination), so the full V2 kernel is dispatched before _add_grouped_gemm_bias raises NotImplementedError. The check should be moved upstream — either into _is_v2_grouped_gemm_supported (returning False to fall back to V1) or as an early guard in grouped_gemm before the FFI bind.


bias = bias.astype(out.dtype)
bias_2d = bias.reshape((num_gemms, n_dim))
if out_first_dims is not None:
out_2d = out.reshape((-1, n_dim))
bias_rows = jnp.repeat(
bias_2d,
out_first_dims,
axis=0,
total_repeat_length=out_2d.shape[0],
)
return (out_2d + bias_rows).reshape(out_shape)

bias_shape = (num_gemms,) + (1,) * (out.ndim - 2) + (n_dim,)
return out + bias_2d.reshape(bias_shape)


def grouped_gemm(
lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x],
rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x],
Expand Down Expand Up @@ -2509,7 +2556,8 @@ def grouped_gemm(
num_gemms,
N_dim,
), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}"
bias = jnp.empty((), jnp.float32) if bias is None else bias
else:
N_dim = 0

if group_offset is not None:
raise RuntimeError(
Expand Down Expand Up @@ -2538,18 +2586,21 @@ def grouped_gemm(
raise ValueError("rhs must be pre-swizzled for MXFP8 1D scaling")

if use_v2_ffi:
additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha
additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta
alpha_beta_numel = num_gemms if _v2_grouped_gemm_supports_per_group_alpha_beta() else 1
additional_arg_0 = jnp.ones((alpha_beta_numel,), jnp.float32) # alpha
additional_arg_1 = jnp.zeros((alpha_beta_numel,), jnp.float32) # beta
else:
additional_arg_0 = jnp.zeros((1,), jnp.int32) # group_offset
additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder
bias_for_ffi = jnp.empty((), jnp.float32) if (bias is None or use_v2_ffi) else bias
has_bias_for_ffi = has_bias and not use_v2_ffi

(out,) = GroupedGemmPrimitive.outer_primitive.bind(
lhs.data,
lhs.scale_inv if isinstance(lhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32),
rhs.data,
rhs.scale_inv if isinstance(rhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32),
bias,
bias_for_ffi,
lhs.first_dims if lhs.first_dims is not None else empty_gs,
lhs.last_dims if lhs.last_dims is not None else empty_gs,
rhs.first_dims if rhs.first_dims is not None else empty_gs,
Expand All @@ -2562,7 +2613,7 @@ def grouped_gemm(
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype,
has_bias=has_bias,
has_bias=has_bias_for_ffi,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
use_v2_ffi=use_v2_ffi,
lhs_axis_boundary=lhs_axis_boundary,
Expand All @@ -2573,4 +2624,8 @@ def grouped_gemm(
rhs_left_size=int(rhs_left_size),
rhs_right_size=int(rhs_right_size),
)
if use_v2_ffi and has_bias:
out = _add_grouped_gemm_bias(
out, bias, out_first_dims, out_last_dims, out_shape, num_gemms, N_dim
)
return out
4 changes: 2 additions & 2 deletions transformer_engine/jax/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,10 +953,10 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty
DType::kByte);

TensorWrapper alpha_tensor(static_cast<void *>(alpha.untyped_data()),
std::vector<size_t>{num_gemms},
std::vector<size_t>{alpha.element_count()},
convert_ffi_datatype_to_te_dtype(alpha.element_type()));
TensorWrapper beta_tensor(static_cast<void *>(beta.untyped_data()),
std::vector<size_t>{num_gemms},
std::vector<size_t>{beta.element_count()},
convert_ffi_datatype_to_te_dtype(beta.element_type()));

// Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed.
Expand Down
Loading