Skip to content

Commit f608f82

Browse files
committed
[WIP][cogview4][refactor]: Split condition/uncondition forward pass in CogView4 pipeline
Split the forward pass for conditional and unconditional predictions in the CogView4 pipeline to match the original implementation. The noise prediction is now done separately for each case before combining them for guidance. However, the results still need improvement. This is a work in progress as the generated images are not yet matching expected quality.
1 parent 0ab7260 commit f608f82

File tree

2 files changed

+35
-49
lines changed

2 files changed

+35
-49
lines changed

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -397,59 +397,62 @@ def forward(
397397
hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
398398
hidden_states, prompt_embeds, negative_prompt_embeds
399399
)
400+
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
400401

401-
encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
402+
encoder_hidden_states_cond = prompt_embeds
403+
encoder_hidden_states_uncond = negative_prompt_embeds
404+
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
405+
emb_cond, emb_uncond = emb.chunk(2)
402406

403407
# prepare image_rotary__emb
404408
image_rotary_emb = self.get_rope_embedding(
405409
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
406410
)
407411

408-
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
409-
410412
for index_block, block in enumerate(self.transformer_blocks):
411413
if torch.is_grad_enabled() and self.gradient_checkpointing:
412-
413-
def create_custom_forward(module):
414-
def custom_forward(*inputs):
415-
return module(*inputs)
416-
417-
return custom_forward
418-
419-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
420-
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
421-
create_custom_forward(block),
422-
hidden_states,
423-
encoder_hidden_states,
424-
emb=emb,
414+
...
415+
else:
416+
hidden_states_cond, encoder_hidden_states_cond = block(
417+
hidden_states=hidden_states_cond,
418+
encoder_hidden_states=encoder_hidden_states_cond,
419+
emb=emb_cond, # refactor later
425420
image_rotary_emb=image_rotary_emb,
426-
**ckpt_kwargs,
427421
)
428-
else:
429-
hidden_states, encoder_hidden_states = block(
430-
hidden_states=hidden_states,
431-
encoder_hidden_states=encoder_hidden_states,
432-
emb=emb,
422+
hidden_states_uncond, encoder_hidden_states_uncond = block(
423+
hidden_states=hidden_states_uncond,
424+
encoder_hidden_states=encoder_hidden_states_uncond,
425+
emb=emb_uncond, # refactor later
433426
image_rotary_emb=image_rotary_emb,
434427
)
435428

436-
hidden_states = self.norm_out(hidden_states, emb) # 结果对应于megatron里的final_layer_input
437-
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
429+
hidden_states_cond = self.norm_out(hidden_states_cond, emb) # 结果对应于megatron里的final_layer_input
430+
hidden_states_uncond = self.norm_out(hidden_states_uncond, emb) # 结果对应于megatron里的final_layer_input
431+
hidden_states_cond = self.proj_out(hidden_states_cond) # (batch_size, height*width, patch_size*patch_size*out_channels)
432+
hidden_states_uncond = self.proj_out(hidden_states_uncond) # (batch_size, height*width, patch_size*patch_size*out_channels)
438433

439434
# unpatchify
440435
patch_size = self.config.patch_size
441436
height = height // patch_size
442437
width = width // patch_size
443438

444-
hidden_states = hidden_states.reshape(
445-
shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
439+
hidden_states_cond = hidden_states_cond.reshape(
440+
shape=(hidden_states_cond.shape[0], height, width, self.out_channels, patch_size, patch_size)
441+
)
442+
hidden_states_cond = torch.einsum("nhwcpq->nchpwq", hidden_states_cond)
443+
output_cond = hidden_states_cond.reshape(
444+
shape=(hidden_states_cond.shape[0], self.out_channels, height * patch_size, width * patch_size)
445+
)
446+
447+
hidden_states_uncond = hidden_states_uncond.reshape(
448+
shape=(hidden_states_uncond.shape[0], height, width, self.out_channels, patch_size, patch_size)
446449
)
447-
hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
448-
output = hidden_states.reshape(
449-
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
450+
hidden_states_uncond = torch.einsum("nhwcpq->nchpwq", hidden_states_uncond)
451+
output_uncond = hidden_states_uncond.reshape(
452+
shape=(hidden_states_uncond.shape[0], self.out_channels, height * patch_size, width * patch_size)
450453
)
451454

452455
if not return_dict:
453-
return (output,)
456+
return (output_cond, output_uncond)
454457

455-
return Transformer2DModelOutput(sample=output)
458+
return Transformer2DModelOutput(sample=output_cond), Transformer2DModelOutput(sample=output_uncond)

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -314,23 +314,6 @@ def encode_prompt(
314314
dtype=dtype,
315315
)
316316

317-
# TODO: 先pad 0 ,后续再处理不同长度的问题 (lhy: 这里改为pad padding token试试)
318-
seq_len_prompt = prompt_embeds.shape[1]
319-
seq_len_neg = negative_prompt_embeds.shape[1]
320-
if seq_len_neg < seq_len_prompt:
321-
# 创建一个新的张量,大小为 [batch_size, seq_len_prompt, hidden_size]
322-
batch_size, seq_len, hidden_size = negative_prompt_embeds.shape
323-
# 填充后的张量
324-
padded_negative_prompt = torch.full(
325-
(batch_size, seq_len_prompt - seq_len_neg),
326-
fill_value=self.tokenizer.pad_token_id,
327-
device=negative_prompt_embeds.device,
328-
)
329-
padded_negative_prompt_embeds = self.text_encoder.model.embed_tokens(
330-
padded_negative_prompt.to(self.text_encoder.model.device)
331-
)
332-
negative_prompt_embeds = torch.cat([padded_negative_prompt_embeds, negative_prompt_embeds], dim=1)
333-
assert negative_prompt_embeds.shape == prompt_embeds.shape
334317
return prompt_embeds, negative_prompt_embeds
335318

336319
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -680,7 +663,7 @@ def __call__(
680663

681664
# perform guidance
682665
if do_classifier_free_guidance:
683-
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
666+
noise_pred_cond, noise_pred_uncond = noise_pred
684667
noise_pred_guided = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
685668

686669
###########################

0 commit comments

Comments
 (0)