Skip to content

Commit 25d9c70

Browse files
committed
feat: support group offloading at the pipeline level.
1 parent 764b624 commit 25d9c70

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,141 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
13341334
offload_buffers = len(model._parameters) > 0
13351335
cpu_offload(model, device, offload_buffers=offload_buffers)
13361336

1337+
def enable_group_offload(
1338+
self,
1339+
onload_device: torch.device,
1340+
offload_device: torch.device = torch.device("cpu"),
1341+
offload_type: str = "block_level",
1342+
num_blocks_per_group: Optional[int] = None,
1343+
non_blocking: bool = False,
1344+
use_stream: bool = False,
1345+
record_stream: bool = False,
1346+
low_cpu_mem_usage=False,
1347+
offload_to_disk_path: Optional[str] = None,
1348+
exclude_modules: Optional[Union[str, List[str]]] = None,
1349+
) -> None:
1350+
r"""
1351+
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
1352+
and where it is beneficial, we need to first provide some context on how other supported offloading methods
1353+
work.
1354+
1355+
Typically, offloading is done at two levels:
1356+
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
1357+
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator
1358+
device when needed for computation. This method is more memory-efficient than keeping all components on the
1359+
accelerator, but the memory requirements are still quite high. For this method to work, one needs memory
1360+
equivalent to size of the model in runtime dtype + size of largest intermediate activation tensors to be able
1361+
to complete the forward pass.
1362+
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method.
1363+
It
1364+
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
1365+
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
1366+
memory, but can be slower due to the excessive number of device synchronizations.
1367+
1368+
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
1369+
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
1370+
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations
1371+
is reduced.
1372+
1373+
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability
1374+
to overlap data transfer and computation to reduce the overall execution time compared to sequential
1375+
offloading. This is enabled using layer prefetching with streams, i.e., the layer that is to be executed next
1376+
starts onloading to the accelerator device while the current layer is being executed - this increases the
1377+
memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made
1378+
much faster when using streams.
1379+
1380+
Args:
1381+
onload_device (`torch.device`):
1382+
The device to which the group of modules are onloaded.
1383+
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
1384+
The device to which the group of modules are offloaded. This should typically be the CPU. Default is
1385+
CPU.
1386+
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
1387+
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
1388+
"block_level".
1389+
offload_to_disk_path (`str`, *optional*, defaults to `None`):
1390+
The path to the directory where parameters will be offloaded. Setting this option can be useful in
1391+
limited RAM environment settings where a reasonable speed-memory trade-off is desired.
1392+
num_blocks_per_group (`int`, *optional*):
1393+
The number of blocks per group when using offload_type="block_level". This is required when using
1394+
offload_type="block_level".
1395+
non_blocking (`bool`, defaults to `False`):
1396+
If True, offloading and onloading is done with non-blocking data transfer.
1397+
use_stream (`bool`, defaults to `False`):
1398+
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
1399+
overlapping computation and data transfer.
1400+
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
1401+
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to
1402+
the [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html)
1403+
more details.
1404+
low_cpu_mem_usage (`bool`, defaults to `False`):
1405+
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them.
1406+
This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
1407+
useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
1408+
exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.
1409+
1410+
Example:
1411+
```python
1412+
>>> from diffusers import DiffusionPipeline
1413+
>>> import torch
1414+
1415+
>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
1416+
1417+
>>> pipe.enable_group_offload(
1418+
... onload_device=torch.device("cuda"),
1419+
... offload_device=torch.device("cpu"),
1420+
... offload_type="leaf_level",
1421+
... use_stream=True,
1422+
... )
1423+
>>> image = pipe("a beautiful sunset").images[0]
1424+
```
1425+
"""
1426+
from ..hooks import apply_group_offloading
1427+
1428+
if exclude_modules is not None and isinstance(exclude_modules, str):
1429+
exclude_modules = [exclude_modules]
1430+
1431+
unknown = set(exclude_modules) - set(self.components.keys())
1432+
if unknown:
1433+
logger.info(
1434+
f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."
1435+
)
1436+
1437+
for name, component in self.components.items():
1438+
if name not in exclude_modules and isinstance(component, torch.nn.Module):
1439+
if hasattr(component, "enable_group_offload"):
1440+
component.enable_group_offload(
1441+
onload_device=onload_device,
1442+
offload_device=offload_device,
1443+
offload_type=offload_type,
1444+
num_blocks_per_group=num_blocks_per_group,
1445+
non_blocking=non_blocking,
1446+
use_stream=use_stream,
1447+
record_stream=record_stream,
1448+
low_cpu_mem_usage=low_cpu_mem_usage,
1449+
offload_to_disk_path=offload_to_disk_path,
1450+
)
1451+
else:
1452+
apply_group_offloading(
1453+
module=component,
1454+
onload_device=onload_device,
1455+
offload_device=offload_device,
1456+
offload_type=offload_type,
1457+
num_blocks_per_group=num_blocks_per_group,
1458+
non_blocking=non_blocking,
1459+
use_stream=use_stream,
1460+
record_stream=record_stream,
1461+
low_cpu_mem_usage=low_cpu_mem_usage,
1462+
offload_to_disk_path=offload_to_disk_path,
1463+
)
1464+
1465+
if exclude_modules:
1466+
for module_name in exclude_modules:
1467+
module = getattr(self, module_name, None)
1468+
if module is not None and isinstance(module, torch.nn.Module):
1469+
module.to(onload_device)
1470+
logger.debug(f"Placed `{module_name}` on {onload_device} device as it was in `exclude_modules`.")
1471+
13371472
def reset_device_map(self):
13381473
r"""
13391474
Resets the device maps (if any) to None.

0 commit comments

Comments
 (0)