Skip to content

Commit 692e5cc

Browse files
update
1 parent b6e10e7 commit 692e5cc

File tree

3 files changed

+25
-39
lines changed

3 files changed

+25
-39
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step,
132132
control_image=validation_image,
133133
num_inference_steps=50,
134134
guidance_scale=args.guidance_scale,
135-
max_sequence_length=args.max_sequence_length, # For downstream task training usage, training can be performed on a batch basis.
135+
max_sequence_length=args.max_sequence_length, # For downstream task training usage, training can be performed on a batch basis.
136136
padding_type="max_length",
137137
generator=generator,
138138
height=args.resolution,
@@ -660,7 +660,7 @@ def prepare_train_dataset(dataset, accelerator):
660660
[
661661
transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
662662
transforms.ToTensor(),
663-
transforms.Lambda(lambda x: x * 2 - 1)
663+
transforms.Lambda(lambda x: x * 2 - 1),
664664
]
665665
)
666666

@@ -1074,7 +1074,6 @@ def load_model_hook(models, input_dir):
10741074
)
10751075

10761076
# Add noise according for cogview4
1077-
# FIXME: The issue of variable-length training has not been resolved, here it is still extended to the longest one.
10781077
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
10791078
timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
10801079
sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device)
@@ -1095,12 +1094,10 @@ def load_model_hook(models, input_dir):
10951094
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
10961095

10971096
with torch.no_grad():
1098-
# Since the batch will be padded, max_length should be used for padding.
1099-
prompt_embeds,pooled_prompt_embeds,= text_encoding_pipeline.encode_prompt(
1100-
captions, "",
1101-
max_sequence_length=args.max_sequence_length,
1102-
padding_type="max_length"
1103-
)
1097+
(
1098+
prompt_embeds,
1099+
pooled_prompt_embeds,
1100+
) = text_encoding_pipeline.encode_prompt(captions, "")
11041101
original_size = (args.resolution, args.resolution)
11051102
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
11061103

@@ -1109,8 +1106,6 @@ def load_model_hook(models, input_dir):
11091106

11101107
target_size = target_size.repeat(len(batch["captions"]), 1)
11111108
original_size = original_size.repeat(len(batch["captions"]), 1)
1112-
1113-
# TODO: Should a parameter be set here for passing? This is not present in Flux.
11141109
crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
11151110
crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1)
11161111

@@ -1140,7 +1135,8 @@ def load_model_hook(models, input_dir):
11401135

11411136
weighting = weighting.view(len(batch["captions"]), 1, 1, 1)
11421137
loss = torch.mean(
1143-
(weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1
1138+
(weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1139+
1,
11441140
)
11451141
loss = loss.mean()
11461142
accelerator.backward(loss)

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,21 +157,17 @@ def __call__(
157157
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
158158
)
159159

160-
# 4. Attention
160+
# 4. Attention and Attention Mask
161161
if attention_mask is not None:
162-
# construct attention_mask for concated sequence
163162
text_attention_mask = attention_mask.float().to(query.device)
164-
attention_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
165-
attention_mask[:, :text_seq_length] = text_attention_mask
166-
attention_mask = attention_mask.unsqueeze(2)
167-
attention_mask_matrix = attention_mask @ attention_mask.mT
168-
attention_mask_matrix = attention_mask_matrix == 1
169-
attention_mask_matrix = attention_mask_matrix.unsqueeze(1)
170-
attention_mask = attention_mask_matrix
171-
172-
hidden_states = F.scaled_dot_product_attention(
173-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
174-
)
163+
actual_text_seq_length = text_attention_mask.size(1)
164+
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
165+
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
166+
new_attention_mask = new_attention_mask.unsqueeze(2)
167+
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
168+
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
169+
170+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
175171
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
176172
hidden_states = hidden_states.type_as(query)
177173

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,11 @@ class CogView4ControlPipeline(DiffusionPipeline):
144144
Args:
145145
vae ([`AutoencoderKL`]):
146146
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
147-
text_encoder ([`T5EncoderModel`]):
148-
Frozen text-encoder. CogView4 uses
149-
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
150-
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
151-
tokenizer (`T5Tokenizer`):
147+
text_encoder ([`GLMModel`]):
148+
Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
149+
tokenizer (`PreTrainedTokenizer`):
152150
Tokenizer of class
153-
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
151+
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
154152
transformer ([`CogView4Transformer2DModel`]):
155153
A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
156154
scheduler ([`SchedulerMixin`]):
@@ -182,7 +180,6 @@ def _get_glm_embeds(
182180
prompt: Union[str, List[str]] = None,
183181
num_images_per_prompt: int = 1,
184182
max_sequence_length: int = 1024,
185-
padding_type: str = "longest",
186183
device: Optional[torch.device] = None,
187184
dtype: Optional[torch.dtype] = None,
188185
):
@@ -194,7 +191,7 @@ def _get_glm_embeds(
194191

195192
text_inputs = self.tokenizer(
196193
prompt,
197-
padding=padding_type,
194+
padding="longest", # not use max length
198195
max_length=max_sequence_length,
199196
truncation=True,
200197
add_special_tokens=True,
@@ -240,7 +237,6 @@ def encode_prompt(
240237
device: Optional[torch.device] = None,
241238
dtype: Optional[torch.dtype] = None,
242239
max_sequence_length: int = 1024,
243-
padding_type: str = "longest",
244240
):
245241
r"""
246242
Encodes the prompt into text encoder hidden states.
@@ -278,7 +274,7 @@ def encode_prompt(
278274
else:
279275
batch_size = prompt_embeds.shape[0]
280276
if prompt_embeds is None:
281-
prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, padding_type, device, dtype)
277+
prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype)
282278

283279
if do_classifier_free_guidance and negative_prompt_embeds is None:
284280
negative_prompt = negative_prompt or ""
@@ -297,7 +293,7 @@ def encode_prompt(
297293
)
298294

299295
negative_prompt_embeds = self._get_glm_embeds(
300-
negative_prompt, num_images_per_prompt, max_sequence_length, "longest", device, dtype
296+
negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype
301297
)
302298

303299
return prompt_embeds, negative_prompt_embeds
@@ -451,7 +447,6 @@ def __call__(
451447
] = None,
452448
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
453449
max_sequence_length: int = 1024,
454-
padding_type: str = "longest", # For downstream tasks, it can be modified to use max_length for implementation.
455450
) -> Union[CogView4PipelineOutput, Tuple]:
456451
"""
457452
Function invoked when calling the pipeline for generation.
@@ -581,8 +576,7 @@ def __call__(
581576
prompt_embeds=prompt_embeds,
582577
negative_prompt_embeds=negative_prompt_embeds,
583578
max_sequence_length=max_sequence_length,
584-
padding_type=padding_type,
585-
device=device
579+
device=device,
586580
)
587581

588582
# Prepare latents

0 commit comments

Comments
 (0)