1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15+ import  os 
1516from  contextlib  import  contextmanager , nullcontext 
1617from  typing  import  Dict , List , Optional , Set , Tuple , Union 
1718
19+ import  safetensors .torch 
1820import  torch 
1921
2022from  ..utils  import  get_logger , is_accelerate_available 
@@ -59,6 +61,7 @@ def __init__(
5961        record_stream : Optional [bool ] =  False ,
6062        low_cpu_mem_usage : bool  =  False ,
6163        onload_self : bool  =  True ,
64+         offload_to_disk_path : Optional [str ] =  None ,
6265    ) ->  None :
6366        self .modules  =  modules 
6467        self .offload_device  =  offload_device 
@@ -72,7 +75,26 @@ def __init__(
7275        self .record_stream  =  record_stream 
7376        self .onload_self  =  onload_self 
7477        self .low_cpu_mem_usage  =  low_cpu_mem_usage 
75-         self .cpu_param_dict  =  self ._init_cpu_param_dict ()
78+ 
79+         self .offload_to_disk_path  =  offload_to_disk_path 
80+         self ._is_offloaded_to_disk  =  False 
81+ 
82+         if  self .offload_to_disk_path :
83+             self .safetensors_file_path  =  os .path .join (self .offload_to_disk_path , f"group_{ id (self )}  )
84+ 
85+             all_tensors  =  []
86+             for  module  in  self .modules :
87+                 all_tensors .extend (list (module .parameters ()))
88+                 all_tensors .extend (list (module .buffers ()))
89+             all_tensors .extend (self .parameters )
90+             all_tensors .extend (self .buffers )
91+             all_tensors  =  list (dict .fromkeys (all_tensors ))  # Remove duplicates 
92+ 
93+             self .tensor_to_key  =  {tensor : f"tensor_{ i }   for  i , tensor  in  enumerate (all_tensors )}
94+             self .key_to_tensor  =  {v : k  for  k , v  in  self .tensor_to_key .items ()}
95+             self .cpu_param_dict  =  {}
96+         else :
97+             self .cpu_param_dict  =  self ._init_cpu_param_dict ()
7698
7799        if  self .stream  is  None  and  self .record_stream :
78100            raise  ValueError ("`record_stream` cannot be True when `stream` is None." )
@@ -124,6 +146,30 @@ def onload_(self):
124146        context  =  nullcontext () if  self .stream  is  None  else  torch_accelerator_module .stream (self .stream )
125147        current_stream  =  torch_accelerator_module .current_stream () if  self .record_stream  else  None 
126148
149+         if  self .offload_to_disk_path :
150+             if  self .stream  is  not None :
151+                 # Wait for previous Host->Device transfer to complete 
152+                 self .stream .synchronize ()
153+ 
154+             with  context :
155+                 if  self .stream  is  not None :
156+                     # Load to CPU, pin, and async copy to device for overlapping transfer and compute 
157+                     loaded_cpu_tensors  =  safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
158+                     for  key , tensor_obj  in  self .key_to_tensor .items ():
159+                         pinned_tensor  =  loaded_cpu_tensors [key ].pin_memory ()
160+                         tensor_obj .data  =  pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
161+                         if  self .record_stream :
162+                             tensor_obj .data .record_stream (current_stream )
163+                 else :
164+                     # Load directly to the target device (synchronous) 
165+                     onload_device  =  (
166+                         self .onload_device .type  if  isinstance (self .onload_device , torch .device ) else  self .onload_device 
167+                     )
168+                     loaded_tensors  =  safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
169+                     for  key , tensor_obj  in  self .key_to_tensor .items ():
170+                         tensor_obj .data  =  loaded_tensors [key ]
171+             return 
172+ 
127173        if  self .stream  is  not None :
128174            # Wait for previous Host->Device transfer to complete 
129175            self .stream .synchronize ()
@@ -169,6 +215,26 @@ def onload_(self):
169215    @torch .compiler .disable () 
170216    def  offload_ (self ):
171217        r"""Offloads the group of modules to the offload_device.""" 
218+         if  self .offload_to_disk_path :
219+             # TODO: we can potentially optimize this code path by checking if the _all_ the desired 
220+             # safetensor files exist on the disk and if so, skip this step entirely, reducing IO 
221+             # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not 
222+             # we perform a write. 
223+             # Check if the file has been saved in this session or if it already exists on disk. 
224+             if  not  self ._is_offloaded_to_disk  and  not  os .path .exists (self .safetensors_file_path ):
225+                 os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
226+                 tensors_to_save  =  {
227+                     key : tensor .data .to (self .offload_device ) for  tensor , key  in  self .tensor_to_key .items ()
228+                 }
229+                 safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
230+ 
231+             # The group is now considered offloaded to disk for the rest of the session. 
232+             self ._is_offloaded_to_disk  =  True 
233+ 
234+             # We do this to free up the RAM which is still holding the up tensor data. 
235+             for  tensor_obj  in  self .tensor_to_key .keys ():
236+                 tensor_obj .data  =  torch .empty_like (tensor_obj .data , device = self .offload_device )
237+             return 
172238
173239        torch_accelerator_module  =  (
174240            getattr (torch , torch .accelerator .current_accelerator ().type )
@@ -205,11 +271,7 @@ class GroupOffloadingHook(ModelHook):
205271
206272    _is_stateful  =  False 
207273
208-     def  __init__ (
209-         self ,
210-         group : ModuleGroup ,
211-         next_group : Optional [ModuleGroup ] =  None ,
212-     ) ->  None :
274+     def  __init__ (self , group : ModuleGroup , next_group : Optional [ModuleGroup ] =  None ) ->  None :
213275        self .group  =  group 
214276        self .next_group  =  next_group 
215277
@@ -363,6 +425,7 @@ def apply_group_offloading(
363425    use_stream : bool  =  False ,
364426    record_stream : bool  =  False ,
365427    low_cpu_mem_usage : bool  =  False ,
428+     offload_to_disk_path : Optional [str ] =  None ,
366429) ->  None :
367430    r""" 
368431    Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and 
@@ -401,6 +464,9 @@ def apply_group_offloading(
401464        offload_type (`str`, defaults to "block_level"): 
402465            The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is 
403466            "block_level". 
467+         offload_to_disk_path (`str`, *optional*, defaults to `None`): 
468+             The path to the directory where parameters will be offloaded. Setting this option can be useful in limited 
469+             RAM environment settings where a reasonable speed-memory trade-off is desired. 
404470        num_blocks_per_group (`int`, *optional*): 
405471            The number of blocks per group when using offload_type="block_level". This is required when using 
406472            offload_type="block_level". 
@@ -458,6 +524,7 @@ def apply_group_offloading(
458524            num_blocks_per_group = num_blocks_per_group ,
459525            offload_device = offload_device ,
460526            onload_device = onload_device ,
527+             offload_to_disk_path = offload_to_disk_path ,
461528            non_blocking = non_blocking ,
462529            stream = stream ,
463530            record_stream = record_stream ,
@@ -468,6 +535,7 @@ def apply_group_offloading(
468535            module = module ,
469536            offload_device = offload_device ,
470537            onload_device = onload_device ,
538+             offload_to_disk_path = offload_to_disk_path ,
471539            non_blocking = non_blocking ,
472540            stream = stream ,
473541            record_stream = record_stream ,
@@ -486,6 +554,7 @@ def _apply_group_offloading_block_level(
486554    stream : Union [torch .cuda .Stream , torch .Stream , None ] =  None ,
487555    record_stream : Optional [bool ] =  False ,
488556    low_cpu_mem_usage : bool  =  False ,
557+     offload_to_disk_path : Optional [str ] =  None ,
489558) ->  None :
490559    r""" 
491560    This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to 
@@ -496,6 +565,9 @@ def _apply_group_offloading_block_level(
496565            The module to which group offloading is applied. 
497566        offload_device (`torch.device`): 
498567            The device to which the group of modules are offloaded. This should typically be the CPU. 
568+         offload_to_disk_path (`str`, *optional*, defaults to `None`): 
569+             The path to the directory where parameters will be offloaded. Setting this option can be useful in limited 
570+             RAM environment settings where a reasonable speed-memory trade-off is desired. 
499571        onload_device (`torch.device`): 
500572            The device to which the group of modules are onloaded. 
501573        non_blocking (`bool`): 
@@ -535,6 +607,7 @@ def _apply_group_offloading_block_level(
535607                modules = current_modules ,
536608                offload_device = offload_device ,
537609                onload_device = onload_device ,
610+                 offload_to_disk_path = offload_to_disk_path ,
538611                offload_leader = current_modules [- 1 ],
539612                onload_leader = current_modules [0 ],
540613                non_blocking = non_blocking ,
@@ -567,6 +640,7 @@ def _apply_group_offloading_block_level(
567640        modules = unmatched_modules ,
568641        offload_device = offload_device ,
569642        onload_device = onload_device ,
643+         offload_to_disk_path = offload_to_disk_path ,
570644        offload_leader = module ,
571645        onload_leader = module ,
572646        parameters = parameters ,
@@ -590,6 +664,7 @@ def _apply_group_offloading_leaf_level(
590664    stream : Union [torch .cuda .Stream , torch .Stream , None ] =  None ,
591665    record_stream : Optional [bool ] =  False ,
592666    low_cpu_mem_usage : bool  =  False ,
667+     offload_to_disk_path : Optional [str ] =  None ,
593668) ->  None :
594669    r""" 
595670    This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory 
@@ -604,6 +679,9 @@ def _apply_group_offloading_leaf_level(
604679            The device to which the group of modules are offloaded. This should typically be the CPU. 
605680        onload_device (`torch.device`): 
606681            The device to which the group of modules are onloaded. 
682+         offload_to_disk_path (`str`, *optional*, defaults to `None`): 
683+             The path to the directory where parameters will be offloaded. Setting this option can be useful in limited 
684+             RAM environment settings where a reasonable speed-memory trade-off is desired. 
607685        non_blocking (`bool`): 
608686            If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation 
609687            and data transfer. 
@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level(
629707            modules = [submodule ],
630708            offload_device = offload_device ,
631709            onload_device = onload_device ,
710+             offload_to_disk_path = offload_to_disk_path ,
632711            offload_leader = submodule ,
633712            onload_leader = submodule ,
634713            non_blocking = non_blocking ,
@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level(
675754            onload_device = onload_device ,
676755            offload_leader = parent_module ,
677756            onload_leader = parent_module ,
757+             offload_to_disk_path = offload_to_disk_path ,
678758            parameters = parameters ,
679759            buffers = buffers ,
680760            non_blocking = non_blocking ,
@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level(
693773            modules = [],
694774            offload_device = offload_device ,
695775            onload_device = onload_device ,
776+             offload_to_disk_path = offload_to_disk_path ,
696777            offload_leader = module ,
697778            onload_leader = module ,
698779            parameters = None ,
0 commit comments