Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from transformer_engine.pytorch import (
DotProductAttention,
GroupedLinear,
LayerNormLinear,
LayerNormMLP,
Linear,
Expand Down Expand Up @@ -216,6 +217,38 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
return x


class _GroupedLinearWrapper(torch.nn.Module):
"""Adapt GroupedLinear to the [seqlen, batch, hidden] data used in this test.

GroupedLinear expects a 2D `[total_tokens, hidden]` input plus an `m_splits`
list, so this wrapper flattens the leading dims, splits the tokens evenly
across the GEMMs, and restores the original shape. It also forwards
`is_first_microbatch` so the FP8 weight-caching path is exercised under CUDA
graphs.
"""

def __init__(self, hidden_size: int, num_gemms: int, params_dtype: torch.dtype) -> None:
super().__init__()
self.num_gemms = num_gemms
self.grouped_linear = GroupedLinear(
num_gemms,
hidden_size,
hidden_size,
device="cuda",
params_dtype=params_dtype,
)

def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
seqlen, batch, hidden = input_.shape
x = input_.reshape(seqlen * batch, hidden)
total_tokens = x.shape[0]
assert total_tokens % self.num_gemms == 0, "tokens must split evenly across GEMMs"
split = total_tokens // self.num_gemms
m_splits = [split] * self.num_gemms
out = self.grouped_linear(x, m_splits, **kwargs)
return out.reshape(seqlen, batch, hidden)


# Supported modules
_test_cuda_graphs_modules: List[str] = [
# Put linear first to test the case where the cuda context might not be set in
Expand Down Expand Up @@ -315,6 +348,15 @@ def _test_cuda_graphs(
)
for _ in range(num_layers)
]
elif module == "grouped_linear":
modules = [
_GroupedLinearWrapper(
model_config.hidden_size,
num_gemms=2,
params_dtype=dtype,
)
for _ in range(num_layers)
]
elif module == "linear_op":
modules = [
te_ops.Sequential(
Expand Down Expand Up @@ -501,6 +543,52 @@ def test_make_graphed_callables_with_fp8_weight_caching(
)


# Per-tensor FP8 recipes that support GroupedLinear FP8 weight caching.
_grouped_linear_fp8_weight_caching_recipes = []
if fp8_available:
_grouped_linear_fp8_weight_caching_recipes.append(recipe.DelayedScaling())
_grouped_linear_fp8_weight_caching_recipes.append(recipe.Float8CurrentScaling())


@pytest.mark.skipif(not fp8_available, reason="FP8 is not supported")
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", _grouped_linear_fp8_weight_caching_recipes, ids=recipe_id)
def test_make_graphed_callables_grouped_linear_with_fp8_weight_caching(
*,
dtype: torch.dtype,
fp8_params: bool,
fp8_recipe: recipe.Recipe,
model_config: str = "small",
num_layers: int = 3,
) -> None:
"""GroupedLinear must thread `is_first_microbatch` into the FP8 weight-update
skip tensor under CUDA graphs.

With `fp8_weight_caching` enabled, the graphed and non-graphed runs only match
when `skip_fp8_weight_update` is propagated for every microbatch. Before the
fix, GroupedLinear hardcoded it to `None`, so the cached FP8 weights diverged
from the eager reference. This regresses if that propagation is dropped again.
"""
config = model_configs[model_config]
kwargs = dict(
module="grouped_linear",
model_config=config,
num_layers=num_layers,
dtype=dtype,
fp8=True,
fp8_params=fp8_params,
fp8_weight_caching=True,
fp8_recipe=fp8_recipe,
)
graph_outputs_full = _test_cuda_graphs(graph_mode="full", **kwargs)
graph_outputs_individual = _test_cuda_graphs(graph_mode="individual", **kwargs)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)

assert_all_equal(outputs, graph_outputs_full)
assert_all_equal(outputs, graph_outputs_individual)


def generate_data_for_dot_product_attention(
model_config: ModelConfig,
dtype: torch.dtype,
Expand Down
11 changes: 10 additions & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,15 @@ def forward(
is_grad_enabled = torch.is_grad_enabled()
num_gemms = self.num_gemms

if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = (
FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor
)
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

# Make sure splits are in expected format
if not isinstance(m_splits, torch.Tensor):
# Convert list of ints to tensor for backward compatibility
Expand Down Expand Up @@ -1754,7 +1763,7 @@ def forward(
is_grad_enabled,
weight_workspaces,
cache_weight,
None, # skip_fp8_weight_update
skip_fp8_weight_update,
self.save_original_input,
debug,
)
Expand Down