Skip to content

Commit b756ec6

Browse files
authored
unet's sample_size attribute is to accept tuple(h, w) in StableDiffusionPipeline (huggingface#10181)
1 parent d8825e7 commit b756ec6

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class conditioning with `class_embed_type` equal to `None`.
170170
@register_to_config
171171
def __init__(
172172
self,
173-
sample_size: Optional[int] = None,
173+
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
174174
in_channels: int = 4,
175175
out_channels: int = 4,
176176
center_input_sample: bool = False,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,12 @@ def __init__(
255255
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
256256
version.parse(unet.config._diffusers_version).base_version
257257
) < version.parse("0.9.0.dev0")
258-
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
258+
self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
259+
is_unet_sample_size_less_64 = (
260+
hasattr(unet.config, "sample_size")
261+
and self._is_unet_config_sample_size_int
262+
and unet.config.sample_size < 64
263+
)
259264
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
260265
deprecation_message = (
261266
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -902,8 +907,18 @@ def __call__(
902907
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
903908

904909
# 0. Default height and width to unet
905-
height = height or self.unet.config.sample_size * self.vae_scale_factor
906-
width = width or self.unet.config.sample_size * self.vae_scale_factor
910+
if not height or not width:
911+
height = (
912+
self.unet.config.sample_size
913+
if self._is_unet_config_sample_size_int
914+
else self.unet.config.sample_size[0]
915+
)
916+
width = (
917+
self.unet.config.sample_size
918+
if self._is_unet_config_sample_size_int
919+
else self.unet.config.sample_size[1]
920+
)
921+
height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
907922
# to deal with lora scaling and other possible forward hooks
908923

909924
# 1. Check inputs. Raise error if not correct

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,14 @@ def callback_on_step_end(pipe, i, t, callback_kwargs):
840840
# they should be the same
841841
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
842842

843+
def test_pipeline_accept_tuple_type_unet_sample_size(self):
844+
# the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size
845+
sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
846+
sample_size = [60, 80]
847+
customised_unet = UNet2DConditionModel(sample_size=sample_size)
848+
pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet)
849+
assert pipe.unet.config.sample_size == sample_size
850+
843851

844852
@slow
845853
@require_torch_gpu

0 commit comments

Comments
 (0)