Skip to content

Commit e141f5c

Browse files
committed
add tests
1 parent 25d9c70 commit e141f5c

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,10 +1425,12 @@ def enable_group_offload(
14251425
"""
14261426
from ..hooks import apply_group_offloading
14271427

1428-
if exclude_modules is not None and isinstance(exclude_modules, str):
1428+
if isinstance(exclude_modules, str):
14291429
exclude_modules = [exclude_modules]
1430+
elif exclude_modules is None:
1431+
exclude_modules = []
14301432

1431-
unknown = set(exclude_modules) - set(self.components.keys())
1433+
unknown = set(exclude_modules) - self.components.keys()
14321434
if unknown:
14331435
logger.info(
14341436
f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."

tests/pipelines/test_pipelines_common.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import PIL.Image
12+
import pytest
1213
import torch
1314
import torch.nn as nn
1415
from huggingface_hub import ModelCard, delete_repo
@@ -2362,6 +2363,73 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4
23622363
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
23632364
self.assertLess(max_diff, expected_max_difference)
23642365

2366+
@require_torch_accelerator
2367+
def test_pipeline_level_group_offloading_sanity_checks(self):
2368+
components = self.get_dummy_components()
2369+
pipe: DiffusionPipeline = self.pipeline_class(**components)
2370+
2371+
for name, component in pipe.components.items():
2372+
if hasattr(component, "_supports_group_offloading"):
2373+
if not component._supports_group_offloading:
2374+
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
2375+
2376+
module_names = sorted(
2377+
[name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)]
2378+
)
2379+
exclude_module_name = module_names[0]
2380+
offload_device = "cpu"
2381+
pipe.enable_group_offload(
2382+
onload_device=torch_device,
2383+
offload_device=offload_device,
2384+
offload_type="leaf_level",
2385+
exclude_modules=exclude_module_name,
2386+
)
2387+
excluded_module = getattr(pipe, exclude_module_name)
2388+
self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type)
2389+
2390+
for name, component in pipe.components.items():
2391+
if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
2392+
# `component.device` prints the `onload_device` type. We should probably override the
2393+
# `device` property in `ModelMixin`.
2394+
component_device = next(component.parameters())[0].device
2395+
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
2396+
2397+
@require_torch_accelerator
2398+
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
2399+
components = self.get_dummy_components()
2400+
pipe: DiffusionPipeline = self.pipeline_class(**components)
2401+
2402+
for name, component in pipe.components.items():
2403+
if hasattr(component, "_supports_group_offloading"):
2404+
if not component._supports_group_offloading:
2405+
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
2406+
2407+
# Regular inference.
2408+
pipe = pipe.to(torch_device)
2409+
pipe.set_progress_bar_config(disable=None)
2410+
torch.manual_seed(0)
2411+
inputs = self.get_dummy_inputs(torch_device)
2412+
inputs["generator"] = torch.manual_seed(0)
2413+
out = pipe(**inputs)[0]
2414+
2415+
pipe.to("cpu")
2416+
del pipe
2417+
2418+
# Inference with offloading
2419+
pipe: DiffusionPipeline = self.pipeline_class(**components)
2420+
offload_device = "cpu"
2421+
pipe.enable_group_offload(
2422+
onload_device=torch_device,
2423+
offload_device=offload_device,
2424+
offload_type="leaf_level",
2425+
)
2426+
pipe.set_progress_bar_config(disable=None)
2427+
inputs["generator"] = torch.manual_seed(0)
2428+
out_offload = pipe(**inputs)[0]
2429+
2430+
max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
2431+
self.assertLess(max_diff, expected_max_difference)
2432+
23652433

23662434
@is_staging_test
23672435
class PipelinePushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)