Skip to content
10 changes: 3 additions & 7 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

---

## [Unreleased] - YYYY-MM-DD

### Added

-

### Changed
### Fixed

- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))
-

### Deprecated
Expand All @@ -22,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

---
- Removed support for Python 3.9 due to end-of-life status ([#21398](https://github.com/Lightning-AI/pytorch-lightning/pull/21398))

### Fixed
Expand Down
54 changes: 53 additions & 1 deletion src/lightning/pytorch/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]:

state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options)
if not self._save_distributed_checkpoint and self.global_rank == 0:
# Store the optimizer state dict in standard format
state_dict = _align_compiled_param_names_with_module(state_dict, self.model)
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
return state_dict

Expand Down Expand Up @@ -366,3 +366,55 @@ def set_world_ranks(self) -> None:
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank


def _align_compiled_param_names_with_module(state_dict: dict[str, Any], module: torch.nn.Module) -> dict[str, Any]:
"""Align optimizer state dict keys with a module that may have compiled submodules.

When ``torch.compile`` wraps a submodule, its parameters appear under ``_orig_mod``.
For example, ``model.0.weight`` becomes ``model._orig_mod.0.weight``. The optimizer
state dict returned by ``get_optimizer_state_dict`` may not include the ``_orig_mod``
prefix, causing a mismatch when ``rekey_optim_state_dict`` builds its mapping from
``module.named_parameters()``.

This function inserts ``._orig_mod`` into the state dict keys where necessary so that
they match the module's ``named_parameters()`` output.

"""
from torch._dynamo import OptimizedModule

# Build set of compiled submodule prefixes (e.g., "model" if model is compiled)
compiled_prefixes: list[str] = []
for name, submodule in module.named_modules():
if isinstance(submodule, OptimizedModule):
compiled_prefixes.append(name)

if not compiled_prefixes:
return state_dict

# Sort by length descending so longer prefixes are matched first
compiled_prefixes.sort(key=len, reverse=True)

def _transform_key(key: str) -> str:
for prefix in compiled_prefixes:
# Check if key starts with "prefix." (the compiled module path)
if key == prefix or key.startswith(prefix + "."):
suffix = key[len(prefix) :] # e.g., ".0.weight" or ""
# Insert _orig_mod between prefix and rest
return f"{prefix}._orig_mod{suffix}"
return key

# Transform keys in "state" section of the optimizer state dict
if "state" in state_dict:
new_state = {_transform_key(k): v for k, v in state_dict["state"].items()}
state_dict = {**state_dict, "state": new_state}

# Transform param names in "param_groups" section
if "param_groups" in state_dict:
new_param_groups = []
for group in state_dict["param_groups"]:
new_group = {**group, "params": [_transform_key(p) for p in group["params"]]}
new_param_groups.append(new_group)
state_dict = {**state_dict, "param_groups": new_param_groups}

return state_dict
75 changes: 75 additions & 0 deletions tests/tests_pytorch/strategies/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,78 @@ def configure_model(self) -> None:
strategy.setup(Mock())
assert all(not p.is_meta for p in model.parameters())
assert all(not b.is_meta for b in model.buffers())


@RunIf(min_torch="2.4")
def test_align_compiled_param_names_with_module():
"""Test that optimizer state dict keys are aligned with compiled submodule parameter names."""
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module

class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32))

def forward(self, x):
return self.model(x)

# Test with compiled submodule
m = SimpleModule()
m.model = torch.compile(m.model)

# Simulate optimizer state dict without _orig_mod in keys (includes both state and param_groups)
state_dict = {
"state": {
"model.0.weight": {"step": 1},
"model.0.bias": {"step": 1},
"model.2.weight": {"step": 1},
"model.2.bias": {"step": 1},
},
"param_groups": [{"params": ["model.0.weight", "model.0.bias", "model.2.weight", "model.2.bias"], "lr": 0.01}],
}

result = _align_compiled_param_names_with_module(state_dict, m)

# Verify state keys now have _orig_mod inserted
expected_keys = {
"model._orig_mod.0.weight",
"model._orig_mod.0.bias",
"model._orig_mod.2.weight",
"model._orig_mod.2.bias",
}
assert set(result["state"].keys()) == expected_keys

# Verify param_groups params also have _orig_mod inserted
assert set(result["param_groups"][0]["params"]) == expected_keys

# Verify they match the module's named_parameters
param_names = {name for name, _ in m.named_parameters()}
assert set(result["state"].keys()) == param_names


@RunIf(min_torch="2.4")
def test_align_compiled_param_names_no_compile():
"""Test that non-compiled modules pass through unchanged."""
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module

class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(nn.Linear(32, 64), nn.Linear(64, 32))

def forward(self, x):
return self.model(x)

m = SimpleModule() # Not compiled

state_dict = {
"state": {
"model.0.weight": {"step": 1},
"model.0.bias": {"step": 1},
}
}

result = _align_compiled_param_names_with_module(state_dict, m)

# Keys should be unchanged
assert set(result["state"].keys()) == {"model.0.weight", "model.0.bias"}
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,33 @@ def configure_model(self):
parallelize(self.model, device_mesh=self.device_mesh)


class SimpleCompiledModule(LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32))
self._loss = nn.MSELoss()

def configure_model(self):
self.model = torch.compile(self.model)

def training_step(self, batch, batch_idx):
x, y = batch
preds = self.model(x)
return self._loss(preds, y)

def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-3)


def _compiled_model_dataloader(batch_size: int = 32, num_batches: int = 2):
total_samples = batch_size * num_batches
generator = torch.Generator().manual_seed(0)
features = torch.randn(total_samples, 32, generator=generator)
targets = torch.randn(total_samples, 32, generator=generator)
dataset = torch.utils.data.TensorDataset(features, targets)
return DataLoader(dataset, batch_size=batch_size)


@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
def test_setup_device_mesh(distributed):
from torch.distributed.device_mesh import DeviceMesh
Expand Down Expand Up @@ -237,6 +264,44 @@ def training_step(self, batch):
trainer.fit(model)


@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
def test_model_parallel_single_file_checkpoint_with_compile(distributed, tmp_path):
"""Replicate the reporter's setup: compiled model + ModelParallel single-file checkpointing."""

seed_everything(0)
strategy = ModelParallelStrategy(
data_parallel_size=1,
tensor_parallel_size=1,
save_distributed_checkpoint=False,
)

trainer = Trainer(
accelerator="auto",
devices=1,
strategy=strategy,
max_steps=2,
limit_train_batches=2,
enable_checkpointing=False,
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
default_root_dir=tmp_path,
)

dataloader = _compiled_model_dataloader(batch_size=32, num_batches=2)

with trainer.init_module(empty_init=True):
model = SimpleCompiledModule()

trainer.fit(model, dataloader)

if trainer.is_global_zero:
checkpoint_path = tmp_path / "compiled-model.ckpt"
trainer.save_checkpoint(checkpoint_path)

trainer.strategy.barrier()


@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
@pytest.mark.parametrize(
"compile",
Expand Down
Loading