Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,15 +646,14 @@ def _update_metrics(
current: Optional[int] = None,
total_batches: bool = False,
) -> None:
if not self.is_enabled or self._metric_component is None:
return

if current is not None and not total_batches:
total = self.total_train_batches
if not self._should_update(current, total):
return

metrics = self.get_metrics(trainer, pl_module)
if not self.is_enabled or self._metric_component is None:
return
if self._metric_component:
self._metric_component.update(metrics)
Comment on lines 657 to 658
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition if self._metric_component: is redundant. The previous check on line 655 already ensures that self._metric_component is not None when this line is reached. This check can be simplified or removed.

Suggested change
if self._metric_component:
self._metric_component.update(metrics)
self._metric_component.update(metrics)

Copilot uses AI. Check for mistakes.

Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,9 @@ def on_train_batch_end(

@override
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
metrics = self.get_metrics(trainer, pl_module)
if not self.train_progress_bar.disable:
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
self.train_progress_bar.set_postfix(metrics)
if self._leave:
self.train_progress_bar.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import pickle
from collections import defaultdict
from unittest import mock
from unittest.mock import DEFAULT, Mock
from unittest.mock import DEFAULT, Mock, patch

import pytest
from tests_pytorch.helpers.runif import RunIf
Expand All @@ -26,6 +27,7 @@
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.loggers.logger import DummyLogger
from lightning.pytorch.strategies import DDPStrategy


@RunIf(rich=True)
Expand Down Expand Up @@ -605,3 +607,55 @@ def val_dataloader(self):

# This should not raise an AssertionError
trainer.fit(model)


Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should have the @RunIf(rich=True) decorator to ensure it only runs when the rich library is available, since it instantiates RichProgressBar which raises ModuleNotFoundError when rich is not installed.

Suggested change
@RunIf(rich=True)

Copilot uses AI. Check for mistakes.
def test_rich_progress_bar_ddp_deadlock(tmp_path):
"""Tests that RichProgressBar doesn't deadlock when using DDP on train epoch end.

We used to have a bug where metrics were synced only on the rank 0 process. See
https://github.com/Lightning-AI/pytorch-lightning/issues/21264
for more details.

Comment on lines +614 to +618
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We used to have a bug where metrics were synced only on the rank 0 process. See
https://github.com/Lightning-AI/pytorch-lightning/issues/21264
for more details.

"""
RichProgressBar()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
RichProgressBar()

seem to not really be used?


# We need a LightningModule that logs a metric with on_epoch=True, sync_dist=True
class MyModel(BoringModel):
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)["loss"]
self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True)
return {"loss": loss}

model = MyModel()

# We need to mock these logger connector hooks, since these also attempt to sync metrics
# and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end.
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment mentions 'TQDMProgressBar' but this test is for RichProgressBar. The comment should reference RichProgressBar instead.

Suggested change
# and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end.
# and can "save" otherwise incorrect implementations of RichProgressBar.on_train_epoch_end.

Copilot uses AI. Check for mistakes.
def mock_on_epoch_end(self):
pass

def mock_update_train_epoch_metrics(self):
pass

with (
patch("lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.on_epoch_end", mock_on_epoch_end),
patch(
"lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.update_train_epoch_metrics",
mock_update_train_epoch_metrics,
),
):
trainer = Trainer(
default_root_dir=tmp_path,
num_sanity_val_steps=0,
max_epochs=1,
val_check_interval=1,
accelerator="cpu",
devices=2,
strategy=DDPStrategy(
process_group_backend="gloo", # run on CPU
timeout=datetime.timedelta(seconds=5), # timeout quickly for the test to fail
),
enable_progress_bar=True,
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RichProgressBar instance created on line 620 is not passed to the trainer's callbacks. This test should include callbacks=[pbar] in the Trainer initialization to actually test the RichProgressBar implementation. Without this, the test only uses the default progress bar behavior and doesn't exercise the fix in RichProgressBar._update_metrics().

Copilot uses AI. Check for mistakes.
enable_model_summary=False,
enable_checkpointing=False,
)
trainer.fit(model)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import math
import os
import pickle
Expand All @@ -32,6 +33,7 @@
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.loggers.logger import DummyLogger
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -859,3 +861,57 @@ def reset(self, total=None):
assert 2 in val_bar.total_values, (
f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}"
)


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_tqdm_progress_bar_ddp_deadlock(tmp_path):
"""Tests that TQDMProgressBar doesn't deadlock when using DDP on train epoch end.

We used to have a bug where metrics were synced only on the rank 0 process. See
https://github.com/Lightning-AI/pytorch-lightning/issues/21264
for more details.

Comment on lines +869 to +873
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We used to have a bug where metrics were synced only on the rank 0 process. See
https://github.com/Lightning-AI/pytorch-lightning/issues/21264
for more details.

"""
pbar = TQDMProgressBar()

# We need a LightningModule that logs a metric with on_epoch=True, sync_dist=True
class MyModel(BoringModel):
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)["loss"]
self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True)
return {"loss": loss}

model = MyModel()

# We need to mock these logger connector hooks, since these also attempt to sync metrics
# and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end.
def mock_on_epoch_end(self):
pass

def mock_update_train_epoch_metrics(self):
pass

with (
patch("lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.on_epoch_end", mock_on_epoch_end),
patch(
"lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.update_train_epoch_metrics",
mock_update_train_epoch_metrics,
),
):
trainer = Trainer(
default_root_dir=tmp_path,
num_sanity_val_steps=0,
max_epochs=1,
val_check_interval=1,
accelerator="cpu",
devices=2,
strategy=DDPStrategy(
process_group_backend="gloo", # run on CPU
timeout=datetime.timedelta(seconds=5), # timeout quickly for the test to fail
),
callbacks=[pbar],
enable_progress_bar=True,
enable_model_summary=False,
enable_checkpointing=False,
)
trainer.fit(model)
Loading