Skip to content

Commit de84a04

Browse files
authored
Merge branch 'main' into cogvideox1.1-5b
2 parents 67cb373 + 5b972fb commit de84a04

File tree

63 files changed

+3226
-2858
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+3226
-2858
lines changed

docs/source/en/api/models/controlnet.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
3939

4040
## ControlNetOutput
4141

42-
[[autodoc]] models.controlnet.ControlNetOutput
42+
[[autodoc]] models.controlnets.controlnet.ControlNetOutput
4343

4444
## FlaxControlNetModel
4545

4646
[[autodoc]] FlaxControlNetModel
4747

4848
## FlaxControlNetOutput
4949

50-
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
50+
[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput

docs/source/en/api/models/controlnet_sd3.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di
3838

3939
## SD3ControlNetOutput
4040

41-
[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
41+
[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput
4242

examples/community/matryoshka.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ def forward(
868868
blocks = list(zip(self.resnets, self.attentions))
869869

870870
for i, (resnet, attn) in enumerate(blocks):
871-
if self.training and self.gradient_checkpointing:
871+
if torch.is_grad_enabled() and self.gradient_checkpointing:
872872

873873
def create_custom_forward(module, return_dict=None):
874874
def custom_forward(*inputs):
@@ -1029,7 +1029,7 @@ def forward(
10291029

10301030
hidden_states = self.resnets[0](hidden_states, temb)
10311031
for attn, resnet in zip(self.attentions, self.resnets[1:]):
1032-
if self.training and self.gradient_checkpointing:
1032+
if torch.is_grad_enabled() and self.gradient_checkpointing:
10331033

10341034
def create_custom_forward(module, return_dict=None):
10351035
def custom_forward(*inputs):
@@ -1191,7 +1191,7 @@ def forward(
11911191

11921192
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
11931193

1194-
if self.training and self.gradient_checkpointing:
1194+
if torch.is_grad_enabled() and self.gradient_checkpointing:
11951195

11961196
def create_custom_forward(module, return_dict=None):
11971197
def custom_forward(*inputs):
@@ -1364,7 +1364,7 @@ def forward(
13641364

13651365
# Blocks
13661366
for block in self.transformer_blocks:
1367-
if self.training and self.gradient_checkpointing:
1367+
if torch.is_grad_enabled() and self.gradient_checkpointing:
13681368

13691369
def create_custom_forward(module, return_dict=None):
13701370
def custom_forward(*inputs):

examples/research_projects/pixart/controlnet_pixart_alpha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def forward(
215215

216216
# 2. Blocks
217217
for block_index, block in enumerate(self.transformer.transformer_blocks):
218-
if self.training and self.gradient_checkpointing:
218+
if torch.is_grad_enabled() and self.gradient_checkpointing:
219219
# rc todo: for training and gradient checkpointing
220220
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
221221
exit(1)

examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,11 @@ def forward(
229229
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
230230
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
231231
return_dict (`bool`, defaults to `True`):
232-
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
232+
Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple.
233233
234234
Returns:
235-
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
236-
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
235+
[`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
236+
If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
237237
returned where the first element is the sample tensor.
238238
"""
239239
# check channel order

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@
487487

488488

489489
else:
490-
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
490+
_import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
491491
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
492492
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
493493
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
@@ -914,7 +914,7 @@
914914
except OptionalDependencyNotAvailable:
915915
from .utils.dummy_flax_objects import * # noqa F403
916916
else:
917-
from .models.controlnet_flax import FlaxControlNetModel
917+
from .models.controlnets.controlnet_flax import FlaxControlNetModel
918918
from .models.modeling_flax_utils import FlaxModelMixin
919919
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
920920
from .models.vae_flax import FlaxAutoencoderKL

src/diffusers/models/__init__.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,16 @@
3636
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
3737
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
3838
_import_structure["autoencoders.vq_model"] = ["VQModel"]
39-
_import_structure["controlnet"] = ["ControlNetModel"]
40-
_import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
41-
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
42-
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
43-
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
44-
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
39+
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
40+
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
41+
_import_structure["controlnets.controlnet_hunyuan"] = [
42+
"HunyuanDiT2DControlNetModel",
43+
"HunyuanDiT2DMultiControlNetModel",
44+
]
45+
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
46+
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
47+
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
48+
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
4549
_import_structure["embeddings"] = ["ImageProjection"]
4650
_import_structure["modeling_utils"] = ["ModelMixin"]
4751
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
@@ -74,7 +78,7 @@
7478
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
7579

7680
if is_flax_available():
77-
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
81+
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
7882
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
7983
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
8084

@@ -94,12 +98,19 @@
9498
ConsistencyDecoderVAE,
9599
VQModel,
96100
)
97-
from .controlnet import ControlNetModel
98-
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
99-
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
100-
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
101-
from .controlnet_sparsectrl import SparseControlNetModel
102-
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
101+
from .controlnets import (
102+
ControlNetModel,
103+
ControlNetXSAdapter,
104+
FluxControlNetModel,
105+
FluxMultiControlNetModel,
106+
HunyuanDiT2DControlNetModel,
107+
HunyuanDiT2DMultiControlNetModel,
108+
MultiControlNetModel,
109+
SD3ControlNetModel,
110+
SD3MultiControlNetModel,
111+
SparseControlNetModel,
112+
UNetControlNetXSModel,
113+
)
103114
from .embeddings import ImageProjection
104115
from .modeling_utils import ModelMixin
105116
from .transformers import (
@@ -137,7 +148,7 @@
137148
)
138149

139150
if is_flax_available():
140-
from .controlnet_flax import FlaxControlNetModel
151+
from .controlnets import FlaxControlNetModel
141152
from .unets import FlaxUNet2DConditionModel
142153
from .vae_flax import FlaxAutoencoderKL
143154

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
506506
sample = self.temp_conv_in(sample)
507507
sample = sample + residual
508508

509-
if self.gradient_checkpointing:
509+
if torch.is_grad_enabled() and self.gradient_checkpointing:
510510

511511
def create_custom_forward(module):
512512
def custom_forward(*inputs):
@@ -646,7 +646,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
646646

647647
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
648648

649-
if self.gradient_checkpointing:
649+
if torch.is_grad_enabled() and self.gradient_checkpointing:
650650

651651
def create_custom_forward(module):
652652
def custom_forward(*inputs):

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def forward(
420420
for i, resnet in enumerate(self.resnets):
421421
conv_cache_key = f"resnet_{i}"
422422

423-
if self.training and self.gradient_checkpointing:
423+
if torch.is_grad_enabled() and self.gradient_checkpointing:
424424

425425
def create_custom_forward(module):
426426
def create_forward(*inputs):
@@ -522,7 +522,7 @@ def forward(
522522
for i, resnet in enumerate(self.resnets):
523523
conv_cache_key = f"resnet_{i}"
524524

525-
if self.training and self.gradient_checkpointing:
525+
if torch.is_grad_enabled() and self.gradient_checkpointing:
526526

527527
def create_custom_forward(module):
528528
def create_forward(*inputs):
@@ -636,7 +636,7 @@ def forward(
636636
for i, resnet in enumerate(self.resnets):
637637
conv_cache_key = f"resnet_{i}"
638638

639-
if self.training and self.gradient_checkpointing:
639+
if torch.is_grad_enabled() and self.gradient_checkpointing:
640640

641641
def create_custom_forward(module):
642642
def create_forward(*inputs):
@@ -773,7 +773,7 @@ def forward(
773773

774774
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
775775

776-
if self.training and self.gradient_checkpointing:
776+
if torch.is_grad_enabled() and self.gradient_checkpointing:
777777

778778
def create_custom_forward(module):
779779
def custom_forward(*inputs):
@@ -939,7 +939,7 @@ def forward(
939939

940940
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
941941

942-
if self.training and self.gradient_checkpointing:
942+
if torch.is_grad_enabled() and self.gradient_checkpointing:
943943

944944
def create_custom_forward(module):
945945
def custom_forward(*inputs):

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def forward(
206206
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
207207
conv_cache_key = f"resnet_{i}"
208208

209-
if self.training and self.gradient_checkpointing:
209+
if torch.is_grad_enabled() and self.gradient_checkpointing:
210210

211211
def create_custom_forward(module):
212212
def create_forward(*inputs):
@@ -311,7 +311,7 @@ def forward(
311311
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
312312
conv_cache_key = f"resnet_{i}"
313313

314-
if self.training and self.gradient_checkpointing:
314+
if torch.is_grad_enabled() and self.gradient_checkpointing:
315315

316316
def create_custom_forward(module):
317317
def create_forward(*inputs):
@@ -392,7 +392,7 @@ def forward(
392392
for i, resnet in enumerate(self.resnets):
393393
conv_cache_key = f"resnet_{i}"
394394

395-
if self.training and self.gradient_checkpointing:
395+
if torch.is_grad_enabled() and self.gradient_checkpointing:
396396

397397
def create_custom_forward(module):
398398
def create_forward(*inputs):
@@ -529,7 +529,7 @@ def forward(
529529
hidden_states = self.proj_in(hidden_states)
530530
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
531531

532-
if self.training and self.gradient_checkpointing:
532+
if torch.is_grad_enabled() and self.gradient_checkpointing:
533533

534534
def create_custom_forward(module):
535535
def create_forward(*inputs):
@@ -646,7 +646,7 @@ def forward(
646646
hidden_states = self.conv_in(hidden_states)
647647

648648
# 1. Mid
649-
if self.training and self.gradient_checkpointing:
649+
if torch.is_grad_enabled() and self.gradient_checkpointing:
650650

651651
def create_custom_forward(module):
652652
def create_forward(*inputs):

0 commit comments

Comments
 (0)