5959from diffusers .utils .testing_utils import (
6060 CaptureLogger ,
6161 backend_empty_cache ,
62+ backend_max_memory_allocated ,
63+ backend_reset_peak_memory_stats ,
64+ backend_synchronize ,
6265 floats_tensor ,
6366 get_python_version ,
6467 is_torch_compile ,
6871 require_torch_2 ,
6972 require_torch_accelerator ,
7073 require_torch_accelerator_with_training ,
71- require_torch_gpu ,
7274 require_torch_multi_accelerator ,
7375 run_test_in_subprocess ,
7476 slow ,
@@ -341,7 +343,7 @@ def test_weight_overwrite(self):
341343
342344 assert model .config .in_channels == 9
343345
344- @require_torch_gpu
346+ @require_torch_accelerator
345347 def test_keep_modules_in_fp32 (self ):
346348 r"""
347349 A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
@@ -1480,16 +1482,16 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
14801482 test_layerwise_casting (torch .float8_e5m2 , torch .float32 )
14811483 test_layerwise_casting (torch .float8_e4m3fn , torch .bfloat16 )
14821484
1483- @require_torch_gpu
1485+ @require_torch_accelerator
14841486 def test_layerwise_casting_memory (self ):
14851487 MB_TOLERANCE = 0.2
14861488 LEAST_COMPUTE_CAPABILITY = 8.0
14871489
14881490 def reset_memory_stats ():
14891491 gc .collect ()
1490- torch . cuda . synchronize ( )
1491- torch . cuda . empty_cache ( )
1492- torch . cuda . reset_peak_memory_stats ( )
1492+ backend_synchronize ( torch_device )
1493+ backend_empty_cache ( torch_device )
1494+ backend_reset_peak_memory_stats ( torch_device )
14931495
14941496 def get_memory_usage (storage_dtype , compute_dtype ):
14951497 torch .manual_seed (0 )
@@ -1502,7 +1504,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15021504 reset_memory_stats ()
15031505 model (** inputs_dict )
15041506 model_memory_footprint = model .get_memory_footprint ()
1505- peak_inference_memory_allocated_mb = torch . cuda . max_memory_allocated ( ) / 1024 ** 2
1507+ peak_inference_memory_allocated_mb = backend_max_memory_allocated ( torch_device ) / 1024 ** 2
15061508
15071509 return model_memory_footprint , peak_inference_memory_allocated_mb
15081510
@@ -1512,7 +1514,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15121514 torch .float8_e4m3fn , torch .bfloat16
15131515 )
15141516
1515- compute_capability = get_torch_cuda_device_capability ()
1517+ compute_capability = get_torch_cuda_device_capability () if torch_device == "cuda" else None
15161518 self .assertTrue (fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint )
15171519 # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
15181520 # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
@@ -1527,7 +1529,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15271529 )
15281530
15291531 @parameterized .expand ([False , True ])
1530- @require_torch_gpu
1532+ @require_torch_accelerator
15311533 def test_group_offloading (self , record_stream ):
15321534 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
15331535 torch .manual_seed (0 )
0 commit comments