@@ -56,6 +56,7 @@ def __init__(
5656 buffers : Optional [List [torch .Tensor ]] = None ,
5757 non_blocking : bool = False ,
5858 stream : Optional [torch .cuda .Stream ] = None ,
59+ record_stream : Optional [bool ] = False ,
5960 low_cpu_mem_usage = False ,
6061 onload_self : bool = True ,
6162 ) -> None :
@@ -68,11 +69,14 @@ def __init__(
6869 self .buffers = buffers or []
6970 self .non_blocking = non_blocking or stream is not None
7071 self .stream = stream
72+ self .record_stream = record_stream
7173 self .onload_self = onload_self
7274 self .low_cpu_mem_usage = low_cpu_mem_usage
73-
7475 self .cpu_param_dict = self ._init_cpu_param_dict ()
7576
77+ if self .stream is None and self .record_stream :
78+ raise ValueError ("`record_stream` cannot be True when `stream` is None." )
79+
7680 def _init_cpu_param_dict (self ):
7781 cpu_param_dict = {}
7882 if self .stream is None :
@@ -112,6 +116,8 @@ def _pinned_memory_tensors(self):
112116 def onload_ (self ):
113117 r"""Onloads the group of modules to the onload_device."""
114118 context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
119+ current_stream = torch .cuda .current_stream () if self .record_stream else None
120+
115121 if self .stream is not None :
116122 # Wait for previous Host->Device transfer to complete
117123 self .stream .synchronize ()
@@ -122,14 +128,22 @@ def onload_(self):
122128 for group_module in self .modules :
123129 for param in group_module .parameters ():
124130 param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
131+ if self .record_stream :
132+ param .data .record_stream (current_stream )
125133 for buffer in group_module .buffers ():
126134 buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
135+ if self .record_stream :
136+ buffer .data .record_stream (current_stream )
127137
128138 for param in self .parameters :
129139 param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
140+ if self .record_stream :
141+ param .data .record_stream (current_stream )
130142
131143 for buffer in self .buffers :
132144 buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
145+ if self .record_stream :
146+ buffer .data .record_stream (current_stream )
133147
134148 else :
135149 for group_module in self .modules :
@@ -143,11 +157,14 @@ def onload_(self):
143157
144158 for buffer in self .buffers :
145159 buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
160+ if self .record_stream :
161+ buffer .data .record_stream (current_stream )
146162
147163 def offload_ (self ):
148164 r"""Offloads the group of modules to the offload_device."""
149165 if self .stream is not None :
150- torch .cuda .current_stream ().synchronize ()
166+ if not self .record_stream :
167+ torch .cuda .current_stream ().synchronize ()
151168 for group_module in self .modules :
152169 for param in group_module .parameters ():
153170 param .data = self .cpu_param_dict [param ]
@@ -331,6 +348,7 @@ def apply_group_offloading(
331348 num_blocks_per_group : Optional [int ] = None ,
332349 non_blocking : bool = False ,
333350 use_stream : bool = False ,
351+ record_stream : bool = False ,
334352 low_cpu_mem_usage : bool = False ,
335353) -> None :
336354 r"""
@@ -378,6 +396,10 @@ def apply_group_offloading(
378396 use_stream (`bool`, defaults to `False`):
379397 If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
380398 overlapping computation and data transfer.
399+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
400+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
401+ [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
402+ details.
381403 low_cpu_mem_usage (`bool`, defaults to `False`):
382404 If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
383405 option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
@@ -417,11 +439,24 @@ def apply_group_offloading(
417439 raise ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
418440
419441 _apply_group_offloading_block_level (
420- module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage
442+ module = module ,
443+ num_blocks_per_group = num_blocks_per_group ,
444+ offload_device = offload_device ,
445+ onload_device = onload_device ,
446+ non_blocking = non_blocking ,
447+ stream = stream ,
448+ record_stream = record_stream ,
449+ low_cpu_mem_usage = low_cpu_mem_usage ,
421450 )
422451 elif offload_type == "leaf_level" :
423452 _apply_group_offloading_leaf_level (
424- module , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage
453+ module = module ,
454+ offload_device = offload_device ,
455+ onload_device = onload_device ,
456+ non_blocking = non_blocking ,
457+ stream = stream ,
458+ record_stream = record_stream ,
459+ low_cpu_mem_usage = low_cpu_mem_usage ,
425460 )
426461 else :
427462 raise ValueError (f"Unsupported offload_type: { offload_type } " )
@@ -434,6 +469,7 @@ def _apply_group_offloading_block_level(
434469 onload_device : torch .device ,
435470 non_blocking : bool ,
436471 stream : Optional [torch .cuda .Stream ] = None ,
472+ record_stream : Optional [bool ] = False ,
437473 low_cpu_mem_usage : bool = False ,
438474) -> None :
439475 r"""
@@ -453,6 +489,14 @@ def _apply_group_offloading_block_level(
453489 stream (`torch.cuda.Stream`, *optional*):
454490 If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
455491 for overlapping computation and data transfer.
492+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
493+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
494+ [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
495+ details.
496+ low_cpu_mem_usage (`bool`, defaults to `False`):
497+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
498+ option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
499+ the CPU memory is a bottleneck but may counteract the benefits of using streams.
456500 """
457501
458502 # Create module groups for ModuleList and Sequential blocks
@@ -475,6 +519,7 @@ def _apply_group_offloading_block_level(
475519 onload_leader = current_modules [0 ],
476520 non_blocking = non_blocking ,
477521 stream = stream ,
522+ record_stream = record_stream ,
478523 low_cpu_mem_usage = low_cpu_mem_usage ,
479524 onload_self = stream is None ,
480525 )
@@ -512,6 +557,7 @@ def _apply_group_offloading_block_level(
512557 buffers = buffers ,
513558 non_blocking = False ,
514559 stream = None ,
560+ record_stream = False ,
515561 onload_self = True ,
516562 )
517563 next_group = matched_module_groups [0 ] if len (matched_module_groups ) > 0 else None
@@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level(
524570 onload_device : torch .device ,
525571 non_blocking : bool ,
526572 stream : Optional [torch .cuda .Stream ] = None ,
573+ record_stream : Optional [bool ] = False ,
527574 low_cpu_mem_usage : bool = False ,
528575) -> None :
529576 r"""
@@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level(
545592 stream (`torch.cuda.Stream`, *optional*):
546593 If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
547594 for overlapping computation and data transfer.
595+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
596+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
597+ [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
598+ details.
599+ low_cpu_mem_usage (`bool`, defaults to `False`):
600+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
601+ option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
602+ the CPU memory is a bottleneck but may counteract the benefits of using streams.
548603 """
549604
550605 # Create module groups for leaf modules and apply group offloading hooks
@@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level(
560615 onload_leader = submodule ,
561616 non_blocking = non_blocking ,
562617 stream = stream ,
618+ record_stream = record_stream ,
563619 low_cpu_mem_usage = low_cpu_mem_usage ,
564620 onload_self = True ,
565621 )
@@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level(
605661 buffers = buffers ,
606662 non_blocking = non_blocking ,
607663 stream = stream ,
664+ record_stream = record_stream ,
608665 low_cpu_mem_usage = low_cpu_mem_usage ,
609666 onload_self = True ,
610667 )
@@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level(
624681 buffers = None ,
625682 non_blocking = False ,
626683 stream = None ,
684+ record_stream = False ,
627685 low_cpu_mem_usage = low_cpu_mem_usage ,
628686 onload_self = True ,
629687 )
0 commit comments