Skip to content

Commit f7e53c1

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
enable prefetch in autounit (#875)
Summary: Pull Request resolved: #875 # Context Users cannot disable prefetch in auto unit # This diff Adds `enable_prefetch` flag to auto unit which can be used to disable if needed. Reviewed By: galrotem Differential Revision: D59980065 fbshipit-source-id: 2a2f2f802b8d084d495a961839773f89c8d82022
1 parent 544a225 commit f7e53c1

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

tests/framework/test_auto_unit.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,17 @@ def _assert_next_batch_dicts(
642642
},
643643
)
644644

645+
def test_enable_prefetch(self) -> None:
646+
data = [1, 2, 3]
647+
auto_unit = DummyAutoUnit(module=torch.nn.Linear(2, 2), enable_prefetch=True)
648+
649+
_ = auto_unit._get_next_batch(get_dummy_train_state(), iter(data))
650+
self.assertEqual(auto_unit._phase_to_next_batch[ActivePhase.TRAIN], 2)
651+
652+
auto_unit = DummyAutoUnit(module=torch.nn.Linear(2, 2), enable_prefetch=False)
653+
_ = auto_unit._get_next_batch(get_dummy_train_state(), iter(data))
654+
self.assertIsNone(auto_unit._phase_to_next_batch[ActivePhase.TRAIN])
655+
645656

646657
Batch = Tuple[torch.Tensor, torch.Tensor]
647658

torchtnt/framework/auto_unit.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
precision: Optional[Union[str, torch.dtype]] = None,
169169
detect_anomaly: Optional[bool] = None,
170170
torch_compile_params: Optional[TorchCompileParams] = None,
171+
enable_prefetch: bool = True,
171172
) -> None:
172173
super().__init__()
173174

@@ -189,7 +190,9 @@ def __init__(
189190

190191
# cuda stream to use for moving data to device
191192
self._prefetch_stream: Optional[torch.cuda.streams.Stream] = (
192-
torch.cuda.Stream() if self.device.type == "cuda" else None
193+
torch.cuda.Stream()
194+
if (self.device.type == "cuda" and enable_prefetch)
195+
else None
193196
)
194197
# dict mapping phase to whether the next batch which has been prefetched for that phase and is ready to be used
195198
self._phase_to_next_batch: dict[ActivePhase, Optional[TData]] = {
@@ -206,6 +209,7 @@ def __init__(
206209
}
207210
# whether the current batch is the last train batch
208211
self._is_last_batch: bool = False
212+
self._enable_prefetch = enable_prefetch
209213

210214
def move_data_to_device(
211215
self, state: State, data: TData, non_blocking: bool
@@ -253,6 +257,10 @@ def _prefetch_next_batch(self, state: State, data_iter: Iterator[TData]) -> None
253257
)
254258

255259
def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData:
260+
if not self._enable_prefetch:
261+
batch = next(data)
262+
return self.move_data_to_device(state, batch, non_blocking=False)
263+
256264
active_phase = state.active_phase
257265
if not self._phase_to_prefetched[active_phase]:
258266
self._prefetch_next_batch(state, data)
@@ -293,6 +301,7 @@ def __init__(
293301
precision: Optional[Union[str, torch.dtype]] = None,
294302
torch_compile_params: Optional[TorchCompileParams] = None,
295303
detect_anomaly: Optional[bool] = None,
304+
enable_prefetch: bool = False,
296305
) -> None:
297306
"""
298307
AutoPredictUnit is a convenience for users who are running inference and would like to have certain features handled for them, such as:
@@ -325,6 +334,7 @@ def __init__(
325334
precision=precision,
326335
torch_compile_params=torch_compile_params,
327336
detect_anomaly=detect_anomaly,
337+
enable_prefetch=enable_prefetch,
328338
)
329339
self.module: torch.nn.Module = prepare_module(
330340
module,
@@ -435,9 +445,10 @@ class AutoUnit(
435445
training: if True, the optimizer and optionally LR scheduler will be created after the class is initialized.
436446
enable_compiled_autograd: if True, `compiled_autograd` will be used to compile the backward, this is an experimental flag.
437447
loss_backward_retain_graph: If ``None`` or ``False``, the graph used to compute
438-
the grads will be freed during loss backward pass. Note that in nearly all cases setting
439-
this option to True is not needed and often can be worked around
440-
in a much more efficient way.
448+
the grads will be freed during loss backward pass. Note that in nearly all cases setting
449+
this option to True is not needed and often can be worked around
450+
in a much more efficient way.
451+
enable_prefetch: if True, the data will be prefetched to the device before the next batch is loaded
441452
442453
Note:
443454
Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well.
@@ -468,13 +479,15 @@ def __init__(
468479
training: bool = True,
469480
enable_compiled_autograd: bool = False,
470481
loss_backward_retain_graph: Optional[bool] = None,
482+
enable_prefetch: bool = True,
471483
) -> None:
472484
super().__init__(
473485
module=module,
474486
device=device,
475487
precision=precision,
476488
detect_anomaly=detect_anomaly,
477489
torch_compile_params=torch_compile_params,
490+
enable_prefetch=enable_prefetch,
478491
)
479492

480493
if not gradient_accumulation_steps > 0:

0 commit comments

Comments
 (0)