diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index fc3022cf7b35..752219b4abd1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -598,6 +598,8 @@ title: Attention Processor - local: api/activations title: Custom activation functions + - local: api/cache + title: Caching methods - local: api/normalization title: Custom normalization layers - local: api/utilities diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md new file mode 100644 index 000000000000..403dbf88b431 --- /dev/null +++ b/docs/source/en/api/cache.md @@ -0,0 +1,49 @@ + + +# Caching methods + +## Pyramid Attention Broadcast + +[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. + +Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation. + +Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request. + +```python +import torch +from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of +# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention +# broadcast is active, leader to slower inference speeds. However, large intervals can lead to +# poorer quality of generated videos. +config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(100, 800), + current_timestep_callback=lambda: pipe.current_timestep, +) +pipe.transformer.enable_cache(config) +``` + +### CacheMixin + +[[autodoc]] CacheMixin + +### PyramidAttentionBroadcastConfig + +[[autodoc]] PyramidAttentionBroadcastConfig + +[[autodoc]] apply_pyramid_attention_broadcast diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 1d7a367ecc60..4895bd150114 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -80,7 +80,6 @@ USE_PEFT_BACKEND, BaseOutput, deprecate, - is_torch_version, is_torch_xla_available, logging, replace_example_docstring, @@ -869,23 +868,7 @@ def forward( for i, (resnet, attn) in enumerate(blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1030,17 +1013,6 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1049,12 +1021,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = attn( hidden_states, @@ -1192,23 +1159,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1282,10 +1233,6 @@ def __init__( ] ) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -1365,19 +1312,8 @@ def forward( # Blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -1385,7 +1321,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, class_labels, - **ckpt_kwargs, ) else: hidden_states = block( @@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py new file mode 100644 index 000000000000..3bb080510902 --- /dev/null +++ b/examples/community/pipeline_flux_semantic_guidance.py @@ -0,0 +1,1351 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers import FluxTransformer2DModel +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> pipe = DiffusionPipeline.from_pretrained( + >>> "black-forest-labs/FLUX.1-dev", + >>> custom_pipeline="pipeline_flux_semantic_guidance", + >>> torch_dtype=torch.bfloat16 + >>> ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe( + >>> prompt=prompt, + >>> num_inference_steps=28, + >>> guidance_scale=3.5, + >>> editing_prompt=["cat", "dog"], # changes from cat to dog. + >>> reverse_editing_direction=[True, False], + >>> edit_warmup_steps=[6, 8], + >>> edit_guidance_scale=[6, 6.5], + >>> edit_threshold=[0.89, 0.89], + >>> edit_cooldown_steps = [25, 27], + >>> edit_momentum_scale=0.3, + >>> edit_mom_beta=0.6, + >>> generator=torch.Generator(device="cuda").manual_seed(6543), + >>> ).images[0] + >>> image.save("semantic_flux.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxSemanticGuidancePipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Flux pipeline for text-to-image generation with semantic guidance. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def encode_text_with_editing( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + editing_prompt: Optional[List[str]] = None, + editing_prompt_2: Optional[List[str]] = None, + editing_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + """ + Encode text prompts with editing prompts and negative prompts for semantic guidance. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + prompt_2 (`str` or `List[str]`): + The prompt or prompts to guide image generation for second tokenizer. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + editing_prompt (`str` or `List[str]`, *optional*): + The editing prompts for semantic guidance. + editing_prompt_2 (`str` or `List[str]`, *optional*): + The editing prompts for semantic guidance for second tokenizer. + editing_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-computed embeddings for editing prompts. + pooled_editing_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-computed pooled embeddings for editing prompts. + device (`torch.device`, *optional*): + The device to use for computation. + num_images_per_prompt (`int`, defaults to 1): + Number of images to generate per prompt. + max_sequence_length (`int`, defaults to 512): + Maximum sequence length for text encoding. + lora_scale (`float`, *optional*): + Scale factor for LoRA layers if used. + + Returns: + tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, int]: + A tuple containing the prompt embeddings, pooled prompt embeddings, + text IDs, and number of enabled editing prompts. + """ + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError("Prompt must be provided as string or list of strings") + + # Get base prompt embeddings + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # Handle editing prompts + if editing_prompt_embeds is not None: + enabled_editing_prompts = int(editing_prompt_embeds.shape[0]) + edit_text_ids = [] + elif editing_prompt is not None: + editing_prompt_embeds = [] + pooled_editing_prompt_embeds = [] + edit_text_ids = [] + + editing_prompt_2 = editing_prompt if editing_prompt_2 is None else editing_prompt_2 + for edit_1, edit_2 in zip(editing_prompt, editing_prompt_2): + e_prompt_embeds, pooled_embeds, e_ids = self.encode_prompt( + prompt=edit_1, + prompt_2=edit_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + editing_prompt_embeds.append(e_prompt_embeds) + pooled_editing_prompt_embeds.append(pooled_embeds) + edit_text_ids.append(e_ids) + + enabled_editing_prompts = len(editing_prompt) + + else: + edit_text_ids = [] + enabled_editing_prompts = 0 + + if enabled_editing_prompts: + for idx in range(enabled_editing_prompts): + editing_prompt_embeds[idx] = torch.cat([editing_prompt_embeds[idx]] * batch_size, dim=0) + pooled_editing_prompt_embeds[idx] = torch.cat([pooled_editing_prompt_embeds[idx]] * batch_size, dim=0) + + return ( + prompt_embeds, + pooled_prompt_embeds, + editing_prompt_embeds, + pooled_editing_prompt_embeds, + text_ids, + edit_text_ids, + enabled_editing_prompts, + ) + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers + ): + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + + image_embeds.append(single_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + editing_prompt: Optional[Union[str, List[str]]] = None, + editing_prompt_2: Optional[Union[str, List[str]]] = None, + editing_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, + reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, + edit_guidance_scale: Optional[Union[float, List[float]]] = 5, + edit_warmup_steps: Optional[Union[int, List[int]]] = 8, + edit_cooldown_steps: Optional[Union[int, List[int]]] = None, + edit_threshold: Optional[Union[float, List[float]]] = 0.9, + edit_momentum_scale: Optional[float] = 0.1, + edit_mom_beta: Optional[float] = 0.4, + edit_weights: Optional[List[float]] = None, + sem_guidance: Optional[List[torch.Tensor]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + editing_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image editing. If not defined, no editing will be performed. + editing_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image editing. If not defined, will use editing_prompt instead. + editing_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings for editing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, text embeddings will be generated from `editing_prompt` input argument. + reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): + Whether to reverse the editing direction for each editing prompt. + edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): + Guidance scale for the editing process. If provided as a list, each value corresponds to an editing prompt. + edit_warmup_steps (`int` or `List[int]`, *optional*, defaults to 10): + Number of warmup steps for editing guidance. If provided as a list, each value corresponds to an editing prompt. + edit_cooldown_steps (`int` or `List[int]`, *optional*, defaults to None): + Number of cooldown steps for editing guidance. If provided as a list, each value corresponds to an editing prompt. + edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): + Threshold for editing guidance. If provided as a list, each value corresponds to an editing prompt. + edit_momentum_scale (`float`, *optional*, defaults to 0.1): + Scale of momentum to be added to the editing guidance at each diffusion step. + edit_mom_beta (`float`, *optional*, defaults to 0.4): + Beta value for momentum calculation in editing guidance. + edit_weights (`List[float]`, *optional*): + Weights for each editing prompt. + sem_guidance (`List[torch.Tensor]`, *optional*): + Pre-generated semantic guidance. If provided, it will be used instead of calculating guidance from editing prompts. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_embeds is not None: + enable_edit_guidance = True + enabled_editing_prompts = editing_prompt_embeds.shape[0] + else: + enabled_editing_prompts = 0 + enable_edit_guidance = False + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + editing_prompts_embeds, + pooled_editing_prompt_embeds, + text_ids, + edit_text_ids, + enabled_editing_prompts, + ) = self.encode_text_with_editing( + prompt=prompt, + prompt_2=prompt_2, + pooled_prompt_embeds=pooled_prompt_embeds, + editing_prompt=editing_prompt, + editing_prompt_2=editing_prompt_2, + pooled_editing_prompt_embeds=pooled_editing_prompt_embeds, + lora_scale=lora_scale, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + negative_prompt_embeds = torch.cat([negative_prompt_embeds] * batch_size, dim=0) + negative_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds] * batch_size, dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + edit_momentum = None + if edit_warmup_steps: + tmp_e_warmup_steps = edit_warmup_steps if isinstance(edit_warmup_steps, list) else [edit_warmup_steps] + min_edit_warmup_steps = min(tmp_e_warmup_steps) + else: + min_edit_warmup_steps = 0 + + if edit_cooldown_steps: + tmp_e_cooldown_steps = ( + edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps] + ) + max_edit_cooldown_steps = min(max(tmp_e_cooldown_steps), num_inference_steps) + else: + max_edit_cooldown_steps = num_inference_steps + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=device) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps: + noise_pred_edit_concepts = [] + for e_embed, pooled_e_embed, e_text_id in zip( + editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids + ): + noise_pred_edit = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_e_embed, + encoder_hidden_states=e_embed, + txt_ids=e_text_id, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred_edit_concepts.append(noise_pred_edit) + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_guidance = true_cfg_scale * (noise_pred - noise_pred_uncond) + else: + noise_pred_uncond = noise_pred + noise_guidance = noise_pred + + if edit_momentum is None: + edit_momentum = torch.zeros_like(noise_guidance) + + if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps: + concept_weights = torch.zeros( + (enabled_editing_prompts, noise_guidance.shape[0]), + device=device, + dtype=noise_guidance.dtype, + ) + noise_guidance_edit = torch.zeros( + (enabled_editing_prompts, *noise_guidance.shape), + device=device, + dtype=noise_guidance.dtype, + ) + + warmup_inds = [] + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + if edit_weights: + edit_weight_c = edit_weights[c] + else: + edit_weight_c = 1.0 + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + if i >= edit_warmup_steps_c: + warmup_inds.append(c) + if i >= edit_cooldown_steps_c: + noise_guidance_edit[c, :, :, :] = torch.zeros_like(noise_pred_edit_concept) + continue + + if do_true_cfg: + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + else: # simple sega + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred + tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2)) + + tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + concept_weights[c, :] = tmp_weights + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp.dtype == torch.float32: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp.dtype) + + noise_guidance_edit_tmp = torch.where( + torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + + noise_guidance_edit[c, :, :, :] = noise_guidance_edit_tmp + + warmup_inds = torch.tensor(warmup_inds).to(device) + if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: + concept_weights = concept_weights.to("cpu") # Offload to cpu + noise_guidance_edit = noise_guidance_edit.to("cpu") + + concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds) + concept_weights_tmp = torch.where( + concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp + ) + concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) + + noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds) + noise_guidance_edit_tmp = torch.einsum( + "cb,cbij->bij", concept_weights_tmp, noise_guidance_edit_tmp + ) + noise_guidance_edit_tmp = noise_guidance_edit_tmp + noise_guidance = noise_guidance + noise_guidance_edit_tmp + + del noise_guidance_edit_tmp + del concept_weights_tmp + concept_weights = concept_weights.to(device) + noise_guidance_edit = noise_guidance_edit.to(device) + + concept_weights = torch.where( + concept_weights < 0, torch.zeros_like(concept_weights), concept_weights + ) + + concept_weights = torch.nan_to_num(concept_weights) + + noise_guidance_edit = torch.einsum("cb,cbij->bij", concept_weights, noise_guidance_edit) + + noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum + + edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit + + if warmup_inds.shape[0] == len(noise_pred_edit_concepts): + noise_guidance = noise_guidance + noise_guidance_edit + + if sem_guidance is not None: + edit_guidance = sem_guidance[i].to(device) + noise_guidance = noise_guidance + edit_guidance + + if do_true_cfg: + noise_pred = noise_guidance + noise_pred_uncond + else: + noise_pred = noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput( + image, + ) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index f97a4d0cd0f4..eed0575c322d 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \ ## Stable Diffusion XL We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md). + +## Dataset + +We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own. + +The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder). + +We need to create a file `metadata.jsonl` in the directory with our images: + +``` +{"file_name": "01.jpg", "prompt": "prompt 01"} +{"file_name": "02.jpg", "prompt": "prompt 02"} +``` + +If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`. + +```sh +python convert_to_imagefolder.py --path my_dataset/ +``` + +We use `--dataset_name` and `--caption_column` with training scripts. + +``` +--dataset_name=my_dataset/ +--caption_column=prompt +``` diff --git a/examples/dreambooth/convert_to_imagefolder.py b/examples/dreambooth/convert_to_imagefolder.py new file mode 100644 index 000000000000..333080077428 --- /dev/null +++ b/examples/dreambooth/convert_to_imagefolder.py @@ -0,0 +1,32 @@ +import argparse +import json +import pathlib + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--path", + type=str, + required=True, + help="Path to folder with image-text pairs.", +) +parser.add_argument("--caption_column", type=str, default="prompt", help="Name of caption column.") +args = parser.parse_args() + +path = pathlib.Path(args.path) +if not path.exists(): + raise RuntimeError(f"`--path` '{args.path}' does not exist.") + +all_files = list(path.glob("*")) +captions = list(path.glob("*.txt")) +images = set(all_files) - set(captions) +images = {image.stem: image for image in images} +caption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)} + +metadata = path.joinpath("metadata.jsonl") + +with metadata.open("w", encoding="utf-8") as f: + for caption, image in caption_image.items(): + caption_text = caption.read_text(encoding="utf-8") + json.dump({"file_name": image.name, args.caption_column: caption_text}, f) + f.write("\n") diff --git a/examples/research_projects/autoencoderkl/README.md b/examples/research_projects/autoencoderkl/README.md new file mode 100644 index 000000000000..c62018312da5 --- /dev/null +++ b/examples/research_projects/autoencoderkl/README.md @@ -0,0 +1,59 @@ +# AutoencoderKL training example + +## Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +## Training on CIFAR10 + +Please replace the validation image with your own image. + +```bash +accelerate launch train_autoencoderkl.py \ + --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \ + --dataset_name=cifar10 \ + --image_column=img \ + --validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \ + --num_train_epochs 100 \ + --gradient_accumulation_steps 2 \ + --learning_rate 4.5e-6 \ + --lr_scheduler cosine \ + --report_to wandb \ +``` + +## Training on ImageNet + +```bash +accelerate launch train_autoencoderkl.py \ + --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \ + --num_train_epochs 100 \ + --gradient_accumulation_steps 2 \ + --learning_rate 4.5e-6 \ + --lr_scheduler cosine \ + --report_to wandb \ + --mixed_precision bf16 \ + --train_data_dir /path/to/ImageNet/train \ + --validation_image ./image.png \ + --decoder_only +``` diff --git a/examples/research_projects/autoencoderkl/requirements.txt b/examples/research_projects/autoencoderkl/requirements.txt new file mode 100644 index 000000000000..fe501252b46a --- /dev/null +++ b/examples/research_projects/autoencoderkl/requirements.txt @@ -0,0 +1,15 @@ +accelerate>=0.16.0 +bitsandbytes +datasets +huggingface_hub +lpips +numpy +packaging +Pillow +taming_transformers +torch +torchvision +tqdm +transformers +wandb +xformers diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py new file mode 100644 index 000000000000..cf13ecdbf8ac --- /dev/null +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py @@ -0,0 +1,1053 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import contextlib +import gc +import logging +import math +import os +import shutil +from pathlib import Path + +import accelerate +import lpips +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from taming.modules.losses.vqperceptual import NLayerDiscriminator, hinge_d_loss, vanilla_d_loss, weights_init +from torchvision import transforms +from tqdm.auto import tqdm + +import diffusers +from diffusers import AutoencoderKL +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.33.0.dev0") + +logger = get_logger(__name__) + + +@torch.no_grad() +def log_validation(vae, args, accelerator, weight_dtype, step, is_final_validation=False): + logger.info("Running validation... ") + + if not is_final_validation: + vae = accelerator.unwrap_model(vae) + else: + vae = AutoencoderKL.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + + images = [] + inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + for i, validation_image in enumerate(args.validation_image): + validation_image = Image.open(validation_image).convert("RGB") + targets = image_transforms(validation_image).to(accelerator.device, weight_dtype) + targets = targets.unsqueeze(0) + + with inference_ctx: + reconstructions = vae(targets).sample + + images.append(torch.cat([targets.cpu(), reconstructions.cpu()], axis=0)) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(f"{tracker_key}: Original (left), Reconstruction (right)", np_images, step) + elif tracker.name == "wandb": + tracker.log( + { + f"{tracker_key}: Original (left), Reconstruction (right)": [ + wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + gc.collect() + torch.cuda.empty_cache() + + return images + + +def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): + img_str = "" + if images is not None: + img_str = "You can find some example images below.\n\n" + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, "images.png")) + img_str += "![images](./images.png)\n" + + model_description = f""" +# autoencoderkl-{repo_id} + +These are autoencoderkl weights trained on {base_model} with new type of conditioning. +{img_str} +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "stable-diffusion", + "stable-diffusion-diffusers", + "image-to-image", + "diffusers", + "autoencoderkl", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a AutoencoderKL training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--model_config_name_or_path", + type=str, + default=None, + help="The config of the VAE model to train, leave as None to use standard VAE model configuration.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="autoencoderkl-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=4.5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--disc_learning_rate", + type=float, + default=4.5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--disc_lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help="A set of paths to the image be evaluated every `--validation_steps` and logged to `--report_to`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_autoencoderkl", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument( + "--rec_loss", + type=str, + default="l2", + help="The loss function for VAE reconstruction loss.", + ) + parser.add_argument( + "--kl_scale", + type=float, + default=1e-6, + help="Scaling factor for the Kullback-Leibler divergence penalty term.", + ) + parser.add_argument( + "--perceptual_scale", + type=float, + default=0.5, + help="Scaling factor for the LPIPS metric", + ) + parser.add_argument( + "--disc_start", + type=int, + default=50001, + help="Start for the discriminator", + ) + parser.add_argument( + "--disc_factor", + type=float, + default=1.0, + help="Scaling factor for the discriminator", + ) + parser.add_argument( + "--disc_scale", + type=float, + default=1.0, + help="Scaling factor for the discriminator", + ) + parser.add_argument( + "--disc_loss", + type=str, + default="hinge", + help="Loss function for the discriminator", + ) + parser.add_argument( + "--decoder_only", + action="store_true", + help="Only train the VAE decoder.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.pretrained_model_name_or_path is not None and args.model_config_name_or_path is not None: + raise ValueError("Cannot specify both `--pretrained_model_name_or_path` and `--model_config_name_or_path`") + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the diffusion model." + ) + + return args + + +def make_train_dataset(args, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + images = [image_transforms(image) for image in images] + + examples["pixel_values"] = images + + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + return train_dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + return {"pixel_values": pixel_values} + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load AutoencoderKL + if args.pretrained_model_name_or_path is None and args.model_config_name_or_path is None: + config = AutoencoderKL.load_config("stabilityai/sd-vae-ft-mse") + vae = AutoencoderKL.from_config(config) + elif args.pretrained_model_name_or_path is not None: + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, revision=args.revision) + else: + config = AutoencoderKL.load_config(args.model_config_name_or_path) + vae = AutoencoderKL.from_config(config) + if args.use_ema: + ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config) + perceptual_loss = lpips.LPIPS(net="vgg").eval() + discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init) + + # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + sub_dir = "autoencoderkl_ema" + ema_vae.save_pretrained(os.path.join(output_dir, sub_dir)) + + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + if isinstance(model, AutoencoderKL): + sub_dir = "autoencoderkl" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + else: + sub_dir = "discriminator" + os.makedirs(os.path.join(output_dir, sub_dir), exist_ok=True) + torch.save(model.state_dict(), os.path.join(output_dir, sub_dir, "pytorch_model.bin")) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + if args.use_ema: + sub_dir = "autoencoderkl_ema" + load_model = EMAModel.from_pretrained(os.path.join(input_dir, sub_dir), AutoencoderKL) + ema_vae.load_state_dict(load_model.state_dict()) + ema_vae.to(accelerator.device) + del load_model + + # pop models so that they are not loaded again + model = models.pop() + load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict( + os.path.join(input_dir, "discriminator", "pytorch_model.bin") + ) + model.load_state_dict(load_model.state_dict()) + del load_model + + model = models.pop() + load_model = AutoencoderKL.from_pretrained(input_dir, subfolder="autoencoderkl") + model.register_to_config(**load_model.config) + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + vae.requires_grad_(True) + if args.decoder_only: + vae.encoder.requires_grad_(False) + if getattr(vae, "quant_conv", None): + vae.quant_conv.requires_grad_(False) + vae.train() + discriminator.requires_grad_(True) + discriminator.train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + vae.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + vae.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if unwrap_model(vae).dtype != torch.float32: + raise ValueError(f"VAE loaded as datatype {unwrap_model(vae).dtype}. {low_precision_error_string}") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = filter(lambda p: p.requires_grad, vae.parameters()) + disc_params_to_optimize = filter(lambda p: p.requires_grad, discriminator.parameters()) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + disc_optimizer = optimizer_class( + disc_params_to_optimize, + lr=args.disc_learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + train_dataset = make_train_dataset(args, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + disc_lr_scheduler = get_scheduler( + args.disc_lr_scheduler, + optimizer=disc_optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + ( + vae, + discriminator, + optimizer, + disc_optimizer, + train_dataloader, + lr_scheduler, + disc_lr_scheduler, + ) = accelerator.prepare( + vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move VAE, perceptual loss and discriminator to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + perceptual_loss.to(accelerator.device, dtype=weight_dtype) + discriminator.to(accelerator.device, dtype=weight_dtype) + if args.use_ema: + ema_vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + vae.train() + discriminator.train() + for step, batch in enumerate(train_dataloader): + # Convert images to latent space and reconstruct from them + targets = batch["pixel_values"].to(dtype=weight_dtype) + posterior = accelerator.unwrap_model(vae).encode(targets).latent_dist + latents = posterior.sample() + reconstructions = accelerator.unwrap_model(vae).decode(latents).sample + + if (step // args.gradient_accumulation_steps) % 2 == 0 or global_step < args.disc_start: + with accelerator.accumulate(vae): + # reconstruction loss. Pixel level differences between input vs output + if args.rec_loss == "l2": + rec_loss = F.mse_loss(reconstructions.float(), targets.float(), reduction="none") + elif args.rec_loss == "l1": + rec_loss = F.l1_loss(reconstructions.float(), targets.float(), reduction="none") + else: + raise ValueError(f"Invalid reconstruction loss type: {args.rec_loss}") + # perceptual loss. The high level feature mean squared error loss + with torch.no_grad(): + p_loss = perceptual_loss(reconstructions, targets) + + rec_loss = rec_loss + args.perceptual_scale * p_loss + nll_loss = rec_loss + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + logits_fake = discriminator(reconstructions) + g_loss = -torch.mean(logits_fake) + last_layer = accelerator.unwrap_model(vae).decoder.conv_out.weight + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + disc_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + disc_weight = torch.clamp(disc_weight, 0.0, 1e4).detach() + disc_weight = disc_weight * args.disc_scale + disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0 + + loss = nll_loss + args.kl_scale * kl_loss + disc_weight * disc_factor * g_loss + + logs = { + "loss": loss.detach().mean().item(), + "nll_loss": nll_loss.detach().mean().item(), + "rec_loss": rec_loss.detach().mean().item(), + "p_loss": p_loss.detach().mean().item(), + "kl_loss": kl_loss.detach().mean().item(), + "disc_weight": disc_weight.detach().mean().item(), + "disc_factor": disc_factor, + "g_loss": g_loss.detach().mean().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = vae.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + else: + with accelerator.accumulate(discriminator): + logits_real = discriminator(targets) + logits_fake = discriminator(reconstructions) + disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss + disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0 + disc_loss = disc_factor * disc_loss(logits_real, logits_fake) + logs = { + "disc_loss": disc_loss.detach().mean().item(), + "logits_real": logits_real.detach().mean().item(), + "logits_fake": logits_fake.detach().mean().item(), + "disc_lr": disc_lr_scheduler.get_last_lr()[0], + } + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + if args.use_ema: + ema_vae.step(vae.parameters()) + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step == 1 or global_step % args.validation_steps == 0: + if args.use_ema: + ema_vae.store(vae.parameters()) + ema_vae.copy_to(vae.parameters()) + image_logs = log_validation( + vae, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + ema_vae.restore(vae.parameters()) + + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + vae = accelerator.unwrap_model(vae) + discriminator = accelerator.unwrap_model(discriminator) + if args.use_ema: + ema_vae.copy_to(vae.parameters()) + vae.save_pretrained(args.output_dir) + torch.save(discriminator.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin")) + # Run a final round of validation. + image_logs = None + image_logs = log_validation( + vae=vae, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py index f825719a1364..8f2eb974398d 100644 --- a/examples/research_projects/pixart/controlnet_pixart_alpha.py +++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py @@ -8,7 +8,6 @@ from diffusers.models.attention import BasicTransformerBlock from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils.torch_utils import is_torch_version class PixArtControlNetAdapterBlock(nn.Module): @@ -151,10 +150,6 @@ def __init__( self.transformer = transformer self.controlnet = controlnet - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -220,18 +215,8 @@ def forward( print("Gradient checkpointing is not supported for the controlnet transformer model, yet.") exit(1) - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -239,7 +224,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, None, - **ckpt_kwargs, ) else: # the control nets are only used for the blocks 1 to self.blocks_num diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b1801fbb2b4b..c36226225ad4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -28,6 +28,7 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], @@ -75,6 +76,13 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["hooks"].extend( + [ + "HookRegistry", + "PyramidAttentionBroadcastConfig", + "apply_pyramid_attention_broadcast", + ] + ) _import_structure["models"].extend( [ "AllegroTransformer3DModel", @@ -90,6 +98,7 @@ "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", + "CacheMixin", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", "ConsisIDTransformer3DModel", @@ -588,6 +597,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, @@ -602,6 +612,7 @@ AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, + CacheMixin, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, ConsisIDTransformer3DModel, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 91b2760acad0..e745b1320e84 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -2,4 +2,6 @@ if is_torch_available(): + from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook + from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index bef4c65c41e1..3b2e4ed91c2f 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -30,6 +30,9 @@ class ModelHook: _is_stateful = False + def __init__(self): + self.fn_ref: "HookFunctionReference" = None + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. @@ -48,8 +51,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: module (`torch.nn.Module`): The module attached to this hook. """ - module.forward = module._old_forward - del module._old_forward return module def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: @@ -99,6 +100,29 @@ def reset_state(self, module: torch.nn.Module): return module +class HookFunctionReference: + def __init__(self) -> None: + """A container class that maintains mutable references to forward pass functions in a hook chain. + + Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the + entire forward pass structure. + + Attributes: + pre_forward: A callable that processes inputs before the main forward pass. + post_forward: A callable that processes outputs after the main forward pass. + forward: The current forward function in the hook chain. + original_forward: The original forward function, stored when a hook provides a custom new_forward. + + The class enables hook removal by allowing updates to the forward chain through reference modification rather + than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to + be updated, preserving the execution order of the remaining hooks. + """ + self.pre_forward = None + self.post_forward = None + self.forward = None + self.original_forward = None + + class HookRegistry: def __init__(self, module_ref: torch.nn.Module) -> None: super().__init__() @@ -107,51 +131,71 @@ def __init__(self, module_ref: torch.nn.Module) -> None: self._module_ref = module_ref self._hook_order = [] + self._fn_refs = [] def register_hook(self, hook: ModelHook, name: str) -> None: if name in self.hooks.keys(): - logger.warning(f"Hook with name {name} already exists, replacing it.") - - if hasattr(self._module_ref, "_old_forward"): - old_forward = self._module_ref._old_forward - else: - old_forward = self._module_ref.forward - self._module_ref._old_forward = self._module_ref.forward + raise ValueError( + f"Hook with name {name} already exists in the registry. Please use a different name or " + f"first remove the existing hook and then add a new one." + ) self._module_ref = hook.initialize_hook(self._module_ref) - if hasattr(hook, "new_forward"): - rewritten_forward = hook.new_forward - + def create_new_forward(function_reference: HookFunctionReference): def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = rewritten_forward(module, *args, **kwargs) - return hook.post_forward(module, output) - else: + args, kwargs = function_reference.pre_forward(module, *args, **kwargs) + output = function_reference.forward(*args, **kwargs) + return function_reference.post_forward(module, output) - def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = old_forward(*args, **kwargs) - return hook.post_forward(module, output) + return new_forward + + forward = self._module_ref.forward + fn_ref = HookFunctionReference() + fn_ref.pre_forward = hook.pre_forward + fn_ref.post_forward = hook.post_forward + fn_ref.forward = forward + + if hasattr(hook, "new_forward"): + fn_ref.original_forward = forward + fn_ref.forward = functools.update_wrapper( + functools.partial(hook.new_forward, self._module_ref), hook.new_forward + ) + + rewritten_forward = create_new_forward(fn_ref) self._module_ref.forward = functools.update_wrapper( - functools.partial(new_forward, self._module_ref), old_forward + functools.partial(rewritten_forward, self._module_ref), rewritten_forward ) + hook.fn_ref = fn_ref self.hooks[name] = hook self._hook_order.append(name) + self._fn_refs.append(fn_ref) def get_hook(self, name: str) -> Optional[ModelHook]: - if name not in self.hooks.keys(): - return None - return self.hooks[name] + return self.hooks.get(name, None) def remove_hook(self, name: str, recurse: bool = True) -> None: if name in self.hooks.keys(): + num_hooks = len(self._hook_order) hook = self.hooks[name] + index = self._hook_order.index(name) + fn_ref = self._fn_refs[index] + + old_forward = fn_ref.forward + if fn_ref.original_forward is not None: + old_forward = fn_ref.original_forward + + if index == num_hooks - 1: + self._module_ref.forward = old_forward + else: + self._fn_refs[index + 1].forward = old_forward + self._module_ref = hook.deinitalize_hook(self._module_ref) del self.hooks[name] - self._hook_order.remove(name) + self._hook_order.pop(index) + self._fn_refs.pop(index) if recurse: for module_name, module in self._module_ref.named_modules(): @@ -161,7 +205,7 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: module._diffusers_hook.remove_hook(name, recurse=False) def reset_stateful_hooks(self, recurse: bool = True) -> None: - for hook_name in self._hook_order: + for hook_name in reversed(self._hook_order): hook = self.hooks[hook_name] if hook._is_stateful: hook.reset_state(self._module_ref) @@ -180,9 +224,13 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry return module._diffusers_hook def __repr__(self) -> str: - hook_repr = "" + registry_repr = "" for i, hook_name in enumerate(self._hook_order): - hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if self.hooks[hook_name].__class__.__repr__ is not object.__repr__: + hook_repr = self.hooks[hook_name].__repr__() + else: + hook_repr = self.hooks[hook_name].__class__.__name__ + registry_repr += f" ({i}) {hook_name} - {hook_repr}" if i < len(self._hook_order) - 1: - hook_repr += "\n" - return f"HookRegistry(\n{hook_repr}\n)" + registry_repr += "\n" + return f"HookRegistry(\n{registry_repr}\n)" diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py new file mode 100644 index 000000000000..9f8597d52f8c --- /dev/null +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -0,0 +1,314 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple, Union + +import torch + +from ..models.attention_processor import Attention, MochiAttention +from ..utils import logging +from .hooks import HookRegistry, ModelHook + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +_ATTENTION_CLASSES = (Attention, MochiAttention) + +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") + + +@dataclass +class PyramidAttentionBroadcastConfig: + r""" + Configuration for Pyramid Attention Broadcast. + + Args: + spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`): + The number of times a specific spatial attention broadcast is skipped before computing the attention states + to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., + old attention states will be re-used) before computing the new attention states again. + temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`): + The number of times a specific temporal attention broadcast is skipped before computing the attention + states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times + (i.e., old attention states will be re-used) before computing the new attention states again. + cross_attention_block_skip_range (`int`, *optional*, defaults to `None`): + The number of times a specific cross-attention broadcast is skipped before computing the attention states + to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., + old attention states will be re-used) before computing the new attention states again. + spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The range of timesteps to skip in the spatial attention layer. The attention computations will be + conditionally skipped if the current timestep is within the specified range. + temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The range of timesteps to skip in the temporal attention layer. The attention computations will be + conditionally skipped if the current timestep is within the specified range. + cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The range of timesteps to skip in the cross-attention layer. The attention computations will be + conditionally skipped if the current timestep is within the specified range. + spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): + The identifiers to match against the layer names to determine if the layer is a spatial attention layer. + temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`): + The identifiers to match against the layer names to determine if the layer is a temporal attention layer. + cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): + The identifiers to match against the layer names to determine if the layer is a cross-attention layer. + """ + + spatial_attention_block_skip_range: Optional[int] = None + temporal_attention_block_skip_range: Optional[int] = None + cross_attention_block_skip_range: Optional[int] = None + + spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS + cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS + + current_timestep_callback: Callable[[], int] = None + + # TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase + # so not added for now) + + def __repr__(self) -> str: + return ( + f"PyramidAttentionBroadcastConfig(" + f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" + f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" + f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n" + f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n" + f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n" + f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n" + f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n" + f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n" + f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n" + f" current_timestep_callback={self.current_timestep_callback}\n" + ")" + ) + + +class PyramidAttentionBroadcastState: + r""" + State for Pyramid Attention Broadcast. + + Attributes: + iteration (`int`): + The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is + called before starting a new inference forward pass for PAB to work correctly. + cache (`Any`): + The cached output from the previous forward pass. This is used to re-use the attention states when the + attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module. + """ + + def __init__(self) -> None: + self.iteration = 0 + self.cache = None + + def reset(self): + self.iteration = 0 + self.cache = None + + def __repr__(self): + cache_repr = "" + if self.cache is None: + cache_repr = "None" + else: + cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})" + return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})" + + +class PyramidAttentionBroadcastHook(ModelHook): + r"""A hook that applies Pyramid Attention Broadcast to a given module.""" + + _is_stateful = True + + def __init__( + self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int] + ) -> None: + super().__init__() + + self.timestep_skip_range = timestep_skip_range + self.block_skip_range = block_skip_range + self.current_timestep_callback = current_timestep_callback + + def initialize_hook(self, module): + self.state = PyramidAttentionBroadcastState() + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + is_within_timestep_range = ( + self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1] + ) + should_compute_attention = ( + self.state.cache is None + or self.state.iteration == 0 + or not is_within_timestep_range + or self.state.iteration % self.block_skip_range == 0 + ) + + if should_compute_attention: + output = self.fn_ref.original_forward(*args, **kwargs) + else: + output = self.state.cache + + self.state.cache = output + self.state.iteration += 1 + return output + + def reset_state(self, module: torch.nn.Module) -> None: + self.state.reset() + return module + + +def apply_pyramid_attention_broadcast( + module: torch.nn.Module, + config: PyramidAttentionBroadcastConfig, +): + r""" + Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline. + + PAB is an attention approximation method that leverages the similarity in attention states between timesteps to + reduce the computational cost of attention computation. The key takeaway from the paper is that the attention + similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and + spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently + than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process. + + Args: + module (`torch.nn.Module`): + The module to apply Pyramid Attention Broadcast to. + config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`): + The configuration to use for Pyramid Attention Broadcast. + + Example: + + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + >>> from diffusers.utils import export_to_video + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = PyramidAttentionBroadcastConfig( + ... spatial_attention_block_skip_range=2, + ... spatial_attention_timestep_skip_range=(100, 800), + ... current_timestep_callback=lambda: pipe.current_timestep, + ... ) + >>> apply_pyramid_attention_broadcast(pipe.transformer, config) + ``` + """ + if config.current_timestep_callback is None: + raise ValueError( + "The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast." + ) + + if ( + config.spatial_attention_block_skip_range is None + and config.temporal_attention_block_skip_range is None + and config.cross_attention_block_skip_range is None + ): + logger.warning( + "Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` " + "or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. " + "To avoid this warning, please set one of the above parameters." + ) + config.spatial_attention_block_skip_range = 2 + + for name, submodule in module.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES): + # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB + # cannot be applied to this layer. For custom layers, users can extend this functionality and implement + # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. + continue + _apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config) + + +def _apply_pyramid_attention_broadcast_on_attention_class( + name: str, module: Attention, config: PyramidAttentionBroadcastConfig +) -> bool: + is_spatial_self_attention = ( + any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) + and config.spatial_attention_block_skip_range is not None + and not getattr(module, "is_cross_attention", False) + ) + is_temporal_self_attention = ( + any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers) + and config.temporal_attention_block_skip_range is not None + and not getattr(module, "is_cross_attention", False) + ) + is_cross_attention = ( + any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers) + and config.cross_attention_block_skip_range is not None + and getattr(module, "is_cross_attention", False) + ) + + block_skip_range, timestep_skip_range, block_type = None, None, None + if is_spatial_self_attention: + block_skip_range = config.spatial_attention_block_skip_range + timestep_skip_range = config.spatial_attention_timestep_skip_range + block_type = "spatial" + elif is_temporal_self_attention: + block_skip_range = config.temporal_attention_block_skip_range + timestep_skip_range = config.temporal_attention_timestep_skip_range + block_type = "temporal" + elif is_cross_attention: + block_skip_range = config.cross_attention_block_skip_range + timestep_skip_range = config.cross_attention_timestep_skip_range + block_type = "cross" + + if block_skip_range is None or timestep_skip_range is None: + logger.info( + f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does ' + f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, " + f"however, that this layer may still be valid for applying PAB. Please specify the correct " + f"block identifiers in the configuration." + ) + return False + + logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}") + _apply_pyramid_attention_broadcast_hook( + module, timestep_skip_range, block_skip_range, config.current_timestep_callback + ) + return True + + +def _apply_pyramid_attention_broadcast_hook( + module: Union[Attention, MochiAttention], + timestep_skip_range: Tuple[int, int], + block_skip_range: int, + current_timestep_callback: Callable[[], int], +): + r""" + Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module. + + Args: + module (`torch.nn.Module`): + The module to apply Pyramid Attention Broadcast to. + timestep_skip_range (`Tuple[int, int]`): + The range of timesteps to skip in the attention layer. The attention computations will be conditionally + skipped if the current timestep is within the specified range. + block_skip_range (`int`): + The number of times a specific attention broadcast is skipped before computing the attention states to + re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old + attention states will be re-used) before computing the new attention states again. + current_timestep_callback (`Callable[[], int]`): + A callback function that returns the current inference timestep. + """ + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback) + registry.register_hook(hook, "pyramid_attention_broadcast") diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e3f291ce2dc7..57a34609d28e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -39,6 +39,7 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] + _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet_hunyuan"] = [ @@ -109,6 +110,7 @@ ConsistencyDecoderVAE, VQModel, ) + from .cache_utils import CacheMixin from .controlnets import ( ControlNetModel, ControlNetUnionModel, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 9036c027a535..357df0c31087 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -138,10 +138,6 @@ def __init__( self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (Encoder, Decoder)): - module.gradient_checkpointing = value - def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index b62ed67ade29..f79aabe91dd3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -507,19 +507,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = sample + residual if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # Down blocks for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + sample = self._gradient_checkpointing_func(down_block, sample) # Mid block - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + sample = self._gradient_checkpointing_func(self.mid_block, sample) else: # Down blocks for down_block in self.down_blocks: @@ -647,19 +640,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # Mid block - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + sample = self._gradient_checkpointing_func(self.mid_block, sample) # Up blocks for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + sample = self._gradient_checkpointing_func(up_block, sample) else: # Mid block @@ -809,10 +795,6 @@ def __init__( sample_size - self.tile_overlap_w, ) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)): - module.gradient_checkpointing = value - def enable_tiling(self) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 941b3eb07f10..829e0fe54dd2 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -421,15 +421,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, temb, zq, @@ -523,15 +516,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key) + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key) ) else: hidden_states, new_conv_cache[conv_cache_key] = resnet( @@ -637,15 +623,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, temb, zq, @@ -774,18 +753,11 @@ def forward( hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # 1. Down for i, down_block in enumerate(self.down_blocks): conv_cache_key = f"down_block_{i}" - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + down_block, hidden_states, temb, None, @@ -793,8 +765,8 @@ def custom_forward(*inputs): ) # 2. Mid - hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), + hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func( + self.mid_block, hidden_states, temb, None, @@ -940,16 +912,9 @@ def forward( hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # 1. Mid - hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), + hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func( + self.mid_block, hidden_states, temb, sample, @@ -959,8 +924,8 @@ def custom_forward(*inputs): # 2. Up for i, up_block in enumerate(self.up_blocks): conv_cache_key = f"up_block_{i}" - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + up_block, hidden_states, temb, sample, @@ -1122,10 +1087,6 @@ def __init__( self.tile_overlap_factor_height = 1 / 6 self.tile_overlap_factor_width = 1 / 5 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): - module.gradient_checkpointing = value - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index e2236a7f20ad..22b833734f0f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -21,7 +21,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..attention_processor import Attention @@ -252,21 +252,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: @@ -278,9 +264,7 @@ def custom_forward(*inputs): hidden_states = attn(hidden_states, attention_mask=attention_mask) hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) else: hidden_states = self.resnets[0](hidden_states) @@ -350,22 +334,8 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for resnet in self.resnets: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) else: for resnet in self.resnets: hidden_states = resnet(hidden_states) @@ -426,22 +396,8 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for resnet in self.resnets: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) else: for resnet in self.resnets: @@ -545,26 +501,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for down_block in self.down_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) else: for down_block in self.down_blocks: hidden_states = down_block(hidden_states) @@ -667,26 +607,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states) else: hidden_states = self.mid_block(hidden_states) @@ -786,7 +710,7 @@ def __init__( self.use_tiling = False # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames - # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + # at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered. self.use_framewise_encoding = True self.use_framewise_decoding = True @@ -800,10 +724,6 @@ def __init__( self.tile_sample_stride_width = 192 self.tile_sample_stride_num_frames = 12 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): - module.gradient_checkpointing = value - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, @@ -868,7 +788,7 @@ def disable_slicing(self) -> None: def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape - if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames: return self._temporal_tiled_encode(x) if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 25753afd5ce6..75709ca10dfe 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -338,16 +338,7 @@ def forward( for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, generator - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) else: hidden_states = resnet(hidden_states, temb, generator) @@ -438,16 +429,7 @@ def forward( for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, generator - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) else: hidden_states = resnet(hidden_states, temb, generator) @@ -573,16 +555,7 @@ def forward( for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, generator - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) else: hidden_states = resnet(hidden_states, temb, generator) @@ -697,17 +670,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - for down_block in self.down_blocks: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states) + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) else: for down_block in self.down_blocks: hidden_states = down_block(hidden_states) @@ -838,19 +804,10 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) else: hidden_states = self.mid_block(hidden_states, temb) @@ -1017,10 +974,6 @@ def __init__( self.tile_sample_stride_width = 448 self.tile_sample_stride_num_frames = 8 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)): - module.gradient_checkpointing = value - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 920b0b62fef6..cd3eff73ed64 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -207,15 +207,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key), ) @@ -312,15 +305,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key) ) else: hidden_states, new_conv_cache[conv_cache_key] = resnet( @@ -393,15 +379,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key), ) @@ -531,21 +510,14 @@ def forward( hidden_states = hidden_states.permute(0, 4, 1, 2, 3) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") + hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func( + self.block_in, hidden_states, conv_cache=conv_cache.get("block_in") ) for i, down_block in enumerate(self.down_blocks): conv_cache_key = f"down_block_{i}" - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key) ) else: hidden_states, new_conv_cache["block_in"] = self.block_in( @@ -648,21 +620,14 @@ def forward( # 1. Mid if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") + hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func( + self.block_in, hidden_states, conv_cache=conv_cache.get("block_in") ) for i, up_block in enumerate(self.up_blocks): conv_cache_key = f"up_block_{i}" - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key) ) else: hidden_states, new_conv_cache["block_in"] = self.block_in( @@ -819,10 +784,6 @@ def __init__( self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (MochiEncoder3D, MochiDecoder3D)): - module.gradient_checkpointing = value - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 38ad78c0707b..5a72cd395196 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -18,7 +18,6 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..modeling_outputs import AutoencoderKLOutput @@ -97,47 +96,21 @@ def forward( upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func( + self.mid_block, + sample, + image_only_indicator, + ) + sample = sample.to(upscale_dtype) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), - sample, - image_only_indicator, - use_reentrant=False, - ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - image_only_indicator, - use_reentrant=False, - ) - else: - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func( + up_block, sample, image_only_indicator, ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - image_only_indicator, - ) else: # middle sample = self.mid_block(sample, image_only_indicator=image_only_indicator) @@ -229,10 +202,6 @@ def __init__( self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (Encoder, TemporalDecoder)): - module.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: diff --git a/src/diffusers/models/autoencoders/autoencoder_tiny.py b/src/diffusers/models/autoencoders/autoencoder_tiny.py index 35081c22dfc4..7ed727c55c37 100644 --- a/src/diffusers/models/autoencoders/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoders/autoencoder_tiny.py @@ -154,10 +154,6 @@ def __init__( self.register_to_config(block_out_channels=decoder_block_out_channels) self.register_to_config(force_upcast=False) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (EncoderTiny, DecoderTiny)): - module.gradient_checkpointing = value - def scale_latents(self, x: torch.Tensor) -> torch.Tensor: """raw latents -> [0, 1]""" return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 7fc7d5a4d797..72e0acda3afe 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from ...utils import BaseOutput, is_torch_version +from ...utils import BaseOutput from ...utils.torch_utils import randn_tensor from ..activations import get_activation from ..attention_processor import SpatialNorm @@ -156,28 +156,11 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.conv_in(sample) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # down - if is_torch_version(">=", "1.11.0"): - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), sample, use_reentrant=False - ) - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, use_reentrant=False - ) - else: - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + for down_block in self.down_blocks: + sample = self._gradient_checkpointing_func(down_block, sample) + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample) else: # down @@ -305,41 +288,13 @@ def forward( upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + sample = sample.to(upscale_dtype) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), - sample, - latent_embeds, - use_reentrant=False, - ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - latent_embeds, - use_reentrant=False, - ) - else: - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds - ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) @@ -558,72 +513,28 @@ def forward( upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + sample = sample.to(upscale_dtype) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), - sample, - latent_embeds, - use_reentrant=False, - ) - sample = sample.to(upscale_dtype) - - # condition encoder - if image is not None and mask is not None: - masked_image = (1 - mask) * image - im_x = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.condition_encoder), - masked_image, - mask, - use_reentrant=False, - ) - - # up - for up_block in self.up_blocks: - if image is not None and mask is not None: - sample_ = im_x[str(tuple(sample.shape))] - mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") - sample = sample * mask_ + sample_ * (1 - mask_) - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - latent_embeds, - use_reentrant=False, - ) - if image is not None and mask is not None: - sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) - else: - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self._gradient_checkpointing_func( + self.condition_encoder, + masked_image, + mask, ) - sample = sample.to(upscale_dtype) - # condition encoder - if image is not None and mask is not None: - masked_image = (1 - mask) * image - im_x = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.condition_encoder), - masked_image, - mask, - ) - - # up - for up_block in self.up_blocks: - if image is not None and mask is not None: - sample_ = im_x[str(tuple(sample.shape))] - mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") - sample = sample * mask_ + sample_ * (1 - mask_) - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + # up + for up_block in self.up_blocks: if image is not None and mask is not None: - sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) else: # middle sample = self.mid_block(sample, latent_embeds) @@ -890,17 +801,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: r"""The forward method of the `EncoderTiny` class.""" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) - else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + x = self._gradient_checkpointing_func(self.layers, x) else: # scale image from [-1, 1] to [0, 1] to match TAESD convention @@ -976,18 +877,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.tanh(x / 3) * 3 if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) - else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) - + x = self._gradient_checkpointing_func(self.layers, x) else: x = self.layers(x) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py new file mode 100644 index 000000000000..f2c621b3011a --- /dev/null +++ b/src/diffusers/models/cache_utils.py @@ -0,0 +1,89 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils.logging import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class CacheMixin: + r""" + A class for enable/disabling caching techniques on diffusion models. + + Supported caching techniques: + - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) + """ + + _cache_config = None + + @property + def is_cache_enabled(self) -> bool: + return self._cache_config is not None + + def enable_cache(self, config) -> None: + r""" + Enable caching techniques on the model. + + Args: + config (`Union[PyramidAttentionBroadcastConfig]`): + The configuration for applying the caching technique. Currently supported caching techniques are: + - [`~hooks.PyramidAttentionBroadcastConfig`] + + Example: + + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = PyramidAttentionBroadcastConfig( + ... spatial_attention_block_skip_range=2, + ... spatial_attention_timestep_skip_range=(100, 800), + ... current_timestep_callback=lambda: pipe.current_timestep, + ... ) + >>> pipe.transformer.enable_cache(config) + ``` + """ + + from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + + if isinstance(config, PyramidAttentionBroadcastConfig): + apply_pyramid_attention_broadcast(self, config) + else: + raise ValueError(f"Cache config {type(config)} is not supported.") + + self._cache_config = config + + def disable_cache(self) -> None: + from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig + + if self._cache_config is None: + logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") + return + + if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): + registry = HookRegistry.check_if_exists_or_initialize(self) + registry.remove_hook("pyramid_attention_broadcast", recurse=True) + else: + raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") + + self._cache_config = None + + def _reset_stateful_cache(self, recurse: bool = True) -> None: + from ..hooks import HookRegistry + + HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 1453aaf4362c..7a6ca886caed 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -31,8 +31,6 @@ from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..unets.unet_2d_blocks import ( - CrossAttnDownBlock2D, - DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block, @@ -659,10 +657,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 923b41119624..51c34b7fe965 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -22,7 +22,7 @@ from ...loaders import PeftAdapterMixin from ...models.attention_processor import AttentionProcessor from ...models.modeling_utils import ModelMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -178,10 +178,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - @classmethod def from_transformer( cls, @@ -330,24 +326,12 @@ def forward( block_samples = () for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, ) else: @@ -364,23 +348,11 @@ def custom_forward(*inputs): single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 9e361f2b16e5..1b0b4bae6410 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import JointTransformerBlock from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed @@ -262,10 +262,6 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer # we should have handled this in conversion script def _get_pos_embed_from_transformer(self, transformer): @@ -382,30 +378,16 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} if self.context_embedder is not None: - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, - **ckpt_kwargs, ) else: # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), hidden_states, temb, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb) else: if self.context_embedder is not None: diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index 807cbd339ef9..4edc91cacaa7 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -590,10 +590,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 1bf176101c61..076e966f3d37 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -29,8 +29,6 @@ from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..unets.unet_2d_blocks import ( - CrossAttnDownBlock2D, - DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) @@ -599,10 +597,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 8a8901d82d90..608be6b70277 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -20,7 +20,7 @@ from torch import Tensor, nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput, is_torch_version, logging +from ...utils import BaseOutput, logging from ...utils.torch_utils import apply_freeu from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -864,10 +864,6 @@ def freeze_unet_params(self) -> None: for u in self.up_blocks: u.freeze_base_params() - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -1450,15 +1446,6 @@ def forward( base_blocks = list(zip(self.base_resnets, self.base_attentions)) ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base ): @@ -1468,13 +1455,7 @@ def custom_forward(*inputs): # apply base subblock if torch.is_grad_enabled() and self.gradient_checkpointing: - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - h_base = torch.utils.checkpoint.checkpoint( - create_custom_forward(b_res), - h_base, - temb, - **ckpt_kwargs, - ) + h_base = self._gradient_checkpointing_func(b_res, h_base, temb) else: h_base = b_res(h_base, temb) @@ -1491,13 +1472,7 @@ def custom_forward(*inputs): # apply ctrl subblock if apply_control: if torch.is_grad_enabled() and self.gradient_checkpointing: - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - h_ctrl = torch.utils.checkpoint.checkpoint( - create_custom_forward(c_res), - h_ctrl, - temb, - **ckpt_kwargs, - ) + h_ctrl = self._gradient_checkpointing_func(c_res, h_ctrl, temb) else: h_ctrl = c_res(h_ctrl, temb) if c_attn is not None: @@ -1862,15 +1837,6 @@ def forward( and getattr(self, "b2", None) ) - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): # FreeU: Only operate on the first two stages if is_freeu_enabled: @@ -1900,13 +1866,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): hidden_states = torch.cat([hidden_states, res_h_base], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4d5669e37f5a..3ef40ffb5783 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -21,12 +21,13 @@ import os import re from collections import OrderedDict -from functools import partial, wraps +from functools import wraps from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import safetensors import torch +import torch.utils.checkpoint from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards from huggingface_hub.utils import validate_hf_hub_args from torch import Tensor, nn @@ -168,6 +169,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): def __init__(self): super().__init__() + self._gradient_checkpointing_func = None + def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite @@ -193,14 +196,35 @@ def is_gradient_checkpointing(self) -> bool: """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) - def enable_gradient_checkpointing(self) -> None: + def enable_gradient_checkpointing(self, gradient_checkpointing_func: Optional[Callable] = None) -> None: """ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). + + Args: + gradient_checkpointing_func (`Callable`, *optional*): + The function to use for gradient checkpointing. If `None`, the default PyTorch checkpointing function + is used (`torch.utils.checkpoint.checkpoint`). """ if not self._supports_gradient_checkpointing: - raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - self.apply(partial(self._set_gradient_checkpointing, value=True)) + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute " + f"`_supports_gradient_checkpointing` to `True` in the class definition." + ) + + if gradient_checkpointing_func is None: + + def _gradient_checkpointing_func(module, *args): + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + return torch.utils.checkpoint.checkpoint( + module.__call__, + *args, + **ckpt_kwargs, + ) + + gradient_checkpointing_func = _gradient_checkpointing_func + + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) def disable_gradient_checkpointing(self) -> None: """ @@ -208,7 +232,7 @@ def disable_gradient_checkpointing(self) -> None: *checkpoint activations* in other frameworks). """ if self._supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, value=False)) + self._set_gradient_checkpointing(enable=False) def set_use_npu_flash_attention(self, valid: bool) -> None: r""" @@ -1452,6 +1476,24 @@ def get_memory_footprint(self, return_buffers=True): mem = mem + mem_bufs return mem + def _set_gradient_checkpointing( + self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint + ) -> None: + is_gradient_checkpointing_set = False + + for name, module in self.named_modules(): + if hasattr(module, "gradient_checkpointing"): + logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'") + module._gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to " + f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`." + ) + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: deprecated_attention_block_paths = [] diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index f1f36b87987d..4938ed23c506 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Union +from typing import Dict, Union import torch import torch.nn as nn @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin -from ...utils import is_torch_version, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( Attention, @@ -444,10 +444,6 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.FloatTensor, @@ -469,23 +465,11 @@ def forward( # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, - **ckpt_kwargs, ) else: @@ -500,22 +484,10 @@ def custom_forward(*inputs): for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - combined_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + combined_hidden_states = self._gradient_checkpointing_func( + block, combined_hidden_states, temb, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index c3039180b81d..53ec148209e0 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -20,10 +20,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -156,7 +157,7 @@ def forward( return hidden_states, encoder_hidden_states -class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): """ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). @@ -330,9 +331,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -488,22 +486,13 @@ def forward( # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, emb, image_rotary_emb, attention_kwargs, - **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py index 86a6628b5161..f312553e4c05 100644 --- a/src/diffusers/models/transformers/consisid_transformer_3d.py +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0 @@ -595,9 +595,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - def _init_face_inputs(self): self.local_facial_extractor = LocalFacialExtractor( id_dim=self.LFE_id_dim, @@ -745,22 +742,13 @@ def forward( # 3. Transformer blocks ca_idx = 0 for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, emb, image_rotary_emb, - **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 7eac313c14db..6e83f49db71c 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -18,7 +18,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...utils import logging from ..attention import BasicTransformerBlock from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -144,10 +144,6 @@ def __init__( self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels ) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -186,19 +182,8 @@ def forward( # 2. Blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, None, None, @@ -206,7 +191,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, class_labels, - **ckpt_kwargs, ) else: hidden_states = block( diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index be06f44a9efe..4fe1d99cb6ee 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Optional import torch @@ -19,13 +20,14 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..attention import BasicTransformerBlock +from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle -class LatteTransformer3DModel(ModelMixin, ConfigMixin): +class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True """ @@ -164,9 +166,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -241,7 +240,7 @@ def forward( zip(self.transformer_blocks, self.temporal_transformer_blocks) ): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self._gradient_checkpointing_func( spatial_block, hidden_states, None, # attention_mask @@ -250,7 +249,6 @@ def forward( timestep_spatial, None, # cross_attention_kwargs None, # class_labels - use_reentrant=False, ) else: hidden_states = spatial_block( @@ -274,7 +272,7 @@ def forward( hidden_states = hidden_states + self.temp_pos_embed if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self._gradient_checkpointing_func( temp_block, hidden_states, None, # attention_mask @@ -283,7 +281,6 @@ def forward( timestep_temp, None, # cross_attention_kwargs None, # class_labels - use_reentrant=False, ) else: hidden_states = temp_block( diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index b1740cc08fdf..8e290074a018 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -17,7 +17,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...utils import logging from ..attention import BasicTransformerBlock from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 from ..embeddings import PatchEmbed, PixArtAlphaTextProjection @@ -184,10 +184,6 @@ def __init__( in_features=self.config.caption_channels, hidden_size=self.inner_dim ) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -388,19 +384,8 @@ def forward( # 2. Blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -408,7 +393,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, None, - **ckpt_kwargs, ) else: hidden_states = block( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index a2a54406430d..cface676b409 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -19,7 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, AttentionProcessor, @@ -308,10 +308,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -438,21 +434,9 @@ def forward( # 2. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for block in self.transformer_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -460,7 +444,6 @@ def custom_forward(*inputs): timestep, post_patch_height, post_patch_width, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index bb370f20f21b..d81b6447adb0 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Dict, Optional, Union import numpy as np import torch @@ -29,7 +29,7 @@ ) from ...models.modeling_utils import ModelMixin from ...models.transformers.transformer_2d import Transformer2DModelOutput -from ...utils import is_torch_version, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph @@ -346,10 +346,6 @@ def set_default_attn_processor(self): """ self.set_attn_processor(StableAudioAttnProcessor2_0()) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.FloatTensor, @@ -416,25 +412,13 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, cross_attention_hidden_states, encoder_attention_mask, rotary_embedding, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 35e78877f27e..a88ee6c9c9b8 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -18,7 +18,7 @@ from torch import nn from ...configuration_utils import LegacyConfigMixin, register_to_config -from ...utils import deprecate, is_torch_version, logging +from ...utils import deprecate, logging from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput @@ -321,10 +321,6 @@ def _init_patched_inputs(self, norm_type): in_features=self.caption_channels, hidden_size=self.inner_dim ) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -417,19 +413,8 @@ def forward( # 2. Blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -437,7 +422,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, class_labels, - **ckpt_kwargs, ) else: hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index f32c38394ba4..d5c93409c932 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -13,17 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import AllegroAttnProcessor2_0, Attention +from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -172,7 +173,7 @@ def forward( return hidden_states -class AllegroTransformer3DModel(ModelMixin, ConfigMixin): +class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True """ @@ -303,9 +304,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -375,23 +373,14 @@ def forward( for i, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep, attention_mask, encoder_attention_mask, image_rotary_emb, - **ckpt_kwargs, ) else: hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 0376cc2fd70d..da7133791f37 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Union +from typing import Dict, Union import torch import torch.nn as nn @@ -27,7 +27,7 @@ ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous -from ...utils import is_torch_version, logging +from ...utils import logging from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..normalization import CogView3PlusAdaLayerNormZeroTextImage @@ -289,10 +289,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -344,20 +340,11 @@ def forward( for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, emb, - **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index db8d73856689..8a36f2254e44 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -32,9 +32,10 @@ ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph +from ..cache_utils import CacheMixin from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -227,7 +228,7 @@ def forward( class FluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin ): """ The Transformer model introduced in Flux. @@ -422,10 +423,6 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -520,24 +517,12 @@ def forward( for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, ) else: @@ -564,23 +549,11 @@ def custom_forward(*inputs): for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 210a2e711972..c78d13344d81 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -22,9 +22,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor +from ..cache_utils import CacheMixin from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, @@ -502,7 +503,7 @@ def forward( return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). @@ -671,10 +672,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -733,38 +730,24 @@ def forward( # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for block in self.transformer_blocks: - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb, - **ckpt_kwargs, ) for block in self.single_transformer_blocks: - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index b5498c0aed01..f5dc63f49562 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention @@ -361,10 +361,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -417,25 +413,13 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, encoder_attention_mask, - **ckpt_kwargs, ) else: hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index d16430f27931..e6532f080d72 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -21,10 +21,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 +from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -305,7 +306,7 @@ def forward( @maybe_allow_in_graph -class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). @@ -403,10 +404,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -459,22 +456,13 @@ def forward( for i, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, encoder_attention_mask, image_rotary_emb, - **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 2688d3640ea5..e24a28fc3d7b 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -28,7 +28,7 @@ ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -329,10 +329,6 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.FloatTensor, @@ -404,24 +400,12 @@ def forward( is_skip = True if skip_layers is not None and index_block in skip_layers else False if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, joint_attention_kwargs, - **ckpt_kwargs, ) elif not is_skip: encoder_hidden_states, hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index 3b5aedb79e3c..5580d0f70f9f 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -343,19 +343,11 @@ def forward( # 2. Blocks for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( - block, - hidden_states, - None, - encoder_hidden_states, - None, - use_reentrant=False, + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, None, encoder_hidden_states, None ) else: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) hidden_states_mix = hidden_states hidden_states_mix = hidden_states_mix + emb diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 84a1322d2a95..5a7fc32223d6 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -248,10 +248,6 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index b4e0cea7c71d..e082d524e766 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from ...utils import deprecate, is_torch_version, logging +from ...utils import deprecate, logging from ...utils.torch_utils import apply_freeu from ..activations import get_activation from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 @@ -737,25 +737,9 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} if attn is not None: hidden_states = attn(hidden_states, temb=temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: if attn is not None: hidden_states = attn(hidden_states, temb=temb) @@ -883,17 +867,6 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -902,12 +875,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = attn( hidden_states, @@ -1156,23 +1124,7 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn(hidden_states, **cross_attention_kwargs) output_states = output_states + (hidden_states,) else: @@ -1304,23 +1256,7 @@ def forward( for i, (resnet, attn) in enumerate(blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1418,21 +1354,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -1906,21 +1828,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -2058,17 +1966,7 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2153,21 +2051,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -2262,22 +2146,10 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states = self._gradient_checkpointing_func( + resnet, hidden_states, temb, - **ckpt_kwargs, ) hidden_states = attn( hidden_states, @@ -2423,23 +2295,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn(hidden_states) else: hidden_states = resnet(hidden_states, temb) @@ -2588,23 +2444,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2721,21 +2561,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -3251,21 +3077,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -3409,17 +3221,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -3512,21 +3314,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -3640,22 +3428,10 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states = self._gradient_checkpointing_func( + resnet, hidden_states, temb, - **ckpt_kwargs, ) hidden_states = attn( hidden_states, diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 3447fa0674bc..5674d8ba26ec 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -834,10 +834,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 195f7601dd54..8d7614a23383 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -17,7 +17,7 @@ import torch from torch import nn -from ...utils import deprecate, is_torch_version, logging +from ...utils import deprecate, logging from ...utils.torch_utils import apply_freeu from ..attention import Attention from ..resnet import ( @@ -1078,31 +1078,14 @@ def forward( ) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) else: hidden_states = attn( hidden_states, @@ -1110,11 +1093,7 @@ def custom_forward(*inputs): image_only_indicator=image_only_indicator, return_dict=False, )[0] - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) return hidden_states @@ -1169,34 +1148,9 @@ def forward( output_states = () for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) else: - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) output_states = output_states + (hidden_states,) @@ -1281,25 +1235,8 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for resnet, attn in blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - **ckpt_kwargs, - ) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) hidden_states = attn( hidden_states, @@ -1308,11 +1245,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1385,34 +1318,9 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) else: - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1495,25 +1403,8 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - **ckpt_kwargs, - ) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1521,11 +1412,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 398609778e65..845d93b9db09 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -37,11 +37,7 @@ from ..modeling_utils import ModelMixin from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( - CrossAttnDownBlock3D, - CrossAttnUpBlock3D, - DownBlock3D, UNetMidBlock3DCrossAttn, - UpBlock3D, get_down_block, get_up_block, ) @@ -472,10 +468,6 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): - module.gradient_checkpointing = value - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index d5d98c256357..f0eca75de169 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -35,11 +35,7 @@ from ..modeling_utils import ModelMixin from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( - CrossAttnDownBlock3D, - CrossAttnUpBlock3D, - DownBlock3D, UNetMidBlock3DCrossAttn, - UpBlock3D, get_down_block, get_up_block, ) @@ -436,11 +432,6 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): - module.gradient_checkpointing = value - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/src/diffusers/models/unets/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py index f611e7d82b1d..73bf0020b481 100644 --- a/src/diffusers/models/unets/unet_kandinsky3.py +++ b/src/diffusers/models/unets/unet_kandinsky3.py @@ -205,10 +205,6 @@ def set_default_attn_processor(self): """ self.set_attn_processor(AttnProcessor()) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 1d0a38a8fb13..21e4db23a166 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...utils import BaseOutput, deprecate, is_torch_version, logging +from ...utils import BaseOutput, deprecate, logging from ...utils.torch_utils import apply_freeu from ..attention import BasicTransformerBlock from ..attention_processor import ( @@ -324,25 +324,7 @@ def forward( blocks = zip(self.resnets, self.motion_modules) for resnet, motion_module in blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(input_tensor=hidden_states, temb=temb) @@ -514,23 +496,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) for i, (resnet, attn, motion_module) in enumerate(blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(input_tensor=hidden_states, temb=temb) @@ -543,10 +509,7 @@ def custom_forward(*inputs): return_dict=False, )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - ) + hidden_states = motion_module(hidden_states, num_frames=num_frames) # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: @@ -733,23 +696,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(input_tensor=hidden_states, temb=temb) @@ -762,10 +709,7 @@ def custom_forward(*inputs): return_dict=False, )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - ) + hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -896,24 +840,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(input_tensor=hidden_states, temb=temb) @@ -1080,34 +1007,12 @@ def forward( )[0] if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, + hidden_states = self._gradient_checkpointing_func( + motion_module, hidden_states, None, None, None, num_frames, None ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - ) + hidden_states = motion_module(hidden_states, None, None, None, num_frames, None) hidden_states = resnet(input_tensor=hidden_states, temb=temb) return hidden_states @@ -1966,10 +1871,6 @@ def set_default_attn_processor(self) -> None: self.set_attn_processor(processor) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): - module.gradient_checkpointing = value - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 172c1e6bbb05..db4ace9656a3 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -320,10 +320,6 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 238e6b411356..f57754435fdc 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -387,9 +387,6 @@ def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=Tru self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, value=False): - self.gradient_checkpointing = value - def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) @@ -456,29 +453,18 @@ def _down_encode(self, x, r_embed, clip): block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - for down_block, downscaler, repmap in block_group: x = downscaler(x) for i in range(len(repmap) + 1): for block in down_block: if isinstance(block, SDCascadeResBlock): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) + x = self._gradient_checkpointing_func(block, x) elif isinstance(block, SDCascadeAttnBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, clip, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, clip) elif isinstance(block, SDCascadeTimestepBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, r_embed, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, r_embed) else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False) + x = self._gradient_checkpointing_func(block) if i < len(repmap): x = repmap[i](x) level_outputs.insert(0, x) @@ -505,13 +491,6 @@ def _up_decode(self, level_outputs, r_embed, clip): block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - for i, (up_block, upscaler, repmap) in enumerate(block_group): for j in range(len(repmap) + 1): for k, block in enumerate(up_block): @@ -523,19 +502,13 @@ def custom_forward(*inputs): x.float(), skip.shape[-2:], mode="bilinear", align_corners=True ) x = x.to(orig_type) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, skip, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, skip) elif isinstance(block, SDCascadeAttnBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, clip, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, clip) elif isinstance(block, SDCascadeTimestepBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, r_embed, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, r_embed) else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) + x = self._gradient_checkpointing_func(block, x) if j < len(repmap): x = repmap[j](x) x = upscaler(x) diff --git a/src/diffusers/models/unets/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py index 785f0f30aaae..94b39c84f055 100644 --- a/src/diffusers/models/unets/uvit_2d.py +++ b/src/diffusers/models/unets/uvit_2d.py @@ -148,9 +148,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - pass - def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 91aedf2cdbe6..cb36a7a672de 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -683,6 +683,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -815,6 +819,7 @@ def __call__( negative_prompt_attention_mask, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default height and width to transformer @@ -892,6 +897,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -933,6 +939,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) video = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index a33e26568772..00bed864ba34 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -38,7 +38,7 @@ from ...models.transformers.transformer_2d import Transformer2DModel from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from ...models.unets.unet_2d_condition import UNet2DConditionOutput -from ...utils import BaseOutput, is_torch_version, logging +from ...utils import BaseOutput, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -673,11 +673,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, @@ -1114,23 +1109,7 @@ def forward( for i in range(num_layers): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnets[i]), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states @@ -1141,8 +1120,8 @@ def custom_forward(*inputs): else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], hidden_states, forward_encoder_hidden_states, None, # timestep @@ -1150,7 +1129,6 @@ def custom_forward(*inputs): cross_attention_kwargs, attention_mask, forward_encoder_attention_mask, - **ckpt_kwargs, )[0] else: hidden_states = self.resnets[i](hidden_states, temb) @@ -1292,17 +1270,6 @@ def forward( for i in range(len(self.resnets[1:])): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states @@ -1313,8 +1280,8 @@ def custom_forward(*inputs): else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], hidden_states, forward_encoder_hidden_states, None, # timestep @@ -1322,14 +1289,8 @@ def custom_forward(*inputs): cross_attention_kwargs, attention_mask, forward_encoder_attention_mask, - **ckpt_kwargs, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnets[i + 1]), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb) else: for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: @@ -1466,23 +1427,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnets[i]), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states @@ -1493,8 +1438,8 @@ def custom_forward(*inputs): else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], hidden_states, forward_encoder_hidden_states, None, # timestep @@ -1502,7 +1447,6 @@ def custom_forward(*inputs): cross_attention_kwargs, attention_mask, forward_encoder_attention_mask, - **ckpt_kwargs, )[0] else: hidden_states = self.resnets[i](hidden_states, temb) diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py index 0d78b987ce77..d2408417f590 100644 --- a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py +++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py @@ -174,19 +174,16 @@ def forward( ) use_cache = False - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self._gradient_checkpointing_func( + layer_module, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, + query_length, ) else: layer_outputs = layer_module( diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index d78d5508dc7f..99ae9025cd3e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -494,6 +494,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -627,6 +631,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -705,6 +710,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -763,6 +769,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": # Discard any padding frames that were added for CogVideoX 1.5 latents = latents[:, additional_frames:] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 46e7b9ee468e..e37574ec9cb2 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -540,6 +540,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -680,6 +684,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -766,6 +771,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -818,6 +824,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 58793902345a..59d7c4cad547 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -591,6 +591,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -728,6 +732,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._attention_kwargs = attention_kwargs self._interrupt = False @@ -815,6 +820,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -877,6 +883,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": # Discard any padding frames that were added for CogVideoX 1.5 latents = latents[:, additional_frames:] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 333e3418dca2..c4dc7e574f7e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -564,6 +564,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -700,6 +704,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -786,6 +791,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -844,6 +850,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py index 56f6c9149c6e..1ee63e5f7db6 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -766,26 +766,6 @@ def check_inputs( else: assert False - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - - elif ( - isinstance(self.controlnet, ControlNetUnionModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetUnionModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - - else: - assert False - if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] @@ -1370,6 +1350,8 @@ def __call__( if not isinstance(control_image, list): control_image = [control_image] + else: + control_image = control_image.copy() if not isinstance(control_mode, list): control_mode = [control_mode] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index a2e50d4f3e09..27e627e5bac9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -741,26 +741,6 @@ def check_inputs( else: assert False - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - - elif ( - isinstance(self.controlnet, ControlNetUnionModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetUnionModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - - else: - assert False - if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] @@ -1160,6 +1140,8 @@ def __call__( if not isinstance(control_image, list): control_image = [control_image] + else: + control_image = control_image.copy() if not isinstance(control_mode, list): control_mode = [control_mode] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index d4409c54b01c..8547675426e3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -746,26 +746,6 @@ def check_inputs( else: assert False - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - - elif ( - isinstance(self.controlnet, ControlNetUnionModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetUnionModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - - else: - assert False - if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] @@ -1306,6 +1286,8 @@ def __call__( if not isinstance(control_image, list): control_image = [control_image] + else: + control_image = control_image.copy() if not isinstance(control_mode, list): control_mode = [control_mode] diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 4d9e50e3a2b4..bc276811ff4a 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -34,7 +34,7 @@ from ....models.transformers.dual_transformer_2d import DualTransformer2DModel from ....models.transformers.transformer_2d import Transformer2DModel from ....models.unets.unet_2d_condition import UNet2DConditionOutput -from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import apply_freeu @@ -963,10 +963,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. @@ -1597,21 +1593,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -1734,23 +1716,7 @@ def forward( for i, (resnet, attn) in enumerate(blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1876,21 +1842,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -2035,23 +1987,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2230,25 +2166,9 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} if attn is not None: hidden_states = attn(hidden_states, temb=temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: if attn is not None: hidden_states = attn(hidden_states, temb=temb) @@ -2377,17 +2297,6 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2396,12 +2305,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = attn( hidden_states, diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index f5716dc9c8ea..aa02dc1de5da 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -28,8 +28,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import FluxTransformer2DModel +from ...models import AutoencoderKL, FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, @@ -620,6 +619,10 @@ def joint_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -775,6 +778,7 @@ def __call__( self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Define call parameters @@ -899,6 +903,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -957,9 +962,10 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latent": image = latents - else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 5c3d6ce611cc..8cc77ed4c148 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -456,6 +456,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -577,6 +581,7 @@ def __call__( self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False device = self._execution_device @@ -644,6 +649,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -678,6 +684,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/kolors/text_encoder.py b/src/diffusers/pipelines/kolors/text_encoder.py index 5eb8d4c43d02..f07d064cbc22 100644 --- a/src/diffusers/pipelines/kolors/text_encoder.py +++ b/src/diffusers/pipelines/kolors/text_encoder.py @@ -605,7 +605,7 @@ def forward( layer = self._get_layer(index) if torch.is_grad_enabled() and self.gradient_checkpointing: - layer_ret = torch.utils.checkpoint.checkpoint( + layer_ret = self._gradient_checkpointing_func( layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache ) else: @@ -666,10 +666,6 @@ def get_position_ids(self, input_ids, device): position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) return position_ids - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GLMTransformer): - module.gradient_checkpointing = value - def default_init(cls, *args, **kwargs): return cls(*args, **kwargs) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index d079e71fe38e..c7aa76a01fb8 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -544,10 +544,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LDMBertEncoder,)): - module.gradient_checkpointing = value - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -688,15 +684,8 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self._gradient_checkpointing_func( + encoder_layer, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index ce4ca313ebc4..578f373e8e3f 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -602,6 +602,10 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -633,7 +637,7 @@ def __call__( clean_caption: bool = True, mask_feature: bool = True, enable_temporal_attentions: bool = True, - decode_chunk_size: Optional[int] = None, + decode_chunk_size: int = 14, ) -> Union[LattePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -729,6 +733,7 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default height and width to transformer @@ -790,6 +795,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -850,6 +856,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latents": deprecation_message = ( "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead." @@ -858,7 +866,7 @@ def __call__( output_type = "latent" if not output_type == "latent": - video = self.decode_latents(latents, video_length, decode_chunk_size=14) + video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 5b37e9a503a8..133cb2c5f146 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -396,8 +396,10 @@ def check_inputs( prompt_attention_mask=None, negative_prompt_attention_mask=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." + ) if prompt is not None and prompt_embeds is not None: raise ValueError( diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index a3028c50d8b7..d1f88b02c5cc 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -21,8 +21,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import Mochi1LoraLoaderMixin -from ...models.autoencoders import AutoencoderKLMochi -from ...models.transformers import MochiTransformer3DModel +from ...models import AutoencoderKLMochi, MochiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( is_torch_xla_available, @@ -467,6 +466,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -591,6 +594,7 @@ def __call__( self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Define call parameters @@ -660,6 +664,9 @@ def __call__( if self.interrupt: continue + # Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need + # to make sure we're using the correct non-reversed timestep values. + self._current_timestep = 1000 - t latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) @@ -705,6 +712,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latent": video = latents else: diff --git a/src/diffusers/pipelines/onnx_utils.py b/src/diffusers/pipelines/onnx_utils.py index f4dbd4092e32..0e12340f6895 100644 --- a/src/diffusers/pipelines/onnx_utils.py +++ b/src/diffusers/pipelines/onnx_utils.py @@ -61,7 +61,7 @@ def __call__(self, **kwargs): return self.model.run(None, inputs) @staticmethod - def load_model(path: Union[str, Path], provider=None, sess_options=None): + def load_model(path: Union[str, Path], provider=None, sess_options=None, provider_options=None): """ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` @@ -75,7 +75,9 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None): logger.info("No onnxruntime provider specified, using CPUExecutionProvider") provider = "CPUExecutionProvider" - return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) + return ort.InferenceSession( + path, providers=[provider], sess_options=sess_options, provider_options=provider_options + ) def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): """ diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d56a2ce6eb30..0c1371c7556f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1133,11 +1133,20 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t def maybe_free_model_hooks(self): r""" - Function that offloads all components, removes all model hooks that were added when using - `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function - is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it - functions correctly when applying enable_model_cpu_offload. + Method that performs the following: + - Offloads all components. + - Removes all model hooks that were added when using `enable_model_cpu_offload`, and then applies them again. + In case the model has not been offloaded, this function is a no-op. + - Resets stateful diffusers hooks of denoiser components if they were added with + [`~hooks.HookRegistry.register_hook`]. + + Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions + correctly when applying `enable_model_cpu_offload`. """ + for component in self.components.values(): + if hasattr(component, "_reset_stateful_cache"): + component._reset_stateful_cache() + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index f90fc82a98ad..9863c506d743 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -29,7 +29,6 @@ AttnProcessor, ) from ...models.modeling_utils import ModelMixin -from ...utils import is_torch_version from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm @@ -138,9 +137,6 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 @@ -159,33 +155,13 @@ def forward(self, x, r, c): r_embed = self.gen_r_embedding(r) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - for block in self.blocks: - if isinstance(block, AttnBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, c_embed, use_reentrant=False - ) - elif isinstance(block, TimestepBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, r_embed, use_reentrant=False - ) - else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) - else: - for block in self.blocks: - if isinstance(block, AttnBlock): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed) - elif isinstance(block, TimestepBlock): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed) - else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x) + for block in self.blocks: + if isinstance(block, AttnBlock): + x = self._gradient_checkpointing_func(block, x, c_embed) + elif isinstance(block, TimestepBlock): + x = self._gradient_checkpointing_func(block, x, r_embed) + else: + x = self._gradient_checkpointing_func(block, x) else: for block in self.blocks: if isinstance(block, AttnBlock): diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index eb40d79b9f60..624d5a5cd4f3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -142,7 +142,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`. trained_betas (`np.ndarray`, *optional*): An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`. variance_type (`str`, defaults to `"fixed_small"`): diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index ae953cfb966b..a14797b42f7a 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -319,7 +319,11 @@ def step( prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf - prev_known_part = (alpha_prod_t_prev**0.5) * original_image + (1 - alpha_prod_t_prev) * noise + # The computation reported in Algorithm 1 Line 5 is incorrect. Line 5 refers to formula (8a) of the same paper, + # which tells to sample from a Gaussian distribution with mean "(alpha_prod_t_prev**0.5) * original_image" + # and variance "(1 - alpha_prod_t_prev)". This means that the standard Gaussian distribution "noise" should be + # scaled by the square root of the variance (as it is done here), however Algorithm 1 Line 5 tells to scale by the variance. + prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 183d6beb35c3..6a1978944c9f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,40 @@ from ..utils import DummyObject, requires_backends +class HookRegistry(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PyramidAttentionBroadcastConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +def apply_pyramid_attention_broadcast(*args, **kwargs): + requires_backends(apply_pyramid_attention_broadcast, ["torch"]) + + class AllegroTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -197,6 +231,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CacheMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CogVideoXTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 12eef8899bbb..3c8911773e39 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -149,3 +149,13 @@ def apply_freeu( res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) return hidden_states, res_hidden_states + + +def get_torch_cuda_device_capability(): + if torch.cuda.is_available(): + device = torch.device("cuda") + compute_capability = torch.cuda.get_device_capability(device) + compute_capability = f"{compute_capability[0]}.{compute_capability[1]}" + return float(compute_capability) + else: + return None diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py new file mode 100644 index 000000000000..74bd43c52315 --- /dev/null +++ b/tests/hooks/test_hooks.py @@ -0,0 +1,382 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers.hooks import HookRegistry, ModelHook +from diffusers.training_utils import free_memory +from diffusers.utils.logging import get_logger +from diffusers.utils.testing_utils import CaptureLogger, torch_device + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class DummyBlock(torch.nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.proj_in = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.proj_out = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj_in(x) + x = self.activation(x) + x = self.proj_out(x) + return x + + +class DummyModel(torch.nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: + super().__init__() + + self.linear_1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] + ) + self.linear_2 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = self.activation(x) + for block in self.blocks: + x = block(x) + x = self.linear_2(x) + return x + + +class AddHook(ModelHook): + def __init__(self, value: int): + super().__init__() + self.value = value + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + logger.debug("AddHook pre_forward") + args = ((x + self.value) if torch.is_tensor(x) else x for x in args) + return args, kwargs + + def post_forward(self, module, output): + logger.debug("AddHook post_forward") + return output + + +class MultiplyHook(ModelHook): + def __init__(self, value: int): + super().__init__() + self.value = value + + def pre_forward(self, module, *args, **kwargs): + logger.debug("MultiplyHook pre_forward") + args = ((x * self.value) if torch.is_tensor(x) else x for x in args) + return args, kwargs + + def post_forward(self, module, output): + logger.debug("MultiplyHook post_forward") + return output + + def __repr__(self): + return f"MultiplyHook(value={self.value})" + + +class StatefulAddHook(ModelHook): + _is_stateful = True + + def __init__(self, value: int): + super().__init__() + self.value = value + self.increment = 0 + + def pre_forward(self, module, *args, **kwargs): + logger.debug("StatefulAddHook pre_forward") + add_value = self.value + self.increment + self.increment += 1 + args = ((x + add_value) if torch.is_tensor(x) else x for x in args) + return args, kwargs + + def reset_state(self, module): + self.increment = 0 + + +class SkipLayerHook(ModelHook): + def __init__(self, skip_layer: bool): + super().__init__() + self.skip_layer = skip_layer + + def pre_forward(self, module, *args, **kwargs): + logger.debug("SkipLayerHook pre_forward") + return args, kwargs + + def new_forward(self, module, *args, **kwargs): + logger.debug("SkipLayerHook new_forward") + if self.skip_layer: + return args[0] + return self.fn_ref.original_forward(*args, **kwargs) + + def post_forward(self, module, output): + logger.debug("SkipLayerHook post_forward") + return output + + +class HookTests(unittest.TestCase): + in_features = 4 + hidden_features = 8 + out_features = 4 + num_layers = 2 + + def setUp(self): + params = self.get_module_parameters() + self.model = DummyModel(**params) + self.model.to(torch_device) + + def tearDown(self): + super().tearDown() + + del self.model + gc.collect() + free_memory() + + def get_module_parameters(self): + return { + "in_features": self.in_features, + "hidden_features": self.hidden_features, + "out_features": self.out_features, + "num_layers": self.num_layers, + } + + def get_generator(self): + return torch.manual_seed(0) + + def test_hook_registry(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(1), "add_hook") + registry.register_hook(MultiplyHook(2), "multiply_hook") + + registry_repr = repr(registry) + expected_repr = ( + "HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")" + ) + + self.assertEqual(len(registry.hooks), 2) + self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) + self.assertEqual(registry_repr, expected_repr) + + registry.remove_hook("add_hook") + + self.assertEqual(len(registry.hooks), 1) + self.assertEqual(registry._hook_order, ["multiply_hook"]) + + def test_stateful_hook(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(StatefulAddHook(1), "stateful_add_hook") + + self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0) + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + num_repeats = 3 + + for i in range(num_repeats): + result = self.model(input) + if i == 0: + output1 = result + + self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats) + + registry.reset_stateful_hooks() + output2 = self.model(input) + + self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1) + self.assertTrue(torch.allclose(output1, output2)) + + def test_inference(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(1), "add_hook") + registry.register_hook(MultiplyHook(2), "multiply_hook") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + output1 = self.model(input).mean().detach().cpu().item() + + registry.remove_hook("multiply_hook") + new_input = input * 2 + output2 = self.model(new_input).mean().detach().cpu().item() + + registry.remove_hook("add_hook") + new_input = input * 2 + 1 + output3 = self.model(new_input).mean().detach().cpu().item() + + self.assertAlmostEqual(output1, output2, places=5) + self.assertAlmostEqual(output1, output3, places=5) + + def test_skip_layer_hook(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") + + input = torch.zeros(1, 4, device=torch_device) + output = self.model(input).mean().detach().cpu().item() + self.assertEqual(output, 0.0) + + registry.remove_hook("skip_layer_hook") + registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook") + output = self.model(input).mean().detach().cpu().item() + self.assertNotEqual(output, 0.0) + + def test_skip_layer_internal_block(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1) + input = torch.zeros(1, 4, device=torch_device) + + registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") + with self.assertRaises(RuntimeError) as cm: + self.model(input).mean().detach().cpu().item() + self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception)) + + registry.remove_hook("skip_layer_hook") + output = self.model(input).mean().detach().cpu().item() + self.assertNotEqual(output, 0.0) + + registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1]) + registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") + output = self.model(input).mean().detach().cpu().item() + self.assertNotEqual(output, 0.0) + + def test_invocation_order_stateful_first(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(StatefulAddHook(1), "add_hook") + registry.register_hook(AddHook(2), "add_hook_2") + registry.register_hook(MultiplyHook(3), "multiply_hook") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + + logger = get_logger(__name__) + logger.setLevel("DEBUG") + + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "MultiplyHook pre_forward\n" + "AddHook pre_forward\n" + "StatefulAddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "MultiplyHook pre_forward\n" + "AddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + def test_invocation_order_stateful_middle(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(2), "add_hook") + registry.register_hook(StatefulAddHook(1), "add_hook_2") + registry.register_hook(MultiplyHook(3), "multiply_hook") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + + logger = get_logger(__name__) + logger.setLevel("DEBUG") + + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "MultiplyHook pre_forward\n" + "StatefulAddHook pre_forward\n" + "AddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n") + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook_2") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + def test_invocation_order_stateful_last(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(1), "add_hook") + registry.register_hook(MultiplyHook(2), "multiply_hook") + registry.register_hook(StatefulAddHook(3), "add_hook_2") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + + logger = get_logger(__name__) + logger.setLevel("DEBUG") + + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "StatefulAddHook pre_forward\n" + "MultiplyHook pre_forward\n" + "AddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n") + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 05050e05bb19..c3cb082b0ef1 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -68,6 +68,7 @@ torch_all_close, torch_device, ) +from diffusers.utils.torch_utils import get_torch_cuda_device_capability from ..others.test_utils import TOKEN, USER, is_staging_test @@ -953,24 +954,15 @@ def test_gradient_checkpointing_is_applied( init_dict["block_out_channels"] = block_out_channels model_class_copy = copy.copy(self.model_class) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - model = model_class_copy(**init_dict) model.enable_gradient_checkpointing() + modules_with_gc_enabled = {} + for submodule in model.modules(): + if hasattr(submodule, "gradient_checkpointing"): + self.assertTrue(submodule.gradient_checkpointing) + modules_with_gc_enabled[submodule.__class__.__name__] = True + assert set(modules_with_gc_enabled.keys()) == expected_set assert all(modules_with_gc_enabled.values()), "All modules should be enabled" @@ -1393,6 +1385,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype): @require_torch_gpu def test_layerwise_casting_memory(self): MB_TOLERANCE = 0.2 + LEAST_COMPUTE_CAPABILITY = 8.0 def reset_memory_stats(): gc.collect() @@ -1421,10 +1414,12 @@ def get_memory_usage(storage_dtype, compute_dtype): torch.float8_e4m3fn, torch.bfloat16 ) + compute_capability = get_torch_cuda_device_capability() self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint) - # NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. - # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. - self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) + # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. + # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. + if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY: + self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few # bytes. This only happens for some models, so we allow a small tolerance. # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32. diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index 322be373641a..2a4d0a36dffa 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -34,13 +34,13 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np enable_full_determinism() -class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = AllegroPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -59,14 +59,14 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False test_layerwise_casting = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = AllegroTransformer3DModel( num_attention_heads=2, attention_head_dim=12, in_channels=4, out_channels=4, - num_layers=1, + num_layers=num_layers, cross_attention_dim=24, sample_width=8, sample_height=8, diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 9ce3d8e9de31..750f20f8fbe5 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -32,6 +32,7 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, to_np, @@ -41,7 +42,7 @@ enable_full_determinism() -class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -60,7 +61,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False test_layerwise_casting = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -72,7 +73,7 @@ def get_dummy_components(self): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=1, + num_layers=num_layers, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index a3bc1658de74..bab343a5954c 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -19,12 +19,15 @@ from ..test_pipelines_common import ( FluxIPAdapterTesterMixin, PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, ) -class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin): +class FluxPipelineFastTests( + unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin +): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -33,13 +36,13 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte test_xformers_attention = False test_layerwise_casting = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = FluxTransformer2DModel( patch_size=1, in_channels=4, - num_layers=1, - num_single_layers=1, + num_layers=num_layers, + num_single_layers=num_single_layers, attention_head_dim=16, num_attention_heads=2, joint_attention_dim=32, diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index ce03381f90d2..1ecfee666fcd 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -30,13 +30,13 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np enable_full_determinism() -class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -55,15 +55,15 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False test_layerwise_casting = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = HunyuanVideoTransformer3DModel( in_channels=4, out_channels=4, num_attention_heads=2, attention_head_dim=10, - num_layers=1, - num_single_layers=1, + num_layers=num_layers, + num_single_layers=num_single_layers, num_refiner_layers=1, patch_size=1, patch_size_t=1, diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 2d5bcba8237a..64459a659179 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -27,6 +27,7 @@ DDIMScheduler, LattePipeline, LatteTransformer3DModel, + PyramidAttentionBroadcastConfig, ) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -38,13 +39,13 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np enable_full_determinism() -class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = LattePipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -54,11 +55,23 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params test_layerwise_casting = True - def get_dummy_components(self): + pab_config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + temporal_attention_block_skip_range=2, + cross_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(100, 700), + temporal_attention_timestep_skip_range=(100, 800), + cross_attention_timestep_skip_range=(100, 800), + spatial_attention_block_identifiers=["transformer_blocks"], + temporal_attention_block_identifiers=["temporal_transformer_blocks"], + cross_attention_block_identifiers=["transformer_blocks"], + ) + + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LatteTransformer3DModel( sample_size=8, - num_layers=1, + num_layers=num_layers, patch_size=2, attention_head_dim=8, num_attention_heads=3, diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 139778994b87..de5faa185c2f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -24,10 +24,12 @@ DDIMScheduler, DiffusionPipeline, KolorsPipeline, + PyramidAttentionBroadcastConfig, StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, ) +from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor @@ -2322,6 +2324,141 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): self.assertLess(max_diff, expected_max_difference) +class PyramidAttentionBroadcastTesterMixin: + pab_config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(100, 800), + spatial_attention_block_identifiers=["transformer_blocks"], + ) + + def test_pyramid_attention_broadcast_layers(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + num_layers = 0 + num_single_layers = 0 + dummy_component_kwargs = {} + dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters + if "num_layers" in dummy_component_parameters: + num_layers = 2 + dummy_component_kwargs["num_layers"] = num_layers + if "num_single_layers" in dummy_component_parameters: + num_single_layers = 2 + dummy_component_kwargs["num_single_layers"] = num_single_layers + + components = self.get_dummy_components(**dummy_component_kwargs) + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + self.pab_config.current_timestep_callback = lambda: pipe.current_timestep + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + denoiser.enable_cache(self.pab_config) + + expected_hooks = 0 + if self.pab_config.spatial_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + if self.pab_config.temporal_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + if self.pab_config.cross_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + count = 0 + for module in denoiser.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + if hook is None: + continue + count += 1 + self.assertTrue( + isinstance(hook, PyramidAttentionBroadcastHook), + "Hook should be of type PyramidAttentionBroadcastHook.", + ) + self.assertTrue(hook.state.cache is None, "Cache should be None at initialization.") + self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.") + + # Perform dummy inference step to ensure state is updated + def pab_state_check_callback(pipe, i, t, kwargs): + for module in denoiser.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + if hook is None: + continue + self.assertTrue( + hook.state.cache is not None, + "Cache should have updated during inference.", + ) + self.assertTrue( + hook.state.iteration == i + 1, + "Hook iteration state should have updated during inference.", + ) + return {} + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 2 + inputs["callback_on_step_end"] = pab_state_check_callback + pipe(**inputs)[0] + + # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states + for module in denoiser.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + if hook is None: + continue + self.assertTrue( + hook.state.cache is None, + "Cache should be reset to None after inference.", + ) + self.assertTrue( + hook.state.iteration == 0, + "Iteration should be reset to 0 after inference.", + ) + + def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2): + # We need to use higher tolerance because we are using a random model. With a converged/trained + # model, the tolerance can be lower. + + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + # Run inference without PAB + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + original_image_slice = output.flatten() + original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) + + # Run inference with PAB enabled + self.pab_config.current_timestep_callback = lambda: pipe.current_timestep + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + denoiser.enable_cache(self.pab_config) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + image_slice_pab_enabled = output.flatten() + image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:])) + + # Run inference with PAB disabled + denoiser.disable_cache() + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + image_slice_pab_disabled = output.flatten() + image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:])) + + assert np.allclose( + original_image_slice, image_slice_pab_enabled, atol=expected_atol + ), "PAB outputs should not differ much in specified timestep range." + assert np.allclose( + original_image_slice, image_slice_pab_disabled, atol=1e-4 + ), "Outputs from normal inference and after disabling cache should not differ." + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image.