@@ -255,7 +255,12 @@ def __init__(
255255 is_unet_version_less_0_9_0 = hasattr (unet .config , "_diffusers_version" ) and version .parse (
256256 version .parse (unet .config ._diffusers_version ).base_version
257257 ) < version .parse ("0.9.0.dev0" )
258- is_unet_sample_size_less_64 = hasattr (unet .config , "sample_size" ) and unet .config .sample_size < 64
258+ self ._is_unet_config_sample_size_int = isinstance (unet .config .sample_size , int )
259+ is_unet_sample_size_less_64 = (
260+ hasattr (unet .config , "sample_size" )
261+ and self ._is_unet_config_sample_size_int
262+ and unet .config .sample_size < 64
263+ )
259264 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64 :
260265 deprecation_message = (
261266 "The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -902,8 +907,18 @@ def __call__(
902907 callback_on_step_end_tensor_inputs = callback_on_step_end .tensor_inputs
903908
904909 # 0. Default height and width to unet
905- height = height or self .unet .config .sample_size * self .vae_scale_factor
906- width = width or self .unet .config .sample_size * self .vae_scale_factor
910+ if not height or not width :
911+ height = (
912+ self .unet .config .sample_size
913+ if self ._is_unet_config_sample_size_int
914+ else self .unet .config .sample_size [0 ]
915+ )
916+ width = (
917+ self .unet .config .sample_size
918+ if self ._is_unet_config_sample_size_int
919+ else self .unet .config .sample_size [1 ]
920+ )
921+ height , width = height * self .vae_scale_factor , width * self .vae_scale_factor
907922 # to deal with lora scaling and other possible forward hooks
908923
909924 # 1. Check inputs. Raise error if not correct
0 commit comments