Skip to content

Commit 28f18dc

Browse files
Fix ModelParallelStrategy fails with non-distributed checkpoint. (#21384)
* Add regression test for ModelParallel single-file checkpoint * Fix ModelParallel single-file checkpoint with compiled modules
1 parent 9e1b038 commit 28f18dc

File tree

4 files changed

+196
-8
lines changed

4 files changed

+196
-8
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
---
88

9-
## [Unreleased] - YYYY-MM-DD
10-
11-
### Added
12-
13-
-
14-
15-
### Changed
9+
### Fixed
1610

11+
- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))
1712
-
1813

1914
### Deprecated
@@ -22,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2217

2318
### Removed
2419

20+
---
2521
- Removed support for Python 3.9 due to end-of-life status ([#21398](https://github.com/Lightning-AI/pytorch-lightning/pull/21398))
2622

2723
### Fixed

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]:
286286

287287
state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options)
288288
if not self._save_distributed_checkpoint and self.global_rank == 0:
289-
# Store the optimizer state dict in standard format
289+
state_dict = _align_compiled_param_names_with_module(state_dict, self.model)
290290
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
291291
return state_dict
292292

@@ -366,3 +366,55 @@ def set_world_ranks(self) -> None:
366366
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
367367
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
368368
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
369+
370+
371+
def _align_compiled_param_names_with_module(state_dict: dict[str, Any], module: torch.nn.Module) -> dict[str, Any]:
372+
"""Align optimizer state dict keys with a module that may have compiled submodules.
373+
374+
When ``torch.compile`` wraps a submodule, its parameters appear under ``_orig_mod``.
375+
For example, ``model.0.weight`` becomes ``model._orig_mod.0.weight``. The optimizer
376+
state dict returned by ``get_optimizer_state_dict`` may not include the ``_orig_mod``
377+
prefix, causing a mismatch when ``rekey_optim_state_dict`` builds its mapping from
378+
``module.named_parameters()``.
379+
380+
This function inserts ``._orig_mod`` into the state dict keys where necessary so that
381+
they match the module's ``named_parameters()`` output.
382+
383+
"""
384+
from torch._dynamo import OptimizedModule
385+
386+
# Build set of compiled submodule prefixes (e.g., "model" if model is compiled)
387+
compiled_prefixes: list[str] = []
388+
for name, submodule in module.named_modules():
389+
if isinstance(submodule, OptimizedModule):
390+
compiled_prefixes.append(name)
391+
392+
if not compiled_prefixes:
393+
return state_dict
394+
395+
# Sort by length descending so longer prefixes are matched first
396+
compiled_prefixes.sort(key=len, reverse=True)
397+
398+
def _transform_key(key: str) -> str:
399+
for prefix in compiled_prefixes:
400+
# Check if key starts with "prefix." (the compiled module path)
401+
if key == prefix or key.startswith(prefix + "."):
402+
suffix = key[len(prefix) :] # e.g., ".0.weight" or ""
403+
# Insert _orig_mod between prefix and rest
404+
return f"{prefix}._orig_mod{suffix}"
405+
return key
406+
407+
# Transform keys in "state" section of the optimizer state dict
408+
if "state" in state_dict:
409+
new_state = {_transform_key(k): v for k, v in state_dict["state"].items()}
410+
state_dict = {**state_dict, "state": new_state}
411+
412+
# Transform param names in "param_groups" section
413+
if "param_groups" in state_dict:
414+
new_param_groups = []
415+
for group in state_dict["param_groups"]:
416+
new_group = {**group, "params": [_transform_key(p) for p in group["params"]]}
417+
new_param_groups.append(new_group)
418+
state_dict = {**state_dict, "param_groups": new_param_groups}
419+
420+
return state_dict

tests/tests_pytorch/strategies/test_model_parallel.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,78 @@ def configure_model(self) -> None:
251251
strategy.setup(Mock())
252252
assert all(not p.is_meta for p in model.parameters())
253253
assert all(not b.is_meta for b in model.buffers())
254+
255+
256+
@RunIf(min_torch="2.4")
257+
def test_align_compiled_param_names_with_module():
258+
"""Test that optimizer state dict keys are aligned with compiled submodule parameter names."""
259+
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module
260+
261+
class SimpleModule(nn.Module):
262+
def __init__(self):
263+
super().__init__()
264+
self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32))
265+
266+
def forward(self, x):
267+
return self.model(x)
268+
269+
# Test with compiled submodule
270+
m = SimpleModule()
271+
m.model = torch.compile(m.model)
272+
273+
# Simulate optimizer state dict without _orig_mod in keys (includes both state and param_groups)
274+
state_dict = {
275+
"state": {
276+
"model.0.weight": {"step": 1},
277+
"model.0.bias": {"step": 1},
278+
"model.2.weight": {"step": 1},
279+
"model.2.bias": {"step": 1},
280+
},
281+
"param_groups": [{"params": ["model.0.weight", "model.0.bias", "model.2.weight", "model.2.bias"], "lr": 0.01}],
282+
}
283+
284+
result = _align_compiled_param_names_with_module(state_dict, m)
285+
286+
# Verify state keys now have _orig_mod inserted
287+
expected_keys = {
288+
"model._orig_mod.0.weight",
289+
"model._orig_mod.0.bias",
290+
"model._orig_mod.2.weight",
291+
"model._orig_mod.2.bias",
292+
}
293+
assert set(result["state"].keys()) == expected_keys
294+
295+
# Verify param_groups params also have _orig_mod inserted
296+
assert set(result["param_groups"][0]["params"]) == expected_keys
297+
298+
# Verify they match the module's named_parameters
299+
param_names = {name for name, _ in m.named_parameters()}
300+
assert set(result["state"].keys()) == param_names
301+
302+
303+
@RunIf(min_torch="2.4")
304+
def test_align_compiled_param_names_no_compile():
305+
"""Test that non-compiled modules pass through unchanged."""
306+
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module
307+
308+
class SimpleModule(nn.Module):
309+
def __init__(self):
310+
super().__init__()
311+
self.model = nn.Sequential(nn.Linear(32, 64), nn.Linear(64, 32))
312+
313+
def forward(self, x):
314+
return self.model(x)
315+
316+
m = SimpleModule() # Not compiled
317+
318+
state_dict = {
319+
"state": {
320+
"model.0.weight": {"step": 1},
321+
"model.0.bias": {"step": 1},
322+
}
323+
}
324+
325+
result = _align_compiled_param_names_with_module(state_dict, m)
326+
327+
# Keys should be unchanged
328+
assert set(result["state"].keys()) == {"model.0.weight", "model.0.bias"}

tests/tests_pytorch/strategies/test_model_parallel_integration.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,33 @@ def configure_model(self):
135135
parallelize(self.model, device_mesh=self.device_mesh)
136136

137137

138+
class SimpleCompiledModule(LightningModule):
139+
def __init__(self):
140+
super().__init__()
141+
self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32))
142+
self._loss = nn.MSELoss()
143+
144+
def configure_model(self):
145+
self.model = torch.compile(self.model)
146+
147+
def training_step(self, batch, batch_idx):
148+
x, y = batch
149+
preds = self.model(x)
150+
return self._loss(preds, y)
151+
152+
def configure_optimizers(self):
153+
return torch.optim.AdamW(self.parameters(), lr=1e-3)
154+
155+
156+
def _compiled_model_dataloader(batch_size: int = 32, num_batches: int = 2):
157+
total_samples = batch_size * num_batches
158+
generator = torch.Generator().manual_seed(0)
159+
features = torch.randn(total_samples, 32, generator=generator)
160+
targets = torch.randn(total_samples, 32, generator=generator)
161+
dataset = torch.utils.data.TensorDataset(features, targets)
162+
return DataLoader(dataset, batch_size=batch_size)
163+
164+
138165
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
139166
def test_setup_device_mesh(distributed):
140167
from torch.distributed.device_mesh import DeviceMesh
@@ -237,6 +264,44 @@ def training_step(self, batch):
237264
trainer.fit(model)
238265

239266

267+
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
268+
def test_model_parallel_single_file_checkpoint_with_compile(distributed, tmp_path):
269+
"""Replicate the reporter's setup: compiled model + ModelParallel single-file checkpointing."""
270+
271+
seed_everything(0)
272+
strategy = ModelParallelStrategy(
273+
data_parallel_size=1,
274+
tensor_parallel_size=1,
275+
save_distributed_checkpoint=False,
276+
)
277+
278+
trainer = Trainer(
279+
accelerator="auto",
280+
devices=1,
281+
strategy=strategy,
282+
max_steps=2,
283+
limit_train_batches=2,
284+
enable_checkpointing=False,
285+
logger=False,
286+
enable_progress_bar=False,
287+
enable_model_summary=False,
288+
default_root_dir=tmp_path,
289+
)
290+
291+
dataloader = _compiled_model_dataloader(batch_size=32, num_batches=2)
292+
293+
with trainer.init_module(empty_init=True):
294+
model = SimpleCompiledModule()
295+
296+
trainer.fit(model, dataloader)
297+
298+
if trainer.is_global_zero:
299+
checkpoint_path = tmp_path / "compiled-model.ckpt"
300+
trainer.save_checkpoint(checkpoint_path)
301+
302+
trainer.strategy.barrier()
303+
304+
240305
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
241306
@pytest.mark.parametrize(
242307
"compile",

0 commit comments

Comments
 (0)