Skip to content

Commit 84115dc

Browse files
reset
1 parent 601696d commit 84115dc

File tree

77 files changed

+591
-586
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+591
-586
lines changed

tests/models/test_modeling_common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,9 @@ def test_one_request_upon_cached(self):
292292
)
293293

294294
download_requests = [r.method for r in m.request_history]
295-
assert download_requests.count("HEAD") == 3, (
296-
"3 HEAD requests one for config, one for model, and one for shard index file."
297-
)
295+
assert (
296+
download_requests.count("HEAD") == 3
297+
), "3 HEAD requests one for config, one for model, and one for shard index file."
298298
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
299299

300300
with requests_mock.mock(real_http=True) as m:
@@ -306,9 +306,9 @@ def test_one_request_upon_cached(self):
306306
)
307307

308308
cache_requests = [r.method for r in m.request_history]
309-
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
310-
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
311-
)
309+
assert (
310+
"HEAD" == cache_requests[0] and len(cache_requests) == 2
311+
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
312312

313313
def test_weight_overwrite(self):
314314
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def test_xformers_enable_works(self):
9191

9292
model.enable_xformers_memory_efficient_attention()
9393

94-
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
95-
"xformers is not enabled"
96-
)
94+
assert (
95+
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
96+
), "xformers is not enabled"
9797

9898
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
9999
def test_set_attn_processor_for_determinism(self):
@@ -165,9 +165,9 @@ def test_xformers_enable_works(self):
165165

166166
model.enable_xformers_memory_efficient_attention()
167167

168-
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
169-
"xformers is not enabled"
170-
)
168+
assert (
169+
model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
170+
), "xformers is not enabled"
171171

172172
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
173173
def test_set_attn_processor_for_determinism(self):

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -651,22 +651,22 @@ def test_model_xattn_mask(self, mask_dtype):
651651

652652
keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype)
653653
full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample
654-
assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
655-
"a 'keep all' mask should give the same result as no mask"
656-
)
654+
assert full_cond_keepallmask_out.allclose(
655+
full_cond_out, rtol=1e-05, atol=1e-05
656+
), "a 'keep all' mask should give the same result as no mask"
657657

658658
trunc_cond = cond[:, :-1, :]
659659
trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample
660-
assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
661-
"discarding the last token from our cond should change the result"
662-
)
660+
assert not trunc_cond_out.allclose(
661+
full_cond_out, rtol=1e-05, atol=1e-05
662+
), "discarding the last token from our cond should change the result"
663663

664664
batch, tokens, _ = cond.shape
665665
mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype)
666666
masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample
667-
assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), (
668-
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
669-
)
667+
assert masked_cond_out.allclose(
668+
trunc_cond_out, rtol=1e-05, atol=1e-05
669+
), "masking the last token from our cond should be equivalent to truncating that token out of the condition"
670670

671671
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
672672
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
@@ -694,9 +694,9 @@ def test_model_xattn_padding(self):
694694

695695
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
696696
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
697-
assert trunc_mask_out.allclose(keeplast_out), (
698-
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
699-
)
697+
assert trunc_mask_out.allclose(
698+
keeplast_out
699+
), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
700700

701701
def test_custom_diffusion_processors(self):
702702
# enable deterministic behavior for gradient checkpointing
@@ -1111,12 +1111,12 @@ def test_load_attn_procs_raise_warning(self):
11111111
with torch.no_grad():
11121112
lora_sample_2 = model(**inputs_dict).sample
11131113

1114-
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
1115-
"LoRA injected UNet should produce different results."
1116-
)
1117-
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
1118-
"Loading from a saved checkpoint should produce identical results."
1119-
)
1114+
assert not torch.allclose(
1115+
non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
1116+
), "LoRA injected UNet should produce different results."
1117+
assert torch.allclose(
1118+
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
1119+
), "Loading from a saved checkpoint should produce identical results."
11201120

11211121
@require_peft_backend
11221122
def test_save_attn_procs_raise_warning(self):

tests/others/test_image_processor.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def test_vae_image_processor_pt(self):
6565
)
6666
out_np = self.to_np(out)
6767
in_np = (input_np * 255).round() if output_type == "pil" else input_np
68-
assert np.abs(in_np - out_np).max() < 1e-6, (
69-
f"decoded output does not match input for output_type {output_type}"
70-
)
68+
assert (
69+
np.abs(in_np - out_np).max() < 1e-6
70+
), f"decoded output does not match input for output_type {output_type}"
7171

7272
def test_vae_image_processor_np(self):
7373
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -78,9 +78,9 @@ def test_vae_image_processor_np(self):
7878

7979
out_np = self.to_np(out)
8080
in_np = (input_np * 255).round() if output_type == "pil" else input_np
81-
assert np.abs(in_np - out_np).max() < 1e-6, (
82-
f"decoded output does not match input for output_type {output_type}"
83-
)
81+
assert (
82+
np.abs(in_np - out_np).max() < 1e-6
83+
), f"decoded output does not match input for output_type {output_type}"
8484

8585
def test_vae_image_processor_pil(self):
8686
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -93,9 +93,9 @@ def test_vae_image_processor_pil(self):
9393
for i, o in zip(input_pil, out):
9494
in_np = np.array(i)
9595
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
96-
assert np.abs(in_np - out_np).max() < 1e-6, (
97-
f"decoded output does not match input for output_type {output_type}"
98-
)
96+
assert (
97+
np.abs(in_np - out_np).max() < 1e-6
98+
), f"decoded output does not match input for output_type {output_type}"
9999

100100
def test_preprocess_input_3d(self):
101101
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
@@ -293,9 +293,9 @@ def test_vae_image_processor_resize_pt(self):
293293
scale = 2
294294
out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale)
295295
exp_pt_shape = (b, c, h // scale, w // scale)
296-
assert out_pt.shape == exp_pt_shape, (
297-
f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
298-
)
296+
assert (
297+
out_pt.shape == exp_pt_shape
298+
), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
299299

300300
def test_vae_image_processor_resize_np(self):
301301
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
@@ -305,6 +305,6 @@ def test_vae_image_processor_resize_np(self):
305305
input_np = self.to_np(input_pt)
306306
out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale)
307307
exp_np_shape = (b, h // scale, w // scale, c)
308-
assert out_np.shape == exp_np_shape, (
309-
f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
310-
)
308+
assert (
309+
out_np.shape == exp_np_shape
310+
), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."

tests/pipelines/amused/test_amused.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
125125
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
126126

127127
@unittest.skip("aMUSEd does not support lists of generators")
128-
def test_inference_batch_single_identical(self): ...
128+
def test_inference_batch_single_identical(self):
129+
...
129130

130131

131132
@slow

tests/pipelines/amused/test_amused_img2img.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
126126
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
127127

128128
@unittest.skip("aMUSEd does not support lists of generators")
129-
def test_inference_batch_single_identical(self): ...
129+
def test_inference_batch_single_identical(self):
130+
...
130131

131132

132133
@slow

tests/pipelines/amused/test_amused_inpaint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def test_inference_batch_consistent(self, batch_sizes=[2]):
130130
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
131131

132132
@unittest.skip("aMUSEd does not support lists of generators")
133-
def test_inference_batch_single_identical(self): ...
133+
def test_inference_batch_single_identical(self):
134+
...
134135

135136

136137
@slow

tests/pipelines/aura_flow/test_pipeline_aura_flow.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ def test_fused_qkv_projections(self):
139139
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
140140
# to the pipeline level.
141141
pipe.transformer.fuse_qkv_projections()
142-
assert check_qkv_fusion_processors_exist(pipe.transformer), (
143-
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
144-
)
142+
assert check_qkv_fusion_processors_exist(
143+
pipe.transformer
144+
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
145145
assert check_qkv_fusion_matches_attn_procs_length(
146146
pipe.transformer, pipe.transformer.original_attn_processors
147147
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -155,15 +155,15 @@ def test_fused_qkv_projections(self):
155155
image = pipe(**inputs).images
156156
image_slice_disabled = image[0, -3:, -3:, -1]
157157

158-
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
159-
"Fusion of QKV projections shouldn't affect the outputs."
160-
)
161-
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
162-
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
163-
)
164-
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
165-
"Original outputs should match when fused QKV projections are disabled."
166-
)
158+
assert np.allclose(
159+
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
160+
), "Fusion of QKV projections shouldn't affect the outputs."
161+
assert np.allclose(
162+
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
163+
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
164+
assert np.allclose(
165+
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
166+
), "Original outputs should match when fused QKV projections are disabled."
167167

168168
@unittest.skip("xformers attention processor does not exist for AuraFlow")
169169
def test_xformers_attention_forwardGenerator_pass(self):

tests/pipelines/blipdiffusion/test_blipdiffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,6 @@ def test_blipdiffusion(self):
195195
[0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007]
196196
)
197197

198-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
199-
f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
200-
)
198+
assert (
199+
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
200+
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"

tests/pipelines/cogvideo/test_cogvideox.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ def test_fused_qkv_projections(self):
295295
original_image_slice = frames[0, -2:, -1, -3:, -3:]
296296

297297
pipe.fuse_qkv_projections()
298-
assert check_qkv_fusion_processors_exist(pipe.transformer), (
299-
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
300-
)
298+
assert check_qkv_fusion_processors_exist(
299+
pipe.transformer
300+
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
301301
assert check_qkv_fusion_matches_attn_procs_length(
302302
pipe.transformer, pipe.transformer.original_attn_processors
303303
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -311,15 +311,15 @@ def test_fused_qkv_projections(self):
311311
frames = pipe(**inputs).frames
312312
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
313313

314-
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
315-
"Fusion of QKV projections shouldn't affect the outputs."
316-
)
317-
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
318-
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
319-
)
320-
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
321-
"Original outputs should match when fused QKV projections are disabled."
322-
)
314+
assert np.allclose(
315+
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
316+
), "Fusion of QKV projections shouldn't affect the outputs."
317+
assert np.allclose(
318+
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
319+
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
320+
assert np.allclose(
321+
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
322+
), "Original outputs should match when fused QKV projections are disabled."
323323

324324

325325
@slow

0 commit comments

Comments
 (0)