diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 212c8083ec..318fffff15 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1519,8 +1519,8 @@ def abstract( additional_args: Either * group_offsets: 1D array containing offsets for each group (not yet implemented) OR - * alpha: 1D array of shape (G,) containing alpha values for each group - * beta: 1D array of shape (G,) containing beta values for each group + * alpha: 1D array of shape (G,) or (1,) containing alpha values + * beta: 1D array of shape (G,) or (1,) containing beta values lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed scaling_mode: Scaling mode for the GEMM operations @@ -1606,12 +1606,17 @@ def abstract( f" GEMM primitive, but got {len(additional_args)} arguments." ) alpha_aval, beta_aval = additional_args - if alpha_aval.shape != (num_groups,): - raise ValueError(f"Expected alpha shape {(num_groups,)}, got {alpha_aval.shape}") + valid_alpha_beta_shapes = ((num_groups,), (1,)) + if alpha_aval.shape not in valid_alpha_beta_shapes: + raise ValueError( + f"Expected alpha shape {(num_groups,)} or (1,), got {alpha_aval.shape}" + ) if alpha_aval.dtype != jnp.float32: raise ValueError(f"Expected alpha dtype float32, got {alpha_aval.dtype}") - if beta_aval.shape != (num_groups,): - raise ValueError(f"Expected beta shape {(num_groups,)}, got {beta_aval.shape}") + if beta_aval.shape not in valid_alpha_beta_shapes: + raise ValueError( + f"Expected beta shape {(num_groups,)} or (1,), got {beta_aval.shape}" + ) if beta_aval.dtype != jnp.float32: raise ValueError(f"Expected beta dtype float32, got {beta_aval.dtype}") @@ -2091,6 +2096,11 @@ def _should_enforce_v2_grouped_gemm() -> bool: ) from e +def _v2_grouped_gemm_supports_per_group_alpha_beta() -> bool: + """Whether nvte_grouped_gemm accepts per-group alpha/beta on all visible devices.""" + return get_min_device_compute_capability() >= 100 + + def _is_v2_grouped_gemm_supported( scaling_mode: ScalingMode, dtype: jnp.dtype, @@ -2111,24 +2121,31 @@ def _is_v2_grouped_gemm_supported( ), ) - # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). - # Fall back to the v1 path on SM90 (Hopper) and older architectures. - if get_min_device_compute_capability() < 100: + # nvte_grouped_gemm (the v2 kernel) supports BF16 on SM90+ (Hopper or newer). + # MXFP8 remains gated to SM100+ below. + if get_min_device_compute_capability() < 90: return ( False, ( - "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current min device" + "The TE V2 grouped GEMM requires SM90+ (Hopper or newer) but current min device" f" compute capability is {get_min_device_compute_capability()}." ), ) - if has_bias: - return False, "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel." - if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16: return True, "" if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + if get_min_device_compute_capability() < 100: + return ( + False, + ( + "The TE V2 grouped GEMM for MXFP8 requires SM100+ (Blackwell or newer) but" + " current min device compute capability is" + f" {get_min_device_compute_capability()}." + ), + ) + # V2 MXFP8 requires that the total first dimension of both operands (up to # axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement. # Individual group sizes must also be 128-aligned (dynamic constraint). @@ -2188,9 +2205,10 @@ def _is_v2_grouped_gemm_supported( return ( False, ( - "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D" - " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input" - f" parameters do not meet these requirements (scaling_mode= {scaling_mode}," + "The TE V2 grouped GEMM currently only supports non-quantized BF16, and MXFP8 with" + " 1D block scaling on SM100+, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and" + " the input parameters do not meet these requirements" + f" (scaling_mode= {scaling_mode}," f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape}," f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})." ), @@ -2390,6 +2408,35 @@ def _get_num_gemms( ) +def _add_grouped_gemm_bias( + out: jnp.ndarray, + bias: jnp.ndarray, + out_first_dims: Optional[jnp.ndarray], + out_last_dims: Optional[jnp.ndarray], + out_shape: Tuple[int, ...], + num_gemms: int, + n_dim: int, +) -> jnp.ndarray: + """Add grouped GEMM bias in JAX for V2 kernels that do not fuse bias.""" + if out_last_dims is not None: + raise NotImplementedError("V2 grouped GEMM bias is not supported for ragged last dims") + + bias = bias.astype(out.dtype) + bias_2d = bias.reshape((num_gemms, n_dim)) + if out_first_dims is not None: + out_2d = out.reshape((-1, n_dim)) + bias_rows = jnp.repeat( + bias_2d, + out_first_dims, + axis=0, + total_repeat_length=out_2d.shape[0], + ) + return (out_2d + bias_rows).reshape(out_shape) + + bias_shape = (num_gemms,) + (1,) * (out.ndim - 2) + (n_dim,) + return out + bias_2d.reshape(bias_shape) + + def grouped_gemm( lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], @@ -2509,7 +2556,8 @@ def grouped_gemm( num_gemms, N_dim, ), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" - bias = jnp.empty((), jnp.float32) if bias is None else bias + else: + N_dim = 0 if group_offset is not None: raise RuntimeError( @@ -2538,18 +2586,21 @@ def grouped_gemm( raise ValueError("rhs must be pre-swizzled for MXFP8 1D scaling") if use_v2_ffi: - additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha - additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta + alpha_beta_numel = num_gemms if _v2_grouped_gemm_supports_per_group_alpha_beta() else 1 + additional_arg_0 = jnp.ones((alpha_beta_numel,), jnp.float32) # alpha + additional_arg_1 = jnp.zeros((alpha_beta_numel,), jnp.float32) # beta else: additional_arg_0 = jnp.zeros((1,), jnp.int32) # group_offset additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder + bias_for_ffi = jnp.empty((), jnp.float32) if (bias is None or use_v2_ffi) else bias + has_bias_for_ffi = has_bias and not use_v2_ffi (out,) = GroupedGemmPrimitive.outer_primitive.bind( lhs.data, lhs.scale_inv if isinstance(lhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32), rhs.data, rhs.scale_inv if isinstance(rhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32), - bias, + bias_for_ffi, lhs.first_dims if lhs.first_dims is not None else empty_gs, lhs.last_dims if lhs.last_dims is not None else empty_gs, rhs.first_dims if rhs.first_dims is not None else empty_gs, @@ -2562,7 +2613,7 @@ def grouped_gemm( rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, out_dtype=out_dtype, - has_bias=has_bias, + has_bias=has_bias_for_ffi, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, lhs_axis_boundary=lhs_axis_boundary, @@ -2573,4 +2624,8 @@ def grouped_gemm( rhs_left_size=int(rhs_left_size), rhs_right_size=int(rhs_right_size), ) + if use_v2_ffi and has_bias: + out = _add_grouped_gemm_bias( + out, bias, out_first_dims, out_last_dims, out_shape, num_gemms, N_dim + ) return out diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index a36bf6cd22..65b6b1286e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -953,10 +953,10 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty DType::kByte); TensorWrapper alpha_tensor(static_cast(alpha.untyped_data()), - std::vector{num_gemms}, + std::vector{alpha.element_count()}, convert_ffi_datatype_to_te_dtype(alpha.element_type())); TensorWrapper beta_tensor(static_cast(beta.untyped_data()), - std::vector{num_gemms}, + std::vector{beta.element_count()}, convert_ffi_datatype_to_te_dtype(beta.element_type())); // Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed.