You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The user can override this method with custom code to copy data to device. This will be called at the start of every ``train_step``/``eval_step``/``predict_step``.
@@ -230,8 +239,18 @@ def move_data_to_device(
230
239
231
240
Returns:
232
241
A batch of data which is on the device
242
+
243
+
Note:
244
+
If overriding, ensure that tensors are recorded on the compute stream to avoid the cuda cache allocator from
245
+
overwriting the underlying data before the compute stream has a chance to use it. If using `copy_data_to_device`,
246
+
you can pass `stream_to_record=self._default_stream` as an argument.
0 commit comments