@@ -651,22 +651,22 @@ def test_model_xattn_mask(self, mask_dtype):
651651
652652 keepall_mask = torch .ones (* cond .shape [:- 1 ], device = cond .device , dtype = mask_dtype )
653653 full_cond_keepallmask_out = model (** {** inputs_dict , "encoder_attention_mask" : keepall_mask }).sample
654- assert full_cond_keepallmask_out .allclose (full_cond_out , rtol = 1e-05 , atol = 1e-05 ), (
655- "a 'keep all' mask should give the same result as no mask"
656- )
654+ assert full_cond_keepallmask_out .allclose (
655+ full_cond_out , rtol = 1e-05 , atol = 1e-05
656+ ), "a 'keep all' mask should give the same result as no mask"
657657
658658 trunc_cond = cond [:, :- 1 , :]
659659 trunc_cond_out = model (** {** inputs_dict , "encoder_hidden_states" : trunc_cond }).sample
660- assert not trunc_cond_out .allclose (full_cond_out , rtol = 1e-05 , atol = 1e-05 ), (
661- "discarding the last token from our cond should change the result"
662- )
660+ assert not trunc_cond_out .allclose (
661+ full_cond_out , rtol = 1e-05 , atol = 1e-05
662+ ), "discarding the last token from our cond should change the result"
663663
664664 batch , tokens , _ = cond .shape
665665 mask_last = (torch .arange (tokens ) < tokens - 1 ).expand (batch , - 1 ).to (cond .device , mask_dtype )
666666 masked_cond_out = model (** {** inputs_dict , "encoder_attention_mask" : mask_last }).sample
667- assert masked_cond_out .allclose (trunc_cond_out , rtol = 1e-05 , atol = 1e-05 ), (
668- "masking the last token from our cond should be equivalent to truncating that token out of the condition"
669- )
667+ assert masked_cond_out .allclose (
668+ trunc_cond_out , rtol = 1e-05 , atol = 1e-05
669+ ), "masking the last token from our cond should be equivalent to truncating that token out of the condition"
670670
671671 # see diffusers.models.attention_processor::Attention#prepare_attention_mask
672672 # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
@@ -694,9 +694,9 @@ def test_model_xattn_padding(self):
694694
695695 trunc_mask = torch .zeros (batch , tokens - 1 , device = cond .device , dtype = torch .bool )
696696 trunc_mask_out = model (** {** inputs_dict , "encoder_attention_mask" : trunc_mask }).sample
697- assert trunc_mask_out .allclose (keeplast_out ), (
698- "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
699- )
697+ assert trunc_mask_out .allclose (
698+ keeplast_out
699+ ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
700700
701701 def test_custom_diffusion_processors (self ):
702702 # enable deterministic behavior for gradient checkpointing
@@ -1111,12 +1111,12 @@ def test_load_attn_procs_raise_warning(self):
11111111 with torch .no_grad ():
11121112 lora_sample_2 = model (** inputs_dict ).sample
11131113
1114- assert not torch .allclose (non_lora_sample , lora_sample_1 , atol = 1e-4 , rtol = 1e-4 ), (
1115- "LoRA injected UNet should produce different results."
1116- )
1117- assert torch .allclose (lora_sample_1 , lora_sample_2 , atol = 1e-4 , rtol = 1e-4 ), (
1118- "Loading from a saved checkpoint should produce identical results."
1119- )
1114+ assert not torch .allclose (
1115+ non_lora_sample , lora_sample_1 , atol = 1e-4 , rtol = 1e-4
1116+ ), "LoRA injected UNet should produce different results."
1117+ assert torch .allclose (
1118+ lora_sample_1 , lora_sample_2 , atol = 1e-4 , rtol = 1e-4
1119+ ), "Loading from a saved checkpoint should produce identical results."
11201120
11211121 @require_peft_backend
11221122 def test_save_attn_procs_raise_warning (self ):
0 commit comments