diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index bb4a4e3857..f07d8b4d9b 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -9,6 +9,7 @@ import torch from transformer_engine.pytorch import ( DotProductAttention, + GroupedLinear, LayerNormLinear, LayerNormMLP, Linear, @@ -216,6 +217,38 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: return x +class _GroupedLinearWrapper(torch.nn.Module): + """Adapt GroupedLinear to the [seqlen, batch, hidden] data used in this test. + + GroupedLinear expects a 2D `[total_tokens, hidden]` input plus an `m_splits` + list, so this wrapper flattens the leading dims, splits the tokens evenly + across the GEMMs, and restores the original shape. It also forwards + `is_first_microbatch` so the FP8 weight-caching path is exercised under CUDA + graphs. + """ + + def __init__(self, hidden_size: int, num_gemms: int, params_dtype: torch.dtype) -> None: + super().__init__() + self.num_gemms = num_gemms + self.grouped_linear = GroupedLinear( + num_gemms, + hidden_size, + hidden_size, + device="cuda", + params_dtype=params_dtype, + ) + + def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: + seqlen, batch, hidden = input_.shape + x = input_.reshape(seqlen * batch, hidden) + total_tokens = x.shape[0] + assert total_tokens % self.num_gemms == 0, "tokens must split evenly across GEMMs" + split = total_tokens // self.num_gemms + m_splits = [split] * self.num_gemms + out = self.grouped_linear(x, m_splits, **kwargs) + return out.reshape(seqlen, batch, hidden) + + # Supported modules _test_cuda_graphs_modules: List[str] = [ # Put linear first to test the case where the cuda context might not be set in @@ -315,6 +348,15 @@ def _test_cuda_graphs( ) for _ in range(num_layers) ] + elif module == "grouped_linear": + modules = [ + _GroupedLinearWrapper( + model_config.hidden_size, + num_gemms=2, + params_dtype=dtype, + ) + for _ in range(num_layers) + ] elif module == "linear_op": modules = [ te_ops.Sequential( @@ -501,6 +543,52 @@ def test_make_graphed_callables_with_fp8_weight_caching( ) +# Per-tensor FP8 recipes that support GroupedLinear FP8 weight caching. +_grouped_linear_fp8_weight_caching_recipes = [] +if fp8_available: + _grouped_linear_fp8_weight_caching_recipes.append(recipe.DelayedScaling()) + _grouped_linear_fp8_weight_caching_recipes.append(recipe.Float8CurrentScaling()) + + +@pytest.mark.skipif(not fp8_available, reason="FP8 is not supported") +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("fp8_params", (False, True)) +@pytest.mark.parametrize("fp8_recipe", _grouped_linear_fp8_weight_caching_recipes, ids=recipe_id) +def test_make_graphed_callables_grouped_linear_with_fp8_weight_caching( + *, + dtype: torch.dtype, + fp8_params: bool, + fp8_recipe: recipe.Recipe, + model_config: str = "small", + num_layers: int = 3, +) -> None: + """GroupedLinear must thread `is_first_microbatch` into the FP8 weight-update + skip tensor under CUDA graphs. + + With `fp8_weight_caching` enabled, the graphed and non-graphed runs only match + when `skip_fp8_weight_update` is propagated for every microbatch. Before the + fix, GroupedLinear hardcoded it to `None`, so the cached FP8 weights diverged + from the eager reference. This regresses if that propagation is dropped again. + """ + config = model_configs[model_config] + kwargs = dict( + module="grouped_linear", + model_config=config, + num_layers=num_layers, + dtype=dtype, + fp8=True, + fp8_params=fp8_params, + fp8_weight_caching=True, + fp8_recipe=fp8_recipe, + ) + graph_outputs_full = _test_cuda_graphs(graph_mode="full", **kwargs) + graph_outputs_individual = _test_cuda_graphs(graph_mode="individual", **kwargs) + outputs = _test_cuda_graphs(graph_mode="none", **kwargs) + + assert_all_equal(outputs, graph_outputs_full) + assert_all_equal(outputs, graph_outputs_individual) + + def generate_data_for_dot_product_attention( model_config: ModelConfig, dtype: torch.dtype, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 15ec3fe322..62007641e4 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -1684,6 +1684,15 @@ def forward( is_grad_enabled = torch.is_grad_enabled() num_gemms = self.num_gemms + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) + else: + skip_fp8_weight_update = None + if skip_fp8_weight_update is not None: + is_first_microbatch = False + # Make sure splits are in expected format if not isinstance(m_splits, torch.Tensor): # Convert list of ints to tensor for backward compatibility @@ -1754,7 +1763,7 @@ def forward( is_grad_enabled, weight_workspaces, cache_weight, - None, # skip_fp8_weight_update + skip_fp8_weight_update, self.save_original_input, debug, )