Skip to content

Commit 0c11c8c

Browse files
authored
[CI] Fix SANA tests (huggingface#11756)
update
1 parent fc51583 commit 0c11c8c

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tests/pipelines/sana/test_sana_controlnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
enable_full_determinism,
3131
torch_device,
3232
)
33+
from diffusers.utils.torch_utils import randn_tensor
3334

3435
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3536
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -151,7 +152,7 @@ def get_dummy_inputs(self, device, seed=0):
151152
else:
152153
generator = torch.Generator(device=device).manual_seed(seed)
153154

154-
control_image = torch.randn(1, 3, 32, 32, generator=generator)
155+
control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
155156
inputs = {
156157
"prompt": "",
157158
"negative_prompt": "",

tests/pipelines/sana/test_sana_sprint_img2img.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
enable_full_determinism,
2525
torch_device,
2626
)
27+
from diffusers.utils.torch_utils import randn_tensor
2728

2829
from ..pipeline_params import (
2930
IMAGE_TO_IMAGE_IMAGE_PARAMS,
@@ -137,7 +138,7 @@ def get_dummy_inputs(self, device, seed=0):
137138
generator = torch.manual_seed(seed)
138139
else:
139140
generator = torch.Generator(device=device).manual_seed(seed)
140-
image = torch.randn(1, 3, 32, 32, generator=generator)
141+
image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
141142
inputs = {
142143
"prompt": "",
143144
"image": image,

0 commit comments

Comments
 (0)