Skip to content

Commit b2ca39c

Browse files
authored
[tests] test encode_prompt() in isolation (huggingface#10438)
* poc encode_prompt() tests * fix * updates. * fixes * fixes * updates * updates * updates * revert * updates * updates * updates * updates * remove SDXLOptionalComponentsTesterMixin. * remove tests that directly leveraged encode_prompt() in some way or the other. * fix imports. * remove _save_load * fixes * fixes * fixes * fixes
1 parent 5321712 commit b2ca39c

File tree

82 files changed

+609
-893
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+609
-893
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def encode_prompt(
268268
else:
269269
batch_size = prompt_embeds.shape[0]
270270

271-
self.tokenizer.padding_side = "right"
271+
if getattr(self, "tokenizer", None) is not None:
272+
self.tokenizer.padding_side = "right"
272273

273274
# See Section 3.1. of the paper.
274275
max_length = max_sequence_length

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,8 @@ def encode_prompt(
312312
else:
313313
batch_size = prompt_embeds.shape[0]
314314

315-
self.tokenizer.padding_side = "right"
315+
if getattr(self, "tokenizer", None) is not None:
316+
self.tokenizer.padding_side = "right"
316317

317318
# See Section 3.1. of the paper.
318319
max_length = max_sequence_length
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import ast
2+
import importlib
3+
import inspect
4+
import textwrap
5+
6+
7+
class ReturnNameVisitor(ast.NodeVisitor):
8+
"""Thanks to ChatGPT for pairing."""
9+
10+
def __init__(self):
11+
self.return_names = []
12+
13+
def visit_Return(self, node):
14+
# Check if the return value is a tuple.
15+
if isinstance(node.value, ast.Tuple):
16+
for elt in node.value.elts:
17+
if isinstance(elt, ast.Name):
18+
self.return_names.append(elt.id)
19+
else:
20+
try:
21+
self.return_names.append(ast.unparse(elt))
22+
except Exception:
23+
self.return_names.append(str(elt))
24+
else:
25+
if isinstance(node.value, ast.Name):
26+
self.return_names.append(node.value.id)
27+
else:
28+
try:
29+
self.return_names.append(ast.unparse(node.value))
30+
except Exception:
31+
self.return_names.append(str(node.value))
32+
self.generic_visit(node)
33+
34+
def _determine_parent_module(self, cls):
35+
from diffusers import DiffusionPipeline
36+
from diffusers.models.modeling_utils import ModelMixin
37+
38+
if issubclass(cls, DiffusionPipeline):
39+
return "pipelines"
40+
elif issubclass(cls, ModelMixin):
41+
return "models"
42+
else:
43+
raise NotImplementedError
44+
45+
def get_ast_tree(self, cls, attribute_name="encode_prompt"):
46+
parent_module_name = self._determine_parent_module(cls)
47+
main_module = importlib.import_module(f"diffusers.{parent_module_name}")
48+
current_cls_module = getattr(main_module, cls.__name__)
49+
source_code = inspect.getsource(getattr(current_cls_module, attribute_name))
50+
source_code = textwrap.dedent(source_code)
51+
tree = ast.parse(source_code)
52+
return tree

tests/pipelines/animatediff/test_animatediff.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,14 @@ def test_xformers_attention_forwardGenerator_pass(self):
548548
def test_vae_slicing(self):
549549
return super().test_vae_slicing(image_count=2)
550550

551+
def test_encode_prompt_works_in_isolation(self):
552+
extra_required_param_value_dict = {
553+
"device": torch.device(torch_device).type,
554+
"num_images_per_prompt": 1,
555+
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
556+
}
557+
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
558+
551559

552560
@slow
553561
@require_torch_accelerator

tests/pipelines/animatediff/test_animatediff_controlnet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,3 +517,11 @@ def test_vae_slicing(self, video_count=2):
517517
output_2 = pipe(**inputs)
518518

519519
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
520+
521+
def test_encode_prompt_works_in_isolation(self):
522+
extra_required_param_value_dict = {
523+
"device": torch.device(torch_device).type,
524+
"num_images_per_prompt": 1,
525+
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
526+
}
527+
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)

tests/pipelines/animatediff/test_animatediff_sdxl.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
IPAdapterTesterMixin,
2222
PipelineTesterMixin,
2323
SDFunctionTesterMixin,
24-
SDXLOptionalComponentsTesterMixin,
2524
)
2625

2726

@@ -36,7 +35,6 @@ class AnimateDiffPipelineSDXLFastTests(
3635
IPAdapterTesterMixin,
3736
SDFunctionTesterMixin,
3837
PipelineTesterMixin,
39-
SDXLOptionalComponentsTesterMixin,
4038
unittest.TestCase,
4139
):
4240
pipeline_class = AnimateDiffSDXLPipeline
@@ -250,33 +248,6 @@ def test_to_dtype(self):
250248
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
251249
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
252250

253-
def test_prompt_embeds(self):
254-
components = self.get_dummy_components()
255-
pipe = self.pipeline_class(**components)
256-
pipe.set_progress_bar_config(disable=None)
257-
pipe.to(torch_device)
258-
259-
inputs = self.get_dummy_inputs(torch_device)
260-
prompt = inputs.pop("prompt")
261-
262-
(
263-
prompt_embeds,
264-
negative_prompt_embeds,
265-
pooled_prompt_embeds,
266-
negative_pooled_prompt_embeds,
267-
) = pipe.encode_prompt(prompt)
268-
269-
pipe(
270-
**inputs,
271-
prompt_embeds=prompt_embeds,
272-
negative_prompt_embeds=negative_prompt_embeds,
273-
pooled_prompt_embeds=pooled_prompt_embeds,
274-
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
275-
)
276-
277-
def test_save_load_optional_components(self):
278-
self._test_save_load_optional_components()
279-
280251
@unittest.skipIf(
281252
torch_device != "cuda" or not is_xformers_available(),
282253
reason="XFormers attention is only available with CUDA and `xformers` installed",
@@ -305,3 +276,11 @@ def test_xformers_attention_forwardGenerator_pass(self):
305276

306277
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
307278
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
279+
280+
@unittest.skip("Test currently not supported.")
281+
def test_encode_prompt_works_in_isolation(self):
282+
pass
283+
284+
@unittest.skip("Functionality is tested elsewhere.")
285+
def test_save_load_optional_components(self):
286+
pass

tests/pipelines/animatediff/test_animatediff_sparsectrl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,3 +484,11 @@ def test_free_init_with_schedulers(self):
484484

485485
def test_vae_slicing(self):
486486
return super().test_vae_slicing(image_count=2)
487+
488+
def test_encode_prompt_works_in_isolation(self):
489+
extra_required_param_value_dict = {
490+
"device": torch.device(torch_device).type,
491+
"num_images_per_prompt": 1,
492+
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
493+
}
494+
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)

tests/pipelines/animatediff/test_animatediff_video2video.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,3 +544,11 @@ def test_free_noise_multi_prompt(self):
544544
inputs["strength"] = 0.5
545545
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
546546
pipe(**inputs).frames[0]
547+
548+
def test_encode_prompt_works_in_isolation(self):
549+
extra_required_param_value_dict = {
550+
"device": torch.device(torch_device).type,
551+
"num_images_per_prompt": 1,
552+
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
553+
}
554+
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)

tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,11 @@ def test_free_noise_multi_prompt(self):
533533
inputs["strength"] = 0.5
534534
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
535535
pipe(**inputs).frames[0]
536+
537+
def test_encode_prompt_works_in_isolation(self):
538+
extra_required_param_value_dict = {
539+
"device": torch.device(torch_device).type,
540+
"num_images_per_prompt": 1,
541+
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
542+
}
543+
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)

tests/pipelines/audioldm2/test_audioldm2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,9 +508,14 @@ def test_to_dtype(self):
508508
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
509509
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
510510

511+
@unittest.skip("Test not supported.")
511512
def test_sequential_cpu_offload_forward_pass(self):
512513
pass
513514

515+
@unittest.skip("Test not supported for now because of the use of `projection_model` in `encode_prompt()`.")
516+
def test_encode_prompt_works_in_isolation(self):
517+
pass
518+
514519

515520
@nightly
516521
class AudioLDM2PipelineSlowTests(unittest.TestCase):

0 commit comments

Comments
 (0)