Skip to content

Commit e238284

Browse files
Merge branch 'huggingface:main' into cogview4
2 parents e94999e + 5d2d239 commit e238284

File tree

7 files changed

+19
-18
lines changed

7 files changed

+19
-18
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def preprocess_images(examples):
695695
)
696696
# We need to ensure that the original and the edited images undergo the same
697697
# augmentation transforms.
698-
images = np.concatenate([original_images, edited_images])
698+
images = np.stack([original_images, edited_images])
699699
images = torch.tensor(images)
700700
images = 2 * (images / 255) - 1
701701
return train_transforms(images)
@@ -706,7 +706,7 @@ def preprocess_train(examples):
706706
# Since the original and edited images were concatenated before
707707
# applying the transformations, we need to separate them and reshape
708708
# them accordingly.
709-
original_images, edited_images = preprocessed_images.chunk(2)
709+
original_images, edited_images = preprocessed_images
710710
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
711711
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
712712

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def preprocess_images(examples):
766766
)
767767
# We need to ensure that the original and the edited images undergo the same
768768
# augmentation transforms.
769-
images = np.concatenate([original_images, edited_images])
769+
images = np.stack([original_images, edited_images])
770770
images = torch.tensor(images)
771771
images = 2 * (images / 255) - 1
772772
return train_transforms(images)
@@ -906,7 +906,7 @@ def preprocess_train(examples):
906906
# Since the original and edited images were concatenated before
907907
# applying the transformations, we need to separate them and reshape
908908
# them accordingly.
909-
original_images, edited_images = preprocessed_images.chunk(2)
909+
original_images, edited_images = preprocessed_images
910910
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
911911
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
912912

src/diffusers/models/attention_processor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,11 +405,12 @@ def set_use_memory_efficient_attention_xformers(
405405
else:
406406
try:
407407
# Make sure we can run the memory efficient attention
408-
_ = xformers.ops.memory_efficient_attention(
409-
torch.randn((1, 2, 40), device="cuda"),
410-
torch.randn((1, 2, 40), device="cuda"),
411-
torch.randn((1, 2, 40), device="cuda"),
412-
)
408+
dtype = None
409+
if attention_op is not None:
410+
op_fw, op_bw = attention_op
411+
dtype, *_ = op_fw.SUPPORTED_DTYPES
412+
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
413+
_ = xformers.ops.memory_efficient_attention(q, q, q)
413414
except Exception as e:
414415
raise e
415416

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,10 @@ def check_inputs(
160160
prompt_attention_mask=None,
161161
negative_prompt_attention_mask=None,
162162
):
163-
if height % 8 != 0 or width % 8 != 0:
164-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
163+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
164+
raise ValueError(
165+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
166+
)
165167

166168
if prompt is not None and prompt_embeds is not None:
167169
raise ValueError(

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def check_inputs(
348348
prompt_template=None,
349349
):
350350
if height % 16 != 0 or width % 16 != 0:
351-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
351+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
352352

353353
if callback_on_step_end_tensor_inputs is not None and not all(
354354
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs

src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2323

2424
from ...image_processor import VaeImageProcessor
25-
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
25+
from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
2626
from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
2727
from ...models.lora import adjust_lora_scale_text_encoder
2828
from ...schedulers import KarrasDiffusionSchedulers
@@ -188,7 +188,7 @@ def retrieve_timesteps(
188188
return timesteps, num_inference_steps
189189

190190

191-
class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
191+
class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin):
192192
r"""
193193
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
194194
https://arxiv.org/abs/2302.08453

tests/pipelines/hunyuan_video/test_hunyuan_video.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
132132

133133
torch.manual_seed(0)
134134
text_encoder = LlamaModel(llama_text_encoder_config)
135-
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
135+
tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
136136

137137
torch.manual_seed(0)
138138
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
@@ -155,10 +155,8 @@ def get_dummy_inputs(self, device, seed=0):
155155
else:
156156
generator = torch.Generator(device=device).manual_seed(seed)
157157

158-
# Cannot test with dummy prompt because tokenizers are not configured correctly.
159-
# TODO(aryan): create dummy tokenizers and using from hub
160158
inputs = {
161-
"prompt": "",
159+
"prompt": "dance monkey",
162160
"prompt_template": {
163161
"template": "{}",
164162
"crop_start": 0,

0 commit comments

Comments
 (0)