Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
164f612
[PyTorch] Add fake implementations for Linear forward/backward
pggPL May 12, 2026
1a9176c
[PyTorch] Add torch.compile dynamo helper module
pggPL May 12, 2026
8bc7a1a
[PyTorch] Iterate on torch.compile support for Linear
pggPL May 13, 2026
f9e45ba
[PyTorch] Fix indentation errors in dynamo.py
pggPL May 14, 2026
be5c4ad
[PyTorch] Replace Linear fake-impls with output-info descriptors
pggPL May 27, 2026
1019b19
[PyTorch] Unify TensorSpec hierarchy and infer num_outputs dynamically
pggPL May 28, 2026
23399f3
[PyTorch] Require output-info descriptors for custom ops
pggPL May 28, 2026
89a80aa
[PyTorch] Drop dead parameters from torch.compile custom op registration
pggPL May 28, 2026
a3f6353
[PyTorch] Unify quantized tensor flatten via declarative schema
pggPL May 28, 2026
e4dfb9a
[PyTorch] Generic Quantizer.create_storage_metadata via declarative s…
pggPL May 28, 2026
dc875f0
[PyTorch] Generic Quantizer._flatten / _do_unflatten via declarative …
pggPL May 28, 2026
e25913c
[PyTorch] Drop dead code and stale comments from torch.compile branch
pggPL May 28, 2026
844f5c2
[PyTorch] Generic Quantizer.create_metadata via declarative schema
pggPL May 28, 2026
14381af
[PyTorch] Drop Recipe._flatten/_unflatten torch.compile protocol
pggPL May 28, 2026
edf11aa
[PyTorch] Consolidate tex.DType opaque-type registration in fp8_dtype.py
pggPL May 28, 2026
c1b0842
[PyTorch] Drop TensorSpec; reassemble torch.compile outputs from fake…
pggPL May 29, 2026
1a2a41a
[PyTorch] Remove dead Quantizer.create_storage_metadata
pggPL Jun 1, 2026
e262f4b
Merge remote-tracking branch 'upstream/main' into linear_torch_compil…
pggPL Jun 1, 2026
d3ba12c
[PyTorch] Add pythonic make_empty path for torch.compile fake-impl
pggPL Jun 1, 2026
98cd401
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 139 additions & 13 deletions tests/pytorch/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,10 @@ def __fx_repr__(self):
def _make_qfactory(tag: str):
"""Return a qfactory that produces ToyQuantizer instances tagged with *tag*."""

quantizers = {
role: ToyQuantizer(tag=f"{tag}:{role}")
for role in (
"linear_input",
"linear_weight",
"linear_output",
"linear_grad_output",
"linear_grad_input",
)
}

def qfactory(role: str):
return quantizers[role]
def qfactory(role):
# ``role`` is a QuantizerRole; tag each slot with its tensor_type so
# the produced ToyQuantizers are distinguishable per tensor.
return ToyQuantizer(tag=f"{tag}:{role.tensor_type}")

return qfactory

Expand Down Expand Up @@ -363,3 +354,138 @@ def fn(inp):

out = compiled(inp)
out.sum().backward()


@pytest.mark.parametrize(
"fp8_recipe",
[None, *_all_recipes],
ids=lambda r: "bf16" if r is None else type(r).__name__,
)
def test_te_linear_compiles(fp8_recipe):
"""torch.compile(fullgraph=True) of ``te.Linear`` under every built-in
recipe (and the bf16-only baseline with no autocast).

Exercises the custom-op path in
:mod:`transformer_engine.pytorch.dynamo`: forward goes through
``_linear_compiled_op``, backward through the registered
``transformer_engine::linear_backward`` op, and the dataclass
arg-objects are packed/unpacked via the bucket dispatch in
:mod:`transformer_engine.pytorch.dynamo`.
"""
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)

dtype = torch.bfloat16
device = "cuda"

# FP8 GEMMs require leading dimensions divisible by 16; pick
# in/out features and batch comfortably above that minimum.
model = te.Linear(64, 32, params_dtype=dtype, device=device)
inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True)

def fn(inp):
if fp8_recipe is None:
return model(inp)
with te.autocast(recipe=fp8_recipe):
return model(inp)

torch._dynamo.reset()
compiled = torch.compile(fn, fullgraph=True)

out = compiled(inp)
out.sum().backward()
assert out.shape == (32, 32)
assert inp.grad is not None
assert model.weight.grad is not None, "weight.grad missing"
assert model.weight.grad.shape == model.weight.shape, (
f"weight.grad shape {tuple(model.weight.grad.shape)} != "
f"weight shape {tuple(model.weight.shape)}"
)


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_te_linear_compile_with_quantized_fp8_weight():
"""torch.compile should handle Linear weights initialized as FP8 tensors."""
dtype = torch.bfloat16
device = "cuda"
fp8_recipe = recipe.Float8CurrentScaling()

with te.quantized_model_init(enabled=True, recipe=fp8_recipe):
model = te.Linear(64, 32, params_dtype=dtype, device=device)

assert isinstance(model.weight, te.Float8Tensor)
inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True)

def fn(inp):
with te.autocast(recipe=fp8_recipe):
return model(inp)

torch._dynamo.reset()
compiled = torch.compile(fn, fullgraph=True)

out = compiled(inp)
out.sum().backward()
assert out.shape == (32, 32)
assert inp.grad is not None
assert model.weight.grad is not None, "Float8Tensor weight.grad missing"
assert model.weight.grad.shape == model.weight.shape, (
f"Float8Tensor weight.grad shape {tuple(model.weight.grad.shape)} != "
f"weight shape {tuple(model.weight.shape)}"
)


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_te_linear_compile_with_fp8_output():
"""torch.compile of ``te.Linear(..., fp8_output=True)``: forward returns
a :class:`Float8Tensor`.

Exercises the output-rewrap path in
:mod:`transformer_engine.pytorch.dynamo`: the first user output is
declared ``Union[torch.Tensor, Float8Tensor]`` in ``output_annotations``,
and when an output quantizer is active the eager + fake paths must
rewrap the inner data tensors back into a ``Float8Tensor`` for the
user-facing slot.

Backward through a subclass return value is a known PyTorch
``torch.compile`` limitation (Dynamo / AOT autograd drop the
``grad_fn`` on wrapper-subclass outputs of custom ops, so
``out.sum().backward()`` errors with "element 0 of tensors does
not require grad and does not have a grad_fn"). The forward shape
+ type assertions below are sufficient to exercise the rewrap;
grad-routing on FP8 outputs under compile is left as future work.
"""
dtype = torch.bfloat16
device = "cuda"
fp8_recipe = recipe.Float8CurrentScaling()

model = te.Linear(64, 32, params_dtype=dtype, device=device)
inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True)

def fn(inp):
with te.autocast(recipe=fp8_recipe):
return model(inp, fp8_output=True)

torch._dynamo.reset()
compiled = torch.compile(fn, fullgraph=True)

out = compiled(inp)
assert isinstance(
out, te.Float8Tensor
), f"expected Float8Tensor output, got {type(out).__name__}"
assert out.shape == (32, 32)
# The compile-path reassembly rebuilds the wrapper via
# ``__tensor_unflatten__``, whose snapshot-free ``meta`` forces
# ``quantizer=None`` (a live ``ProcessGroup`` / amax-reduction group
# can't survive Dynamo guards). ``make_fake_empty`` stashes the live
# quantizer on the fake template and the reassembly helper restores it,
# so the output must keep a (non-``None``) quantizer rather than losing
# its amax-reduction group.
assert (
out._quantizer is not None
), "FP8 output lost its quantizer (and thus its amax-reduction group) on the torch.compile path"
# Dequantising outside the compiled region exercises the
# ``Float8Tensor`` machinery (scale + data + dtype all wired up
# by the rewrap) on the value returned from the compiled fn.
deq = out.dequantize()
assert deq.shape == (32, 32)
assert deq.dtype == dtype
28 changes: 28 additions & 0 deletions transformer_engine/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@
tex.DType.kBFloat16: torch.bfloat16,
}

# Map: TE DType *id* (Python int) -> TE DType enum. Used by
# :func:`canonicalize_te_dtype` to recover the pybind enum from its
# integer id without going through ``tex.DType(int)``, which Dynamo
# cannot trace (pybind11 enum constructor is opaque).
TE_DType_ID_To_TE = {
int(tex.DType.kByte): tex.DType.kByte,
int(tex.DType.kFloat8E4M3): tex.DType.kFloat8E4M3,
int(tex.DType.kFloat8E5M2): tex.DType.kFloat8E5M2,
int(tex.DType.kFloat4E2M1): tex.DType.kFloat4E2M1,
int(tex.DType.kInt32): tex.DType.kInt32,
int(tex.DType.kFloat32): tex.DType.kFloat32,
int(tex.DType.kFloat16): tex.DType.kFloat16,
int(tex.DType.kBFloat16): tex.DType.kBFloat16,
}


def canonicalize_te_dtype(dtype):
"""Accept either a TE ``DType`` enum or its Python ``int`` id.

Recipe state keeps dtype ids as Python ``int`` values for cheap,
trace-friendly comparisons. Quantizer objects, however, are passed to
TE's C++ bindings, which expect the pybind ``tex.DType`` enum.
"""
if isinstance(dtype, int):
return TE_DType_ID_To_TE[dtype]
return dtype


# Cache enum -> int conversions to avoid repeated PyObject lookups.
FP8FwdTensorIdx = SimpleNamespace(
GEMM1_INPUT=int(tex.FP8FwdTensors.GEMM1_INPUT),
Expand Down
Loading
Loading