@@ -1528,14 +1528,16 @@ def test_fn(storage_dtype, compute_dtype):
15281528 test_fn (torch .float8_e5m2 , torch .float32 )
15291529 test_fn (torch .float8_e4m3fn , torch .bfloat16 )
15301530
1531+ @torch .no_grad ()
15311532 def test_layerwise_casting_inference (self ):
15321533 from diffusers .hooks .layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN , SUPPORTED_PYTORCH_LAYERS
15331534
15341535 torch .manual_seed (0 )
15351536 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1536- model = self .model_class (** config ).eval ()
1537- model = model .to (torch_device )
1538- base_slice = model (** inputs_dict )[0 ].flatten ().detach ().cpu ().numpy ()
1537+ model = self .model_class (** config )
1538+ model .eval ()
1539+ model .to (torch_device )
1540+ base_slice = model (** inputs_dict )[0 ].detach ().flatten ().cpu ().numpy ()
15391541
15401542 def check_linear_dtype (module , storage_dtype , compute_dtype ):
15411543 patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
@@ -1573,6 +1575,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
15731575 test_layerwise_casting (torch .float8_e4m3fn , torch .bfloat16 )
15741576
15751577 @require_torch_accelerator
1578+ @torch .no_grad ()
15761579 def test_layerwise_casting_memory (self ):
15771580 MB_TOLERANCE = 0.2
15781581 LEAST_COMPUTE_CAPABILITY = 8.0
@@ -1706,10 +1709,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17061709 if not self .model_class ._supports_group_offloading :
17071710 pytest .skip ("Model does not support group offloading." )
17081711
1709- torch .manual_seed (0 )
1710- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1711- model = self .model_class (** init_dict )
1712-
17131712 torch .manual_seed (0 )
17141713 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
17151714 model = self .model_class (** init_dict )
@@ -1725,7 +1724,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17251724 ** additional_kwargs ,
17261725 )
17271726 has_safetensors = glob .glob (f"{ tmpdir } /*.safetensors" )
1728- assert has_safetensors , "No safetensors found in the directory."
1727+ self . assertTrue ( len ( has_safetensors ) > 0 , "No safetensors found in the offload directory." )
17291728 _ = model (** inputs_dict )[0 ]
17301729
17311730 def test_auto_model (self , expected_max_diff = 5e-5 ):
0 commit comments