Skip to content

Commit a577dd4

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
sync epoch number print in train.py (#905)
Summary: Pull Request resolved: #905 # Context In train.py, we log the number of steps underwent in a epoch here: https://www.internalfb.com/code/fbsource/[9ef12f3ec86db52d063788e881d9f9d3f1209b7e]/fbcode/torchtnt/framework/train.py?lines=276-279&base=d2dcd7009ea1ca4cfb2a5ba2025caa2ec04a7e7a The printed epoch is intentionally +1 since internally torchtnt starts the epoch at 0. However, in a prior print which logs reason why epoch finished, the epoch is not +1. https://www.internalfb.com/code/fbsource/[9ef12f3ec86db52d063788e881d9f9d3f1209b7e]/fbcode/torchtnt/framework/train.py?lines=258-264&base=d2dcd7009ea1ca4cfb2a5ba2025caa2ec04a7e7a # This Diff Add +1 so both prints are synced Reviewed By: anshulverma, vbourgin Differential Revision: D63557368 fbshipit-source-id: 48e9a4a1ebe63abd07354402aefa6c750245bfda
1 parent d86828b commit a577dd4

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tests/framework/test_loop_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,17 +199,17 @@ def test_log_reason_epoch_completed(self) -> None:
199199
p, max_steps_per_epoch=5, max_steps=None, stop_iteration_reached=False
200200
)
201201
self.assertEqual(
202-
reason, "Train epoch 2 ended as max steps per epoch reached: 5"
202+
reason, "Train epoch 3 ended as max steps per epoch reached: 5"
203203
)
204204

205205
reason = _reason_epoch_completed(
206206
p, max_steps_per_epoch=6, max_steps=100, stop_iteration_reached=False
207207
)
208-
self.assertEqual(reason, "Train epoch 2 ended as max steps reached: 100")
208+
self.assertEqual(reason, "Train epoch 3 ended as max steps reached: 100")
209209

210210
reason = _reason_epoch_completed(
211211
p, max_steps_per_epoch=5, max_steps=None, stop_iteration_reached=True
212212
)
213213
self.assertEqual(
214-
reason, "Train epoch 2 ended as it reached end of train dataloader"
214+
reason, "Train epoch 3 ended as it reached end of train dataloader"
215215
)

torchtnt/framework/_loop_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _reason_epoch_completed(
4949
max_steps: Optional[int],
5050
stop_iteration_reached: bool,
5151
) -> str:
52-
current_epoch = progress.num_epochs_completed
52+
current_epoch = progress.num_epochs_completed + 1
5353
if stop_iteration_reached:
5454
return (
5555
f"Train epoch {current_epoch} ended as it reached end of train dataloader"

0 commit comments

Comments
 (0)