-
Notifications
You must be signed in to change notification settings - Fork 741
[JAX] Hopper BF16 grouped GEMM v2 support #3083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1519,8 +1519,8 @@ def abstract( | |
| additional_args: Either | ||
| * group_offsets: 1D array containing offsets for each group (not yet implemented) | ||
| OR | ||
| * alpha: 1D array of shape (G,) containing alpha values for each group | ||
| * beta: 1D array of shape (G,) containing beta values for each group | ||
| * alpha: 1D array of shape (G,) or (1,) containing alpha values | ||
| * beta: 1D array of shape (G,) or (1,) containing beta values | ||
| lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed | ||
| rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed | ||
| scaling_mode: Scaling mode for the GEMM operations | ||
|
|
@@ -1606,12 +1606,17 @@ def abstract( | |
| f" GEMM primitive, but got {len(additional_args)} arguments." | ||
| ) | ||
| alpha_aval, beta_aval = additional_args | ||
| if alpha_aval.shape != (num_groups,): | ||
| raise ValueError(f"Expected alpha shape {(num_groups,)}, got {alpha_aval.shape}") | ||
| valid_alpha_beta_shapes = ((num_groups,), (1,)) | ||
| if alpha_aval.shape not in valid_alpha_beta_shapes: | ||
| raise ValueError( | ||
| f"Expected alpha shape {(num_groups,)} or (1,), got {alpha_aval.shape}" | ||
| ) | ||
| if alpha_aval.dtype != jnp.float32: | ||
| raise ValueError(f"Expected alpha dtype float32, got {alpha_aval.dtype}") | ||
| if beta_aval.shape != (num_groups,): | ||
| raise ValueError(f"Expected beta shape {(num_groups,)}, got {beta_aval.shape}") | ||
| if beta_aval.shape not in valid_alpha_beta_shapes: | ||
| raise ValueError( | ||
| f"Expected beta shape {(num_groups,)} or (1,), got {beta_aval.shape}" | ||
| ) | ||
| if beta_aval.dtype != jnp.float32: | ||
| raise ValueError(f"Expected beta dtype float32, got {beta_aval.dtype}") | ||
|
|
||
|
|
@@ -2091,6 +2096,11 @@ def _should_enforce_v2_grouped_gemm() -> bool: | |
| ) from e | ||
|
|
||
|
|
||
| def _v2_grouped_gemm_supports_per_group_alpha_beta() -> bool: | ||
| """Whether nvte_grouped_gemm accepts per-group alpha/beta on all visible devices.""" | ||
| return get_min_device_compute_capability() >= 100 | ||
|
|
||
|
|
||
| def _is_v2_grouped_gemm_supported( | ||
| scaling_mode: ScalingMode, | ||
| dtype: jnp.dtype, | ||
|
|
@@ -2111,24 +2121,31 @@ def _is_v2_grouped_gemm_supported( | |
| ), | ||
| ) | ||
|
|
||
| # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). | ||
| # Fall back to the v1 path on SM90 (Hopper) and older architectures. | ||
| if get_min_device_compute_capability() < 100: | ||
| # nvte_grouped_gemm (the v2 kernel) supports BF16 on SM90+ (Hopper or newer). | ||
| # MXFP8 remains gated to SM100+ below. | ||
| if get_min_device_compute_capability() < 90: | ||
| return ( | ||
| False, | ||
| ( | ||
| "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current min device" | ||
| "The TE V2 grouped GEMM requires SM90+ (Hopper or newer) but current min device" | ||
| f" compute capability is {get_min_device_compute_capability()}." | ||
| ), | ||
| ) | ||
|
|
||
| if has_bias: | ||
| return False, "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel." | ||
|
|
||
| if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16: | ||
| return True, "" | ||
|
|
||
| if scaling_mode == ScalingMode.MXFP8_1D_SCALING: | ||
| if get_min_device_compute_capability() < 100: | ||
| return ( | ||
| False, | ||
| ( | ||
| "The TE V2 grouped GEMM for MXFP8 requires SM100+ (Blackwell or newer) but" | ||
| " current min device compute capability is" | ||
| f" {get_min_device_compute_capability()}." | ||
| ), | ||
| ) | ||
|
|
||
| # V2 MXFP8 requires that the total first dimension of both operands (up to | ||
| # axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement. | ||
| # Individual group sizes must also be 128-aligned (dynamic constraint). | ||
|
|
@@ -2188,9 +2205,10 @@ def _is_v2_grouped_gemm_supported( | |
| return ( | ||
| False, | ||
| ( | ||
| "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D" | ||
| " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input" | ||
| f" parameters do not meet these requirements (scaling_mode= {scaling_mode}," | ||
| "The TE V2 grouped GEMM currently only supports non-quantized BF16, and MXFP8 with" | ||
| " 1D block scaling on SM100+, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and" | ||
| " the input parameters do not meet these requirements" | ||
| f" (scaling_mode= {scaling_mode}," | ||
| f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape}," | ||
| f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})." | ||
| ), | ||
|
|
@@ -2390,6 +2408,35 @@ def _get_num_gemms( | |
| ) | ||
|
|
||
|
|
||
| def _add_grouped_gemm_bias( | ||
| out: jnp.ndarray, | ||
| bias: jnp.ndarray, | ||
| out_first_dims: Optional[jnp.ndarray], | ||
| out_last_dims: Optional[jnp.ndarray], | ||
| out_shape: Tuple[int, ...], | ||
| num_gemms: int, | ||
| n_dim: int, | ||
| ) -> jnp.ndarray: | ||
| """Add grouped GEMM bias in JAX for V2 kernels that do not fuse bias.""" | ||
| if out_last_dims is not None: | ||
| raise NotImplementedError("V2 grouped GEMM bias is not supported for ragged last dims") | ||
|
Comment on lines
+2421
to
+2422
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When |
||
|
|
||
| bias = bias.astype(out.dtype) | ||
| bias_2d = bias.reshape((num_gemms, n_dim)) | ||
| if out_first_dims is not None: | ||
| out_2d = out.reshape((-1, n_dim)) | ||
| bias_rows = jnp.repeat( | ||
| bias_2d, | ||
| out_first_dims, | ||
| axis=0, | ||
| total_repeat_length=out_2d.shape[0], | ||
| ) | ||
| return (out_2d + bias_rows).reshape(out_shape) | ||
|
|
||
| bias_shape = (num_gemms,) + (1,) * (out.ndim - 2) + (n_dim,) | ||
| return out + bias_2d.reshape(bias_shape) | ||
|
|
||
|
|
||
| def grouped_gemm( | ||
| lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], | ||
| rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], | ||
|
|
@@ -2509,7 +2556,8 @@ def grouped_gemm( | |
| num_gemms, | ||
| N_dim, | ||
| ), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" | ||
| bias = jnp.empty((), jnp.float32) if bias is None else bias | ||
| else: | ||
| N_dim = 0 | ||
|
|
||
| if group_offset is not None: | ||
| raise RuntimeError( | ||
|
|
@@ -2538,18 +2586,21 @@ def grouped_gemm( | |
| raise ValueError("rhs must be pre-swizzled for MXFP8 1D scaling") | ||
|
|
||
| if use_v2_ffi: | ||
| additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha | ||
| additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta | ||
| alpha_beta_numel = num_gemms if _v2_grouped_gemm_supports_per_group_alpha_beta() else 1 | ||
| additional_arg_0 = jnp.ones((alpha_beta_numel,), jnp.float32) # alpha | ||
| additional_arg_1 = jnp.zeros((alpha_beta_numel,), jnp.float32) # beta | ||
| else: | ||
| additional_arg_0 = jnp.zeros((1,), jnp.int32) # group_offset | ||
| additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder | ||
| bias_for_ffi = jnp.empty((), jnp.float32) if (bias is None or use_v2_ffi) else bias | ||
| has_bias_for_ffi = has_bias and not use_v2_ffi | ||
|
|
||
| (out,) = GroupedGemmPrimitive.outer_primitive.bind( | ||
| lhs.data, | ||
| lhs.scale_inv if isinstance(lhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32), | ||
| rhs.data, | ||
| rhs.scale_inv if isinstance(rhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32), | ||
| bias, | ||
| bias_for_ffi, | ||
| lhs.first_dims if lhs.first_dims is not None else empty_gs, | ||
| lhs.last_dims if lhs.last_dims is not None else empty_gs, | ||
| rhs.first_dims if rhs.first_dims is not None else empty_gs, | ||
|
|
@@ -2562,7 +2613,7 @@ def grouped_gemm( | |
| rhs_is_trans=rhs_is_trans, | ||
| scaling_mode=scaling_mode.value, | ||
| out_dtype=out_dtype, | ||
| has_bias=has_bias, | ||
| has_bias=has_bias_for_ffi, | ||
| use_async_d2h_group_sizes=use_async_d2h_group_sizes, | ||
| use_v2_ffi=use_v2_ffi, | ||
| lhs_axis_boundary=lhs_axis_boundary, | ||
|
|
@@ -2573,4 +2624,8 @@ def grouped_gemm( | |
| rhs_left_size=int(rhs_left_size), | ||
| rhs_right_size=int(rhs_right_size), | ||
| ) | ||
| if use_v2_ffi and has_bias: | ||
| out = _add_grouped_gemm_bias( | ||
| out, bias, out_first_dims, out_last_dims, out_shape, num_gemms, N_dim | ||
| ) | ||
| return out | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_v2_grouped_gemm_supports_per_group_alpha_beta()callsget_min_device_compute_capability()on everygrouped_gemminvocation, but unlike_should_enforce_v2_grouped_gemm()(which is decorated with@cache) it has no memoization. Both functions encode a process-wide constant; querying CUDA device capability in a hot path can add unnecessary overhead. Adding@cache(or@functools.lru_cache(maxsize=None)) mirrors the pattern already used by the sibling helper.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!