@@ -1334,6 +1334,143 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
13341334 offload_buffers = len (model ._parameters ) > 0
13351335 cpu_offload (model , device , offload_buffers = offload_buffers )
13361336
1337+ def enable_group_offload (
1338+ self ,
1339+ onload_device : torch .device ,
1340+ offload_device : torch .device = torch .device ("cpu" ),
1341+ offload_type : str = "block_level" ,
1342+ num_blocks_per_group : Optional [int ] = None ,
1343+ non_blocking : bool = False ,
1344+ use_stream : bool = False ,
1345+ record_stream : bool = False ,
1346+ low_cpu_mem_usage = False ,
1347+ offload_to_disk_path : Optional [str ] = None ,
1348+ exclude_modules : Optional [Union [str , List [str ]]] = None ,
1349+ ) -> None :
1350+ r"""
1351+ Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
1352+ and where it is beneficial, we need to first provide some context on how other supported offloading methods
1353+ work.
1354+
1355+ Typically, offloading is done at two levels:
1356+ - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
1357+ works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator
1358+ device when needed for computation. This method is more memory-efficient than keeping all components on the
1359+ accelerator, but the memory requirements are still quite high. For this method to work, one needs memory
1360+ equivalent to size of the model in runtime dtype + size of largest intermediate activation tensors to be able
1361+ to complete the forward pass.
1362+ - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method.
1363+ It
1364+ works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
1365+ onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
1366+ memory, but can be slower due to the excessive number of device synchronizations.
1367+
1368+ Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
1369+ (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
1370+ offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations
1371+ is reduced.
1372+
1373+ Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability
1374+ to overlap data transfer and computation to reduce the overall execution time compared to sequential
1375+ offloading. This is enabled using layer prefetching with streams, i.e., the layer that is to be executed next
1376+ starts onloading to the accelerator device while the current layer is being executed - this increases the
1377+ memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made
1378+ much faster when using streams.
1379+
1380+ Args:
1381+ onload_device (`torch.device`):
1382+ The device to which the group of modules are onloaded.
1383+ offload_device (`torch.device`, defaults to `torch.device("cpu")`):
1384+ The device to which the group of modules are offloaded. This should typically be the CPU. Default is
1385+ CPU.
1386+ offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
1387+ The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
1388+ "block_level".
1389+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
1390+ The path to the directory where parameters will be offloaded. Setting this option can be useful in
1391+ limited RAM environment settings where a reasonable speed-memory trade-off is desired.
1392+ num_blocks_per_group (`int`, *optional*):
1393+ The number of blocks per group when using offload_type="block_level". This is required when using
1394+ offload_type="block_level".
1395+ non_blocking (`bool`, defaults to `False`):
1396+ If True, offloading and onloading is done with non-blocking data transfer.
1397+ use_stream (`bool`, defaults to `False`):
1398+ If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
1399+ overlapping computation and data transfer.
1400+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
1401+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to
1402+ the [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html)
1403+ more details.
1404+ low_cpu_mem_usage (`bool`, defaults to `False`):
1405+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them.
1406+ This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
1407+ useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
1408+ exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.
1409+
1410+ Example:
1411+ ```python
1412+ >>> from diffusers import DiffusionPipeline
1413+ >>> import torch
1414+
1415+ >>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
1416+
1417+ >>> pipe.enable_group_offload(
1418+ ... onload_device=torch.device("cuda"),
1419+ ... offload_device=torch.device("cpu"),
1420+ ... offload_type="leaf_level",
1421+ ... use_stream=True,
1422+ ... )
1423+ >>> image = pipe("a beautiful sunset").images[0]
1424+ ```
1425+ """
1426+ from ..hooks import apply_group_offloading
1427+
1428+ if isinstance (exclude_modules , str ):
1429+ exclude_modules = [exclude_modules ]
1430+ elif exclude_modules is None :
1431+ exclude_modules = []
1432+
1433+ unknown = set (exclude_modules ) - self .components .keys ()
1434+ if unknown :
1435+ logger .info (
1436+ f"The following modules are not present in pipeline: { ', ' .join (unknown )} . Ignore if this is expected."
1437+ )
1438+
1439+ for name , component in self .components .items ():
1440+ if name not in exclude_modules and isinstance (component , torch .nn .Module ):
1441+ if hasattr (component , "enable_group_offload" ):
1442+ component .enable_group_offload (
1443+ onload_device = onload_device ,
1444+ offload_device = offload_device ,
1445+ offload_type = offload_type ,
1446+ num_blocks_per_group = num_blocks_per_group ,
1447+ non_blocking = non_blocking ,
1448+ use_stream = use_stream ,
1449+ record_stream = record_stream ,
1450+ low_cpu_mem_usage = low_cpu_mem_usage ,
1451+ offload_to_disk_path = offload_to_disk_path ,
1452+ )
1453+ else :
1454+ apply_group_offloading (
1455+ module = component ,
1456+ onload_device = onload_device ,
1457+ offload_device = offload_device ,
1458+ offload_type = offload_type ,
1459+ num_blocks_per_group = num_blocks_per_group ,
1460+ non_blocking = non_blocking ,
1461+ use_stream = use_stream ,
1462+ record_stream = record_stream ,
1463+ low_cpu_mem_usage = low_cpu_mem_usage ,
1464+ offload_to_disk_path = offload_to_disk_path ,
1465+ )
1466+
1467+ if exclude_modules :
1468+ for module_name in exclude_modules :
1469+ module = getattr (self , module_name , None )
1470+ if module is not None and isinstance (module , torch .nn .Module ):
1471+ module .to (onload_device )
1472+ logger .debug (f"Placed `{ module_name } ` on { onload_device } device as it was in `exclude_modules`." )
1473+
13371474 def reset_device_map (self ):
13381475 r"""
13391476 Resets the device maps (if any) to None.
0 commit comments