@@ -95,7 +95,7 @@ def __init__(
9595 self .offload_to_disk_path = offload_to_disk_path
9696 self ._is_offloaded_to_disk = False
9797
98- if self .offload_to_disk_path :
98+ if self .offload_to_disk_path is not None :
9999 # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
100100 self .group_id = group_id if group_id is not None else str (id (self ))
101101 short_hash = _compute_group_hash (self .group_id )
@@ -115,6 +115,12 @@ def __init__(
115115 else :
116116 self .cpu_param_dict = self ._init_cpu_param_dict ()
117117
118+ self ._torch_accelerator_module = (
119+ getattr (torch , torch .accelerator .current_accelerator ().type )
120+ if hasattr (torch , "accelerator" )
121+ else torch .cuda
122+ )
123+
118124 def _init_cpu_param_dict (self ):
119125 cpu_param_dict = {}
120126 if self .stream is None :
@@ -138,112 +144,76 @@ def _init_cpu_param_dict(self):
138144
139145 @contextmanager
140146 def _pinned_memory_tensors (self ):
141- pinned_dict = {}
142147 try :
143- for param , tensor in self .cpu_param_dict .items ():
144- if not tensor .is_pinned ():
145- pinned_dict [param ] = tensor .pin_memory ()
146- else :
147- pinned_dict [param ] = tensor
148-
148+ pinned_dict = {
149+ param : tensor .pin_memory () if not tensor .is_pinned () else tensor
150+ for param , tensor in self .cpu_param_dict .items ()
151+ }
149152 yield pinned_dict
150-
151153 finally :
152154 pinned_dict = None
153155
154- def _transfer_tensor_to_device (self , tensor , source_tensor , current_stream = None ):
156+ def _transfer_tensor_to_device (self , tensor , source_tensor ):
155157 tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
156- if self .record_stream and current_stream is not None :
157- tensor .data .record_stream (current_stream )
158+ if self .record_stream :
159+ tensor .data .record_stream (self . _torch_accelerator_module . current_stream () )
158160
159- def _process_tensors_from_modules (self , pinned_memory = None , current_stream = None ):
161+ def _process_tensors_from_modules (self , pinned_memory = None ):
160162 for group_module in self .modules :
161163 for param in group_module .parameters ():
162164 source = pinned_memory [param ] if pinned_memory else param .data
163- self ._transfer_tensor_to_device (param , source , current_stream )
165+ self ._transfer_tensor_to_device (param , source )
164166 for buffer in group_module .buffers ():
165167 source = pinned_memory [buffer ] if pinned_memory else buffer .data
166- self ._transfer_tensor_to_device (buffer , source , current_stream )
168+ self ._transfer_tensor_to_device (buffer , source )
167169
168170 for param in self .parameters :
169171 source = pinned_memory [param ] if pinned_memory else param .data
170- self ._transfer_tensor_to_device (param , source , current_stream )
172+ self ._transfer_tensor_to_device (param , source )
171173
172174 for buffer in self .buffers :
173175 source = pinned_memory [buffer ] if pinned_memory else buffer .data
174- self ._transfer_tensor_to_device (buffer , source , current_stream )
176+ self ._transfer_tensor_to_device (buffer , source )
175177
176- def _onload_from_disk (self , current_stream ):
178+ def _onload_from_disk (self ):
177179 if self .stream is not None :
178- loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
179-
180- for key , tensor_obj in self .key_to_tensor .items ():
181- self .cpu_param_dict [tensor_obj ] = loaded_cpu_tensors [key ]
182-
183- with self ._pinned_memory_tensors () as pinned_memory :
184- for key , tensor_obj in self .key_to_tensor .items ():
185- self ._transfer_tensor_to_device (tensor_obj , pinned_memory [tensor_obj ], current_stream )
186-
187- self .cpu_param_dict .clear ()
180+ # Wait for previous Host->Device transfer to complete
181+ self .stream .synchronize ()
188182
189- else :
190- onload_device = (
191- self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
192- )
193- loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
194- for key , tensor_obj in self .key_to_tensor .items ():
195- tensor_obj .data = loaded_tensors [key ]
183+ context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
184+ current_stream = self ._torch_accelerator_module .current_stream () if self .record_stream else None
196185
197- def _onload_from_memory (self , current_stream ):
198- if self .stream is not None :
199- with self ._pinned_memory_tensors () as pinned_memory :
200- self ._process_tensors_from_modules (pinned_memory , current_stream )
201- else :
202- self ._process_tensors_from_modules (None , current_stream )
203-
204- @torch .compiler .disable ()
205- def onload_ (self ):
206- torch_accelerator_module = (
207- getattr (torch , torch .accelerator .current_accelerator ().type )
208- if hasattr (torch , "accelerator" )
209- else torch .cuda
210- )
211- context = nullcontext () if self .stream is None else torch_accelerator_module .stream (self .stream )
212- current_stream = torch_accelerator_module .current_stream () if self .record_stream else None
186+ with context :
187+ # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
188+ device = str (self .onload_device ) if self .stream is None else "cpu"
189+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = device )
213190
214- if self .offload_to_disk_path :
215191 if self .stream is not None :
216- # Wait for previous Host->Device transfer to complete
217- self .stream .synchronize ()
218-
219- with context :
220- if self .stream is not None :
221- # Load to CPU, pin, and async copy to device for overlapping transfer and compute
222- loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
223- for key , tensor_obj in self .key_to_tensor .items ():
224- pinned_tensor = loaded_cpu_tensors [key ].pin_memory ()
225- tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
226- if self .record_stream :
227- tensor_obj .data .record_stream (current_stream )
228- else :
229- # Load directly to the target device (synchronous)
230- onload_device = (
231- self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
232- )
233- loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
234- for key , tensor_obj in self .key_to_tensor .items ():
235- tensor_obj .data = loaded_tensors [key ]
236- return
192+ for key , tensor_obj in self .key_to_tensor .items ():
193+ pinned_tensor = loaded_tensors [key ].pin_memory ()
194+ tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
195+ if self .record_stream :
196+ tensor_obj .data .record_stream (current_stream )
197+ else :
198+ onload_device = (
199+ self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
200+ )
201+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
202+ for key , tensor_obj in self .key_to_tensor .items ():
203+ tensor_obj .data = loaded_tensors [key ]
237204
205+ def _onload_from_memory (self ):
238206 if self .stream is not None :
239207 # Wait for previous Host->Device transfer to complete
240208 self .stream .synchronize ()
241209
210+ context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
242211 with context :
243- if self .offload_to_disk_path :
244- self ._onload_from_disk (current_stream )
212+ if self .stream is not None :
213+ with self ._pinned_memory_tensors () as pinned_memory :
214+ self ._process_tensors_from_modules (pinned_memory )
245215 else :
246- self ._onload_from_memory ( current_stream )
216+ self ._process_tensors_from_modules ( None )
247217
248218 def _offload_to_disk (self ):
249219 # TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -264,14 +234,10 @@ def _offload_to_disk(self):
264234 tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
265235
266236 def _offload_to_memory (self ):
267- torch_accelerator_module = (
268- getattr (torch , torch .accelerator .current_accelerator ().type )
269- if hasattr (torch , "accelerator" )
270- else torch .cuda
271- )
272237 if self .stream is not None :
273238 if not self .record_stream :
274- torch_accelerator_module .current_stream ().synchronize ()
239+ self ._torch_accelerator_module .current_stream ().synchronize ()
240+
275241 for group_module in self .modules :
276242 for param in group_module .parameters ():
277243 param .data = self .cpu_param_dict [param ]
@@ -282,15 +248,23 @@ def _offload_to_memory(self):
282248
283249 else :
284250 for group_module in self .modules :
285- group_module .to (self .offload_device , non_blocking = self . non_blocking )
251+ group_module .to (self .offload_device , non_blocking = False )
286252 for param in self .parameters :
287- param .data = param .data .to (self .offload_device , non_blocking = self . non_blocking )
253+ param .data = param .data .to (self .offload_device , non_blocking = False )
288254 for buffer in self .buffers :
289- buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
255+ buffer .data = buffer .data .to (self .offload_device , non_blocking = False )
256+
257+ @torch .compiler .disable ()
258+ def onload_ (self ):
259+ r"""Onloads the group of parameters to the onload_device."""
260+ if self .offload_to_disk_path is not None :
261+ self ._onload_from_disk ()
262+ else :
263+ self ._onload_from_memory ()
290264
291265 @torch .compiler .disable ()
292266 def offload_ (self ):
293- r"""Offloads the group of modules to the offload_device."""
267+ r"""Offloads the group of parameters to the offload_device."""
294268 if self .offload_to_disk_path :
295269 self ._offload_to_disk ()
296270 else :
@@ -307,11 +281,9 @@ class GroupOffloadingHook(ModelHook):
307281
308282 _is_stateful = False
309283
310- def __init__ (
311- self , group : ModuleGroup , next_group : Optional [ModuleGroup ] = None , * , config : GroupOffloadingConfig
312- ) -> None :
284+ def __init__ (self , group : ModuleGroup , * , config : GroupOffloadingConfig ) -> None :
313285 self .group = group
314- self .next_group = next_group
286+ self .next_group : Optional [ ModuleGroup ] = None
315287 self .config = config
316288
317289 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
@@ -459,8 +431,8 @@ def pre_forward(self, module, *args, **kwargs):
459431
460432def apply_group_offloading (
461433 module : torch .nn .Module ,
462- onload_device : torch .device ,
463- offload_device : torch .device = torch .device ("cpu" ),
434+ onload_device : Union [ str , torch .device ] ,
435+ offload_device : Union [ str , torch .device ] = torch .device ("cpu" ),
464436 offload_type : Union [str , GroupOffloadingType ] = "block_level" ,
465437 num_blocks_per_group : Optional [int ] = None ,
466438 non_blocking : bool = False ,
@@ -546,6 +518,8 @@ def apply_group_offloading(
546518 ```
547519 """
548520
521+ onload_device = torch .device (onload_device ) if isinstance (onload_device , str ) else onload_device
522+ offload_device = torch .device (offload_device ) if isinstance (offload_device , str ) else offload_device
549523 offload_type = GroupOffloadingType (offload_type )
550524
551525 stream = None
@@ -633,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
633607 # Apply group offloading hooks to the module groups
634608 for i , group in enumerate (matched_module_groups ):
635609 for group_module in group .modules :
636- _apply_group_offloading_hook (group_module , group , None , config = config )
610+ _apply_group_offloading_hook (group_module , group , config = config )
637611
638612 # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
639613 # when the forward pass of this module is called. This is because the top-level module is not
@@ -662,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
662636 group_id = f"{ module .__class__ .__name__ } _unmatched_group" ,
663637 )
664638 if config .stream is None :
665- _apply_group_offloading_hook (module , unmatched_group , None , config = config )
639+ _apply_group_offloading_hook (module , unmatched_group , config = config )
666640 else :
667- _apply_lazy_group_offloading_hook (module , unmatched_group , None , config = config )
641+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
668642
669643
670644def _apply_group_offloading_leaf_level (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
@@ -693,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
693667 onload_self = True ,
694668 group_id = name ,
695669 )
696- _apply_group_offloading_hook (submodule , group , None , config = config )
670+ _apply_group_offloading_hook (submodule , group , config = config )
697671 modules_with_group_offloading .add (name )
698672
699673 # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -740,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
740714 onload_self = True ,
741715 group_id = name ,
742716 )
743- _apply_group_offloading_hook (parent_module , group , None , config = config )
717+ _apply_group_offloading_hook (parent_module , group , config = config )
744718
745719 if config .stream is not None :
746720 # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -762,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
762736 onload_self = True ,
763737 group_id = _GROUP_ID_LAZY_LEAF ,
764738 )
765- _apply_lazy_group_offloading_hook (module , unmatched_group , None , config = config )
739+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
766740
767741
768742def _apply_group_offloading_hook (
769743 module : torch .nn .Module ,
770744 group : ModuleGroup ,
771- next_group : Optional [ModuleGroup ] = None ,
772745 * ,
773746 config : GroupOffloadingConfig ,
774747) -> None :
@@ -777,14 +750,13 @@ def _apply_group_offloading_hook(
777750 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
778751 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
779752 if registry .get_hook (_GROUP_OFFLOADING ) is None :
780- hook = GroupOffloadingHook (group , next_group , config = config )
753+ hook = GroupOffloadingHook (group , config = config )
781754 registry .register_hook (hook , _GROUP_OFFLOADING )
782755
783756
784757def _apply_lazy_group_offloading_hook (
785758 module : torch .nn .Module ,
786759 group : ModuleGroup ,
787- next_group : Optional [ModuleGroup ] = None ,
788760 * ,
789761 config : GroupOffloadingConfig ,
790762) -> None :
@@ -793,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
793765 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
794766 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
795767 if registry .get_hook (_GROUP_OFFLOADING ) is None :
796- hook = GroupOffloadingHook (group , next_group , config = config )
768+ hook = GroupOffloadingHook (group , config = config )
797769 registry .register_hook (hook , _GROUP_OFFLOADING )
798770
799771 lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook ()
0 commit comments