Skip to content

Commit de119c5

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
do record_data_in_stream step in copy_data_to_device (#956)
Summary: Pull Request resolved: #956 Reviewed By: galrotem Differential Revision: D67719965 fbshipit-source-id: 71dde3aaf42f70bd6fa79ce5634f4ccea3d4e6e2
1 parent c8a8e76 commit de119c5

File tree

2 files changed

+39
-14
lines changed

2 files changed

+39
-14
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torchtnt.framework.state import ActivePhase, EntryPoint, State
3434
from torchtnt.framework.unit import EvalUnit, PredictUnit, TPredictData, TrainUnit
3535
from torchtnt.framework.utils import get_timing_context
36-
from torchtnt.utils.device import copy_data_to_device, record_data_in_stream
36+
from torchtnt.utils.device import copy_data_to_device
3737
from torchtnt.utils.env import init_from_env
3838
from torchtnt.utils.lr_scheduler import TLRScheduler
3939
from torchtnt.utils.precision import (
@@ -191,6 +191,12 @@ def __init__(
191191
enabled=self.precision is not None,
192192
)
193193

194+
# main stream responsible for computation on the device
195+
self._default_stream: Optional[torch.cuda.streams.Stream] = (
196+
torch.cuda.current_stream()
197+
if (self.device.type == "cuda" and enable_prefetch)
198+
else None
199+
)
194200
# cuda stream to use for moving data to device
195201
self._prefetch_stream: Optional[torch.cuda.streams.Stream] = (
196202
torch.cuda.Stream()
@@ -215,7 +221,10 @@ def __init__(
215221
self._enable_prefetch = enable_prefetch
216222

217223
def move_data_to_device(
218-
self, state: State, data: TData, non_blocking: bool
224+
self,
225+
state: State,
226+
data: TData,
227+
non_blocking: bool,
219228
) -> TData:
220229
"""
221230
The user can override this method with custom code to copy data to device. This will be called at the start of every ``train_step``/``eval_step``/``predict_step``.
@@ -230,8 +239,18 @@ def move_data_to_device(
230239
231240
Returns:
232241
A batch of data which is on the device
242+
243+
Note:
244+
If overriding, ensure that tensors are recorded on the compute stream to avoid the cuda cache allocator from
245+
overwriting the underlying data before the compute stream has a chance to use it. If using `copy_data_to_device`,
246+
you can pass `stream_to_record=self._default_stream` as an argument.
233247
"""
234-
return copy_data_to_device(data, self.device, non_blocking=non_blocking)
248+
return copy_data_to_device(
249+
data,
250+
self.device,
251+
non_blocking=non_blocking,
252+
stream_to_record=self._default_stream,
253+
)
235254

236255
def _prefetch_next_batch(self, state: State, data_iter: Iterator[TData]) -> None:
237256
"""Prefetch the next batch on a separate CUDA stream."""
@@ -256,7 +275,9 @@ def _prefetch_next_batch(self, state: State, data_iter: Iterator[TData]) -> None
256275
state, f"{self.__class__.__name__}.{phase}.move_data_to_device"
257276
):
258277
self._phase_to_next_batch[active_phase] = self.move_data_to_device(
259-
state, next_batch, non_blocking=non_blocking
278+
state,
279+
next_batch,
280+
non_blocking=non_blocking,
260281
)
261282

262283
def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData:
@@ -281,13 +302,6 @@ def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData:
281302
self._is_last_batch = False
282303
raise StopIteration
283304

284-
if self._prefetch_stream:
285-
with get_timing_context(
286-
state, f"{self.__class__.__name__}.record_data_in_stream"
287-
):
288-
# record the batch in the current stream
289-
record_data_in_stream(batch, torch.cuda.current_stream())
290-
291305
# prefetch the next batch
292306
self._prefetch_next_batch(state, data)
293307

torchtnt/utils/device.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import subprocess
1414
from collections import defaultdict
1515
from dataclasses import fields, is_dataclass
16-
from typing import Any, Dict, Mapping, TypeVar
16+
from typing import Any, Dict, Mapping, Optional, TypeVar
1717

1818
import torch
1919
from typing_extensions import Protocol, runtime_checkable, TypedDict
@@ -56,12 +56,20 @@ def _is_named_tuple(x: T) -> bool:
5656
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
5757

5858

59-
def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any) -> T:
59+
def copy_data_to_device(
60+
data: T,
61+
device: torch.device,
62+
stream_to_record: Optional[torch.cuda.Stream] = None,
63+
*args: Any,
64+
**kwargs: Any,
65+
) -> T:
6066
"""Function that recursively copies data to a torch.device.
6167
6268
Args:
6369
data: The data to copy to device
6470
device: The device to which the data should be copied
71+
stream_to_record: The CUDA stream to which the data should be recorded. Useful if this function is called
72+
on side stream, and the data is expected to be used on the main stream.
6573
args: positional arguments that will be passed to the `to` call
6674
kwargs: keyword arguments that will be passed to the `to` call
6775
@@ -116,7 +124,10 @@ def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any
116124
return new_data_class
117125
elif hasattr(data, "to"):
118126
# pyre-ignore Undefined attribute [16]: `Variable[T]` has no attribute `to`
119-
return data.to(device, *args, **kwargs)
127+
gpu_data = data.to(device, *args, **kwargs)
128+
if stream_to_record is not None and hasattr(gpu_data, "record_stream"):
129+
gpu_data.record_stream(stream_to_record)
130+
return gpu_data
120131

121132
return data
122133

0 commit comments

Comments
 (0)