Skip to content

Commit 8c7d612

Browse files
authored
Merge branch 'support-group-offloading-pipeline-level' into group-offload
2 parents 224e7d2 + 506424c commit 8c7d612

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,143 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
13341334
offload_buffers = len(model._parameters) > 0
13351335
cpu_offload(model, device, offload_buffers=offload_buffers)
13361336

1337+
def enable_group_offload(
1338+
self,
1339+
onload_device: torch.device,
1340+
offload_device: torch.device = torch.device("cpu"),
1341+
offload_type: str = "block_level",
1342+
num_blocks_per_group: Optional[int] = None,
1343+
non_blocking: bool = False,
1344+
use_stream: bool = False,
1345+
record_stream: bool = False,
1346+
low_cpu_mem_usage=False,
1347+
offload_to_disk_path: Optional[str] = None,
1348+
exclude_modules: Optional[Union[str, List[str]]] = None,
1349+
) -> None:
1350+
r"""
1351+
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
1352+
and where it is beneficial, we need to first provide some context on how other supported offloading methods
1353+
work.
1354+
1355+
Typically, offloading is done at two levels:
1356+
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
1357+
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator
1358+
device when needed for computation. This method is more memory-efficient than keeping all components on the
1359+
accelerator, but the memory requirements are still quite high. For this method to work, one needs memory
1360+
equivalent to size of the model in runtime dtype + size of largest intermediate activation tensors to be able
1361+
to complete the forward pass.
1362+
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method.
1363+
It
1364+
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
1365+
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
1366+
memory, but can be slower due to the excessive number of device synchronizations.
1367+
1368+
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
1369+
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
1370+
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations
1371+
is reduced.
1372+
1373+
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability
1374+
to overlap data transfer and computation to reduce the overall execution time compared to sequential
1375+
offloading. This is enabled using layer prefetching with streams, i.e., the layer that is to be executed next
1376+
starts onloading to the accelerator device while the current layer is being executed - this increases the
1377+
memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made
1378+
much faster when using streams.
1379+
1380+
Args:
1381+
onload_device (`torch.device`):
1382+
The device to which the group of modules are onloaded.
1383+
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
1384+
The device to which the group of modules are offloaded. This should typically be the CPU. Default is
1385+
CPU.
1386+
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
1387+
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
1388+
"block_level".
1389+
offload_to_disk_path (`str`, *optional*, defaults to `None`):
1390+
The path to the directory where parameters will be offloaded. Setting this option can be useful in
1391+
limited RAM environment settings where a reasonable speed-memory trade-off is desired.
1392+
num_blocks_per_group (`int`, *optional*):
1393+
The number of blocks per group when using offload_type="block_level". This is required when using
1394+
offload_type="block_level".
1395+
non_blocking (`bool`, defaults to `False`):
1396+
If True, offloading and onloading is done with non-blocking data transfer.
1397+
use_stream (`bool`, defaults to `False`):
1398+
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
1399+
overlapping computation and data transfer.
1400+
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
1401+
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to
1402+
the [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html)
1403+
more details.
1404+
low_cpu_mem_usage (`bool`, defaults to `False`):
1405+
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them.
1406+
This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
1407+
useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
1408+
exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.
1409+
1410+
Example:
1411+
```python
1412+
>>> from diffusers import DiffusionPipeline
1413+
>>> import torch
1414+
1415+
>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
1416+
1417+
>>> pipe.enable_group_offload(
1418+
... onload_device=torch.device("cuda"),
1419+
... offload_device=torch.device("cpu"),
1420+
... offload_type="leaf_level",
1421+
... use_stream=True,
1422+
... )
1423+
>>> image = pipe("a beautiful sunset").images[0]
1424+
```
1425+
"""
1426+
from ..hooks import apply_group_offloading
1427+
1428+
if isinstance(exclude_modules, str):
1429+
exclude_modules = [exclude_modules]
1430+
elif exclude_modules is None:
1431+
exclude_modules = []
1432+
1433+
unknown = set(exclude_modules) - self.components.keys()
1434+
if unknown:
1435+
logger.info(
1436+
f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."
1437+
)
1438+
1439+
for name, component in self.components.items():
1440+
if name not in exclude_modules and isinstance(component, torch.nn.Module):
1441+
if hasattr(component, "enable_group_offload"):
1442+
component.enable_group_offload(
1443+
onload_device=onload_device,
1444+
offload_device=offload_device,
1445+
offload_type=offload_type,
1446+
num_blocks_per_group=num_blocks_per_group,
1447+
non_blocking=non_blocking,
1448+
use_stream=use_stream,
1449+
record_stream=record_stream,
1450+
low_cpu_mem_usage=low_cpu_mem_usage,
1451+
offload_to_disk_path=offload_to_disk_path,
1452+
)
1453+
else:
1454+
apply_group_offloading(
1455+
module=component,
1456+
onload_device=onload_device,
1457+
offload_device=offload_device,
1458+
offload_type=offload_type,
1459+
num_blocks_per_group=num_blocks_per_group,
1460+
non_blocking=non_blocking,
1461+
use_stream=use_stream,
1462+
record_stream=record_stream,
1463+
low_cpu_mem_usage=low_cpu_mem_usage,
1464+
offload_to_disk_path=offload_to_disk_path,
1465+
)
1466+
1467+
if exclude_modules:
1468+
for module_name in exclude_modules:
1469+
module = getattr(self, module_name, None)
1470+
if module is not None and isinstance(module, torch.nn.Module):
1471+
module.to(onload_device)
1472+
logger.debug(f"Placed `{module_name}` on {onload_device} device as it was in `exclude_modules`.")
1473+
13371474
def reset_device_map(self):
13381475
r"""
13391476
Resets the device maps (if any) to None.

tests/pipelines/test_pipelines_common.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import PIL.Image
12+
import pytest
1213
import torch
1314
import torch.nn as nn
1415
from huggingface_hub import ModelCard, delete_repo
@@ -2362,6 +2363,73 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4
23622363
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
23632364
self.assertLess(max_diff, expected_max_difference)
23642365

2366+
@require_torch_accelerator
2367+
def test_pipeline_level_group_offloading_sanity_checks(self):
2368+
components = self.get_dummy_components()
2369+
pipe: DiffusionPipeline = self.pipeline_class(**components)
2370+
2371+
for name, component in pipe.components.items():
2372+
if hasattr(component, "_supports_group_offloading"):
2373+
if not component._supports_group_offloading:
2374+
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
2375+
2376+
module_names = sorted(
2377+
[name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)]
2378+
)
2379+
exclude_module_name = module_names[0]
2380+
offload_device = "cpu"
2381+
pipe.enable_group_offload(
2382+
onload_device=torch_device,
2383+
offload_device=offload_device,
2384+
offload_type="leaf_level",
2385+
exclude_modules=exclude_module_name,
2386+
)
2387+
excluded_module = getattr(pipe, exclude_module_name)
2388+
self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type)
2389+
2390+
for name, component in pipe.components.items():
2391+
if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
2392+
# `component.device` prints the `onload_device` type. We should probably override the
2393+
# `device` property in `ModelMixin`.
2394+
component_device = next(component.parameters())[0].device
2395+
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
2396+
2397+
@require_torch_accelerator
2398+
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
2399+
components = self.get_dummy_components()
2400+
pipe: DiffusionPipeline = self.pipeline_class(**components)
2401+
2402+
for name, component in pipe.components.items():
2403+
if hasattr(component, "_supports_group_offloading"):
2404+
if not component._supports_group_offloading:
2405+
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
2406+
2407+
# Regular inference.
2408+
pipe = pipe.to(torch_device)
2409+
pipe.set_progress_bar_config(disable=None)
2410+
torch.manual_seed(0)
2411+
inputs = self.get_dummy_inputs(torch_device)
2412+
inputs["generator"] = torch.manual_seed(0)
2413+
out = pipe(**inputs)[0]
2414+
2415+
pipe.to("cpu")
2416+
del pipe
2417+
2418+
# Inference with offloading
2419+
pipe: DiffusionPipeline = self.pipeline_class(**components)
2420+
offload_device = "cpu"
2421+
pipe.enable_group_offload(
2422+
onload_device=torch_device,
2423+
offload_device=offload_device,
2424+
offload_type="leaf_level",
2425+
)
2426+
pipe.set_progress_bar_config(disable=None)
2427+
inputs["generator"] = torch.manual_seed(0)
2428+
out_offload = pipe(**inputs)[0]
2429+
2430+
max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
2431+
self.assertLess(max_diff, expected_max_difference)
2432+
23652433

23662434
@is_staging_test
23672435
class PipelinePushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)