-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix progress bar deadlock on DDP metrics computation. #21322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,10 +11,11 @@ | |||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
| # See the License for the specific language governing permissions and | ||||||||||
| # limitations under the License. | ||||||||||
| import datetime | ||||||||||
| import pickle | ||||||||||
| from collections import defaultdict | ||||||||||
| from unittest import mock | ||||||||||
| from unittest.mock import DEFAULT, Mock | ||||||||||
| from unittest.mock import DEFAULT, Mock, patch | ||||||||||
|
|
||||||||||
| import pytest | ||||||||||
| from tests_pytorch.helpers.runif import RunIf | ||||||||||
|
|
@@ -26,6 +27,7 @@ | |||||||||
| from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset | ||||||||||
| from lightning.pytorch.loggers import CSVLogger | ||||||||||
| from lightning.pytorch.loggers.logger import DummyLogger | ||||||||||
| from lightning.pytorch.strategies import DDPStrategy | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @RunIf(rich=True) | ||||||||||
|
|
@@ -605,3 +607,55 @@ def val_dataloader(self): | |||||||||
|
|
||||||||||
| # This should not raise an AssertionError | ||||||||||
| trainer.fit(model) | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
||||||||||
| @RunIf(rich=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| We used to have a bug where metrics were synced only on the rank 0 process. See | |
| https://github.com/Lightning-AI/pytorch-lightning/issues/21264 | |
| for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| RichProgressBar() |
seem to not really be used?
Copilot
AI
Dec 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment mentions 'TQDMProgressBar' but this test is for RichProgressBar. The comment should reference RichProgressBar instead.
| # and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end. | |
| # and can "save" otherwise incorrect implementations of RichProgressBar.on_train_epoch_end. |
Copilot
AI
Dec 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RichProgressBar instance created on line 620 is not passed to the trainer's callbacks. This test should include callbacks=[pbar] in the Trainer initialization to actually test the RichProgressBar implementation. Without this, the test only uses the default progress bar behavior and doesn't exercise the fix in RichProgressBar._update_metrics().
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
| # See the License for the specific language governing permissions and | ||||||||||
| # limitations under the License. | ||||||||||
| import datetime | ||||||||||
| import math | ||||||||||
| import os | ||||||||||
| import pickle | ||||||||||
|
|
@@ -32,6 +33,7 @@ | |||||||||
| from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset | ||||||||||
| from lightning.pytorch.loggers import CSVLogger | ||||||||||
| from lightning.pytorch.loggers.logger import DummyLogger | ||||||||||
| from lightning.pytorch.strategies import DDPStrategy | ||||||||||
| from lightning.pytorch.utilities.exceptions import MisconfigurationException | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -859,3 +861,57 @@ def reset(self, total=None): | |||||||||
| assert 2 in val_bar.total_values, ( | ||||||||||
| f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}" | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) | ||||||||||
| def test_tqdm_progress_bar_ddp_deadlock(tmp_path): | ||||||||||
| """Tests that TQDMProgressBar doesn't deadlock when using DDP on train epoch end. | ||||||||||
|
|
||||||||||
| We used to have a bug where metrics were synced only on the rank 0 process. See | ||||||||||
| https://github.com/Lightning-AI/pytorch-lightning/issues/21264 | ||||||||||
| for more details. | ||||||||||
|
|
||||||||||
|
Comment on lines
+869
to
+873
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| """ | ||||||||||
| pbar = TQDMProgressBar() | ||||||||||
|
|
||||||||||
| # We need a LightningModule that logs a metric with on_epoch=True, sync_dist=True | ||||||||||
| class MyModel(BoringModel): | ||||||||||
| def training_step(self, batch, batch_idx): | ||||||||||
| loss = super().training_step(batch, batch_idx)["loss"] | ||||||||||
| self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True) | ||||||||||
| return {"loss": loss} | ||||||||||
|
|
||||||||||
| model = MyModel() | ||||||||||
|
|
||||||||||
| # We need to mock these logger connector hooks, since these also attempt to sync metrics | ||||||||||
| # and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end. | ||||||||||
| def mock_on_epoch_end(self): | ||||||||||
| pass | ||||||||||
|
|
||||||||||
| def mock_update_train_epoch_metrics(self): | ||||||||||
| pass | ||||||||||
|
|
||||||||||
| with ( | ||||||||||
| patch("lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.on_epoch_end", mock_on_epoch_end), | ||||||||||
| patch( | ||||||||||
| "lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.update_train_epoch_metrics", | ||||||||||
| mock_update_train_epoch_metrics, | ||||||||||
| ), | ||||||||||
| ): | ||||||||||
| trainer = Trainer( | ||||||||||
| default_root_dir=tmp_path, | ||||||||||
| num_sanity_val_steps=0, | ||||||||||
| max_epochs=1, | ||||||||||
| val_check_interval=1, | ||||||||||
| accelerator="cpu", | ||||||||||
| devices=2, | ||||||||||
| strategy=DDPStrategy( | ||||||||||
| process_group_backend="gloo", # run on CPU | ||||||||||
| timeout=datetime.timedelta(seconds=5), # timeout quickly for the test to fail | ||||||||||
| ), | ||||||||||
| callbacks=[pbar], | ||||||||||
| enable_progress_bar=True, | ||||||||||
| enable_model_summary=False, | ||||||||||
| enable_checkpointing=False, | ||||||||||
| ) | ||||||||||
| trainer.fit(model) | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
if self._metric_component:is redundant. The previous check on line 655 already ensures thatself._metric_component is not Nonewhen this line is reached. This check can be simplified or removed.