Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 64 additions & 26 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
from collections.abc import Generator
from dataclasses import dataclass
from datetime import timedelta
Expand All @@ -22,13 +23,15 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
from lightning.pytorch.utilities.types import STEP_OUTPUT

if _RICH_AVAILABLE:
from rich import get_console, reconfigure
from rich.console import Console, RenderableType
from rich.live import _RefreshThread as _RichRefreshThread
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn
from rich.progress_bar import ProgressBar as _RichProgressBar
from rich.style import Style
Expand Down Expand Up @@ -66,9 +69,49 @@ class CustomInfiniteTask(Task):
def time_remaining(self) -> Optional[float]:
return None

class _RefreshThread(_RichRefreshThread):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.refresh_cond = False
super().__init__(*args, **kwargs)

def run(self) -> None:
while not self.done.is_set():
if self.refresh_cond:
with self.live._lock:
self.live.refresh()
self.refresh_cond = False
time.sleep(1 / self.refresh_per_second)

class CustomProgress(Progress):
"""Overrides ``Progress`` to support adding tasks that have an infinite total size."""

def start(self) -> None:
"""Starts the progress display.

Notes
-----
This override is needed to support the custom refresh thread.

"""
if self.live.auto_refresh:
self.live._refresh_thread = _RefreshThread(self.live, self.live.refresh_per_second)
self.live.auto_refresh = False
super().start()
if self.live._refresh_thread:
self.live.auto_refresh = True
self.live._refresh_thread.start()

def stop(self) -> None:
refresh_thread = self.live._refresh_thread
super().stop()
if refresh_thread:
refresh_thread.stop()
refresh_thread.join()

def soft_refresh(self) -> None:
if self.live.auto_refresh and isinstance(self.live._refresh_thread, _RefreshThread):
self.live._refresh_thread.refresh_cond = True

def add_task(
self,
description: str,
Expand Down Expand Up @@ -239,8 +282,8 @@ class RichProgressBar(ProgressBar):
trainer = Trainer(callbacks=RichProgressBar())

Args:
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display.
refresh_rate: Determines at which rate (per second) the progress bars get updated.
Set it to ``0`` to disable the display. Default: 100
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
theme: Contains styles used to stylize the progress bar.
console_kwargs: Args for constructing a `Console`
Expand All @@ -258,7 +301,7 @@ class RichProgressBar(ProgressBar):

def __init__(
self,
refresh_rate: int = 1,
refresh_rate: int = 100,
leave: bool = False,
theme: RichProgressBarTheme = RichProgressBarTheme(),
console_kwargs: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -356,17 +399,21 @@ def _init_progress(self, trainer: "pl.Trainer") -> None:
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,
auto_refresh=False,
auto_refresh=True,
refresh_per_second=self.refresh_rate if self.is_enabled else 1,
disable=self.is_disabled,
console=self._console,
)
self.progress.start()
# progress has started
self._progress_stopped = False

def refresh(self) -> None:
def refresh(self, hard: bool = False) -> None:
if self.progress:
self.progress.refresh()
if hard or _IS_INTERACTIVE:
self.progress.refresh()
else:
self.progress.soft_refresh()

@override
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down Expand Up @@ -466,16 +513,16 @@ def _initialize_train_progress_bar_id(self) -> None:
train_description = self._get_train_description(self.trainer.current_epoch)
self.train_progress_bar_id = self._add_task(total_batches, train_description)

def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
def _update(
self,
progress_bar_id: Optional["TaskID"],
current: int,
visible: bool = True,
hard: bool = False,
) -> None:
if self.progress is not None and self.is_enabled and progress_bar_id is not None:
total = self.progress.tasks[progress_bar_id].total
assert total is not None
if not self._should_update(current, total):
return
self.progress.update(progress_bar_id, completed=current, visible=visible)

def _should_update(self, current: int, total: Union[int, float]) -> bool:
return current % self.refresh_rate == 0 or current == total
self.refresh(hard=hard)

@override
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down Expand Up @@ -549,12 +596,13 @@ def on_train_batch_end(
# can happen when resuming from a mid-epoch restart
self._initialize_train_progress_bar_id()
self._update(self.train_progress_bar_id, batch_idx + 1)
self._update_metrics(trainer, pl_module, batch_idx + 1)
self._update_metrics(trainer, pl_module)
self.refresh()

@override
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._update_metrics(trainer, pl_module, total_batches=True)
self._update_metrics(trainer, pl_module)
self.refresh()

@override
def on_validation_batch_end(
Expand All @@ -576,7 +624,6 @@ def on_validation_batch_end(
if self.val_progress_bar_id is None:
return
self._update(self.val_progress_bar_id, batch_idx + 1)
self.refresh()

@override
def on_test_batch_end(
Expand All @@ -591,7 +638,6 @@ def on_test_batch_end(
if self.is_disabled or self.test_progress_bar_id is None:
return
self._update(self.test_progress_bar_id, batch_idx + 1)
self.refresh()

@override
def on_predict_batch_end(
Expand All @@ -606,7 +652,6 @@ def on_predict_batch_end(
if self.is_disabled or self.predict_progress_bar_id is None:
return
self._update(self.predict_progress_bar_id, batch_idx + 1)
self.refresh()

def _get_train_description(self, current_epoch: int) -> str:
train_description = f"Epoch {current_epoch}"
Expand Down Expand Up @@ -643,17 +688,10 @@ def _update_metrics(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
current: Optional[int] = None,
total_batches: bool = False,
) -> None:
if not self.is_enabled or self._metric_component is None:
return

if current is not None and not total_batches:
total = self.total_train_batches
if not self._should_update(current, total):
return

metrics = self.get_metrics(trainer, pl_module)
if self._metric_component:
self._metric_component.update(metrics)
Expand Down
34 changes: 19 additions & 15 deletions tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def test_rich_progress_bar_custom_theme():
_, kwargs = mocks["ProcessingSpeedColumn"].call_args
assert kwargs["style"] == theme.processing_speed

progress_bar.progress.live._refresh_thread.stop()
progress_bar.progress.live._refresh_thread.join()


@RunIf(rich=True)
def test_rich_progress_bar_keyboard_interrupt(tmp_path):
Expand Down Expand Up @@ -176,6 +179,8 @@ def configure_columns(self, trainer):
assert progress_bar.progress.columns[0] == custom_column
assert len(progress_bar.progress.columns) == 2

progress_bar.progress.stop()


@RunIf(rich=True)
@pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 3)]))
Expand Down Expand Up @@ -216,30 +221,27 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmp_path):

@RunIf(rich=True)
@pytest.mark.parametrize(
("refresh_rate", "train_batches", "val_batches", "expected_call_count"),
("train_batches", "val_batches", "expected_call_count"),
[
# note: there is always one extra update at the very end (+1)
(3, 6, 6, 2 + 2 + 1),
(4, 6, 6, 2 + 2 + 1),
(7, 6, 6, 1 + 1 + 1),
(1, 2, 3, 2 + 3 + 1),
(1, 0, 0, 0 + 0),
(3, 1, 0, 1 + 0),
(3, 1, 1, 1 + 1 + 1),
(3, 5, 0, 2 + 0),
(3, 5, 2, 2 + 1 + 1),
(6, 5, 2, 1 + 1 + 1),
(6, 6, 6 + 6 + 1),
(2, 3, 2 + 3 + 1),
(0, 0, 0 + 0),
(1, 0, 1 + 0),
(1, 1, 1 + 1 + 1),
(5, 0, 5 + 0),
(5, 2, 5 + 2 + 1),
],
)
def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batches, val_batches, expected_call_count):
def test_rich_progress_bar_update_counts(tmp_path, train_batches, val_batches, expected_call_count):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
num_sanity_val_steps=0,
limit_train_batches=train_batches,
limit_val_batches=val_batches,
max_epochs=1,
callbacks=RichProgressBar(refresh_rate=refresh_rate),
callbacks=RichProgressBar(),
)

trainer.progress_bar_callback.on_train_start(trainer, model)
Expand Down Expand Up @@ -345,7 +347,8 @@ def training_step(self, *args, **kwargs):

for key in ("loss", "v_num", "train_loss"):
assert key in rendered[train_progress_bar_id][1]
assert key not in rendered[val_progress_bar_id][1]
if val_progress_bar_id in rendered:
assert key not in rendered[val_progress_bar_id][1]


def test_rich_progress_bar_metrics_fast_dev_run(tmp_path):
Expand All @@ -359,7 +362,8 @@ def test_rich_progress_bar_metrics_fast_dev_run(tmp_path):
val_progress_bar_id = progress_bar.val_progress_bar_id
rendered = progress_bar.progress.columns[-1]._renderable_cache
assert "v_num" not in rendered[train_progress_bar_id][1]
assert "v_num" not in rendered[val_progress_bar_id][1]
if val_progress_bar_id in rendered:
assert "v_num" not in rendered[val_progress_bar_id][1]


@RunIf(rich=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,17 +320,14 @@ def on_validation_epoch_end(self, *args):

def test_tqdm_progress_bar_default_value(tmp_path):
"""Test that a value of None defaults to refresh rate 1."""
trainer = Trainer(default_root_dir=tmp_path)
trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar())
assert trainer.progress_bar_callback.refresh_rate == 1


@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_tqdm_progress_bar_value_on_colab(tmp_path):
"""Test that Trainer will override the default in Google COLAB."""
trainer = Trainer(default_root_dir=tmp_path)
assert trainer.progress_bar_callback.refresh_rate == 20

trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar())
assert trainer.progress_bar_callback.refresh_rate == 20

Expand Down
Loading