Skip to content

Commit 89b12e6

Browse files
Xiaolong Wangfacebook-github-bot
authored andcommitted
enable be (#2851)
Summary: Pull Request resolved: #2851 setup using lsr as input; need to change once input is ready; on exhaustiveness: - dataloader requires `drop_incomplete=False` - disable fullsync and enable exhaustiveness - torchrec train_pipeline fix which didn't account for exhaustiveness need to follow up FI behavior when in BE: - if cache is never updated (0 in standalone BE or <itrn when training finishes), update - if cache is updated to date, use it Differential Revision: D71827944 Privacy Context Container: L1292699 fbshipit-source-id: 1460a9673845decdeecc53d0fc292dd360125146
1 parent 577d4bd commit 89b12e6

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,24 @@ def __init__(
139139
)
140140
self._cur_batch: Optional[In] = None
141141
self._connected = False
142+
self._data_iter_stopped = False
142143

143144
def _connect(self, dataloader_iter: Iterator[In]) -> None:
144145
cur_batch = next(dataloader_iter)
145146
self._cur_batch = cur_batch
146-
with self._stream_context(self._memcpy_stream):
147-
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
147+
if cur_batch is not None:
148+
with self._stream_context(self._memcpy_stream):
149+
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
148150
self._connected = True
149151

150-
def _next_batch(self, dataloader_iter: Iterator[In]) -> In:
152+
def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
151153
with record_function("## next_batch ##"):
152-
next_batch = next(dataloader_iter)
154+
try:
155+
next_batch = next(dataloader_iter)
156+
except StopIteration:
157+
self._data_iter_stopped = True
158+
return None
159+
153160
return next_batch
154161

155162
def _wait_for_batch(self, cur_batch: In) -> None:
@@ -168,18 +175,26 @@ def _copy_batch_to_gpu(self, cur_batch: In) -> None:
168175
def progress(self, dataloader_iter: Iterator[In]) -> Out:
169176
if not self._connected:
170177
self._connect(dataloader_iter)
178+
if self._data_iter_stopped:
179+
raise StopIteration()
171180

172-
# Fetch next batch
181+
# Fetch next batch, if depleted, raise at start of next progress
173182
next_batch = self._next_batch(dataloader_iter)
174183
cur_batch = self._cur_batch
175-
assert cur_batch is not None
184+
185+
# for exhaustive data iter, some ranks will first depletes data,
186+
# but we still need progress the train pipeline for other ranks;
187+
# cur_batch could be None
176188

177189
if self._model.training:
178190
with record_function("## zero_grad ##"):
179191
self._optimizer.zero_grad()
180192

181-
self._wait_for_batch(cur_batch)
193+
if cur_batch is not None:
194+
self._wait_for_batch(cur_batch)
182195

196+
# model will need to handle if cur_batch is empty; this is needed if there's
197+
# communicative ops
183198
with record_function("## forward ##"):
184199
(losses, output) = self._model(cur_batch)
185200

@@ -188,7 +203,8 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
188203

189204
# Copy the next batch to GPU
190205
self._cur_batch = cur_batch = next_batch
191-
self._copy_batch_to_gpu(cur_batch)
206+
if cur_batch is not None:
207+
self._copy_batch_to_gpu(cur_batch)
192208

193209
# Update
194210
if self._model.training:

0 commit comments

Comments
 (0)