| 
9 | 9 | 
 
  | 
10 | 10 | import numpy as np  | 
11 | 11 | import PIL.Image  | 
 | 12 | +import pytest  | 
12 | 13 | import torch  | 
13 | 14 | import torch.nn as nn  | 
14 | 15 | from huggingface_hub import ModelCard, delete_repo  | 
@@ -2362,6 +2363,73 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4  | 
2362 | 2363 |         max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()  | 
2363 | 2364 |         self.assertLess(max_diff, expected_max_difference)  | 
2364 | 2365 | 
 
  | 
 | 2366 | +    @require_torch_accelerator  | 
 | 2367 | +    def test_pipeline_level_group_offloading_sanity_checks(self):  | 
 | 2368 | +        components = self.get_dummy_components()  | 
 | 2369 | +        pipe: DiffusionPipeline = self.pipeline_class(**components)  | 
 | 2370 | + | 
 | 2371 | +        for name, component in pipe.components.items():  | 
 | 2372 | +            if hasattr(component, "_supports_group_offloading"):  | 
 | 2373 | +                if not component._supports_group_offloading:  | 
 | 2374 | +                    pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")  | 
 | 2375 | + | 
 | 2376 | +        module_names = sorted(  | 
 | 2377 | +            [name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)]  | 
 | 2378 | +        )  | 
 | 2379 | +        exclude_module_name = module_names[0]  | 
 | 2380 | +        offload_device = "cpu"  | 
 | 2381 | +        pipe.enable_group_offload(  | 
 | 2382 | +            onload_device=torch_device,  | 
 | 2383 | +            offload_device=offload_device,  | 
 | 2384 | +            offload_type="leaf_level",  | 
 | 2385 | +            exclude_modules=exclude_module_name,  | 
 | 2386 | +        )  | 
 | 2387 | +        excluded_module = getattr(pipe, exclude_module_name)  | 
 | 2388 | +        self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type)  | 
 | 2389 | + | 
 | 2390 | +        for name, component in pipe.components.items():  | 
 | 2391 | +            if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):  | 
 | 2392 | +                # `component.device` prints the `onload_device` type. We should probably override the  | 
 | 2393 | +                # `device` property in `ModelMixin`.  | 
 | 2394 | +                component_device = next(component.parameters())[0].device  | 
 | 2395 | +                self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)  | 
 | 2396 | + | 
 | 2397 | +    @require_torch_accelerator  | 
 | 2398 | +    def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):  | 
 | 2399 | +        components = self.get_dummy_components()  | 
 | 2400 | +        pipe: DiffusionPipeline = self.pipeline_class(**components)  | 
 | 2401 | + | 
 | 2402 | +        for name, component in pipe.components.items():  | 
 | 2403 | +            if hasattr(component, "_supports_group_offloading"):  | 
 | 2404 | +                if not component._supports_group_offloading:  | 
 | 2405 | +                    pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")  | 
 | 2406 | + | 
 | 2407 | +        # Regular inference.  | 
 | 2408 | +        pipe = pipe.to(torch_device)  | 
 | 2409 | +        pipe.set_progress_bar_config(disable=None)  | 
 | 2410 | +        torch.manual_seed(0)  | 
 | 2411 | +        inputs = self.get_dummy_inputs(torch_device)  | 
 | 2412 | +        inputs["generator"] = torch.manual_seed(0)  | 
 | 2413 | +        out = pipe(**inputs)[0]  | 
 | 2414 | + | 
 | 2415 | +        pipe.to("cpu")  | 
 | 2416 | +        del pipe  | 
 | 2417 | + | 
 | 2418 | +        # Inference with offloading  | 
 | 2419 | +        pipe: DiffusionPipeline = self.pipeline_class(**components)  | 
 | 2420 | +        offload_device = "cpu"  | 
 | 2421 | +        pipe.enable_group_offload(  | 
 | 2422 | +            onload_device=torch_device,  | 
 | 2423 | +            offload_device=offload_device,  | 
 | 2424 | +            offload_type="leaf_level",  | 
 | 2425 | +        )  | 
 | 2426 | +        pipe.set_progress_bar_config(disable=None)  | 
 | 2427 | +        inputs["generator"] = torch.manual_seed(0)  | 
 | 2428 | +        out_offload = pipe(**inputs)[0]  | 
 | 2429 | + | 
 | 2430 | +        max_diff = np.abs(to_np(out) - to_np(out_offload)).max()  | 
 | 2431 | +        self.assertLess(max_diff, expected_max_difference)  | 
 | 2432 | + | 
2365 | 2433 | 
 
  | 
2366 | 2434 | @is_staging_test  | 
2367 | 2435 | class PipelinePushToHubTester(unittest.TestCase):  | 
 | 
0 commit comments