Skip to content

Commit 4c6a884

Browse files
alanhdufacebook-github-bot
authored andcommitted
Allow setting mininterval for TQDM progress bar (#987)
Summary: Pull Request resolved: #987 This uses `tqdm`'s exposes built-in `mininterval` to the callback -- this allows the user to control the TQDM update rate based on the number of seconds rather than the number of iteration steps. Reviewed By: JKSenthil Differential Revision: D72068311 fbshipit-source-id: 91105da3db1dbadac5b89f9865406f75035dccb6
1 parent d050dcd commit 4c6a884

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

torchtnt/framework/callbacks/tqdm_progress_bar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,19 @@ class TQDMProgressBar(Callback):
3030
3131
Args:
3232
refresh_rate: Determines at which rate (in number of steps) the progress bars get updated.
33+
mininterval: Minimum display update interval (in seconds). If None, use TQDM's default.
3334
file: specifies where to output the progress messages (default: sys.stderr)
3435
"""
3536

3637
def __init__(
3738
self,
3839
refresh_rate: int = 1,
3940
file: Optional[Union[TextIO, io.StringIO]] = None,
41+
*,
42+
mininterval: float | None = None,
4043
) -> None:
4144
self._refresh_rate = refresh_rate
45+
self._mininterval = mininterval
4246
self._file = file
4347

4448
self._train_progress_bar: Optional[tqdm] = None
@@ -56,6 +60,7 @@ def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
5660
max_steps=train_state.max_steps,
5761
max_steps_per_epoch=train_state.max_steps_per_epoch,
5862
file=self._file,
63+
mininterval=self._mininterval,
5964
)
6065

6166
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
@@ -87,6 +92,7 @@ def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
8792
max_steps=eval_state.max_steps,
8893
max_steps_per_epoch=eval_state.max_steps_per_epoch,
8994
file=self._file,
95+
mininterval=self._mininterval,
9096
)
9197

9298
def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:

torchtnt/utils/tqdm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def create_progress_bar(
2525
num_steps_completed: int,
2626
max_steps: Optional[int],
2727
max_steps_per_epoch: Optional[int],
28+
mininterval: float | None = None,
2829
file: Optional[Union[TextIO, io.StringIO]] = None,
2930
) -> tqdm:
3031
"""Constructs a :func:`tqdm` progress bar. The number of steps in an epoch is inferred from the dataloader, num_steps_completed, max_steps and max_steps_per_epoch.
@@ -36,6 +37,7 @@ def create_progress_bar(
3637
num_steps_completed: an integer for the number of steps completed so far in the loop.
3738
max_steps: an optional integer for the number of max steps in the loop.
3839
max_steps_per_epoch: an optional integer for the number of max steps per epoch.
40+
mininterval: Minimum display update interval (in seconds). If None, use TQDM's default.
3941
file: specifies where to output the progress messages (default: sys.stderr)
4042
"""
4143
current_epoch = num_epochs_completed
@@ -45,12 +47,16 @@ def create_progress_bar(
4547
max_steps=max_steps,
4648
max_steps_per_epoch=max_steps_per_epoch,
4749
)
50+
kwargs = {}
51+
if mininterval is not None:
52+
kwargs["mininterval"] = mininterval
4853
return tqdm(
4954
desc=f"{desc} {current_epoch}",
5055
total=total,
5156
initial=num_steps_completed,
5257
bar_format="{l_bar}{bar}{r_bar}\n",
5358
file=file,
59+
**kwargs,
5460
)
5561

5662

0 commit comments

Comments
 (0)