Skip to content

Commit d1a9515

Browse files
Lucy Liaofacebook-github-bot
authored andcommitted
Fix train-eval mode on trunk
Summary: train eval broken due to D71827944. The train pipeline automatically raises StopIteration when switching between train mode and eval mode. This diff sets _data_iter_stopped=False when switching between train mode and eval mode. Reviewed By: dragonxlwang Differential Revision: D72582264 fbshipit-source-id: ab9454858a7aa0ec30f96cdab17ed466292951c4
1 parent 074216a commit d1a9515

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ def __init__(
141141
self._connected = False
142142
self._data_iter_stopped = False
143143

144+
def _reset_data_iter(self) -> None:
145+
self._connected = False
146+
self._data_iter_stopped = False
147+
self._cur_batch = None
148+
144149
def _connect(self, dataloader_iter: Iterator[In]) -> None:
145150
cur_batch = next(dataloader_iter)
146151
self._cur_batch = cur_batch

0 commit comments

Comments
 (0)