Skip to content

Commit 44bfd4c

Browse files
add CacheMixin
1 parent 5c25cd2 commit 44bfd4c

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step,
127127
num_inference_steps=50,
128128
guidance_scale=args.guidance_scale,
129129
generator=generator,
130-
max_sequence_length=512,
131130
height=args.resolution,
132131
width=args.resolution,
133132
).images[0]
@@ -1075,7 +1074,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10751074
mu = torch.sqrt(image_seq_lens / 256)
10761075
mu = mu * 0.75 + 0.25
10771076
scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(dtype=pixel_latents.dtype, device=pixel_latents.device)
1078-
scale_factors = scale_factors.view(4, 1, 1, 1)
1077+
scale_factors = scale_factors.view(len(batch["captions"]), 1, 1, 1)
10791078
noisy_model_input = (1.0 - scale_factors) * pixel_latents + scale_factors * noise
10801079
concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)
10811080
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
@@ -1114,7 +1113,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11141113
# flow-matching loss
11151114
target = noise - pixel_latents
11161115

1117-
weighting = weighting.unsqueeze(1).unsqueeze(2).unsqueeze(3) # [4, 1, 1, 1]
1116+
weighting = weighting.view(len(batch["captions"]), 1, 1, 1)
11181117
loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),1)
11191118
loss = loss.mean()
11201119
accelerator.backward(loss)

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20-
20+
from ...loaders import PeftAdapterMixin
2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...models.attention import FeedForward
2323
from ...models.attention_processor import Attention
2424
from ...models.modeling_utils import ModelMixin
2525
from ...models.normalization import AdaLayerNormContinuous
2626
from ...utils import logging
27+
from ..cache_utils import CacheMixin
2728
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2829
from ..modeling_outputs import Transformer2DModelOutput
2930

@@ -285,6 +286,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
285286

286287

287288
class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
289+
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
288290
r"""
289291
Args:
290292
patch_size (`int`, defaults to `2`):
@@ -390,7 +392,6 @@ def forward(
390392
p = self.config.patch_size
391393
post_patch_height = height // p
392394
post_patch_width = width // p
393-
394395
hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states)
395396

396397
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)

0 commit comments

Comments
 (0)