Skip to content

Commit f8945ce

Browse files
committed
[WIP] Add tensor-reload to align input from transformer block
1 parent 7916140 commit f8945ce

File tree

3 files changed

+51
-7
lines changed

3 files changed

+51
-7
lines changed

src/diffusers/models/normalization.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,18 @@ def __init__(
333333

334334
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
335335
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
336+
337+
####################################
336338
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
339+
# emb = self.linear(conditioning_embedding).to(x.dtype)
340+
####################################
341+
337342
scale, shift = torch.chunk(emb, 2, dim=1)
343+
344+
############################
338345
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
346+
# x = x * (1 + scale)[:, None, :] + shift[:, None, :]
347+
############################
339348
return x
340349

341350

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def __init__(
232232
embedding_dim=self.inner_dim,
233233
conditioning_embedding_dim=time_embed_dim,
234234
elementwise_affine=False,
235-
eps=1e-6,
235+
# eps=1e-6,
236+
eps=1e-5,
236237
)
237238
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
238239

@@ -399,8 +400,6 @@ def forward(
399400
)
400401
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
401402

402-
encoder_hidden_states_cond = prompt_embeds
403-
encoder_hidden_states_uncond = negative_prompt_embeds
404403
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
405404
emb_cond, emb_uncond = emb.chunk(2)
406405

@@ -409,6 +408,22 @@ def forward(
409408
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
410409
)
411410

411+
######################
412+
# prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
413+
# negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
414+
prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
415+
negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_uncondition_16_32.pt")[None, ::]
416+
417+
hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")
418+
hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")
419+
420+
emb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")
421+
emb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")
422+
######################
423+
424+
encoder_hidden_states_cond = prompt_embeds
425+
encoder_hidden_states_uncond = negative_prompt_embeds
426+
412427
for index_block, block in enumerate(self.transformer_blocks):
413428
if torch.is_grad_enabled() and self.gradient_checkpointing:
414429
...
@@ -418,16 +433,31 @@ def forward(
418433
encoder_hidden_states=encoder_hidden_states_cond,
419434
emb=emb_cond, # refactor later
420435
image_rotary_emb=image_rotary_emb,
436+
# image_rotary_emb=None,
421437
)
438+
###########################
439+
# hidden_states_cond, encoder_hidden_states_cond = (
440+
# self.norm_out.norm(hidden_states_cond),
441+
# self.norm_out.norm(encoder_hidden_states_cond),
442+
# )
443+
###########################
444+
422445
hidden_states_uncond, encoder_hidden_states_uncond = block(
423446
hidden_states=hidden_states_uncond,
424447
encoder_hidden_states=encoder_hidden_states_uncond,
425448
emb=emb_uncond, # refactor later
426449
image_rotary_emb=image_rotary_emb,
450+
# image_rotary_emb=None,
427451
)
428-
429-
hidden_states_cond = self.norm_out(hidden_states_cond, emb) # 结果对应于megatron里的final_layer_input
430-
hidden_states_uncond = self.norm_out(hidden_states_uncond, emb) # 结果对应于megatron里的final_layer_input
452+
###########################
453+
# hidden_states_uncond, encoder_hidden_states_uncond = (
454+
# self.norm_out.norm(hidden_states_uncond),
455+
# self.norm_out.norm(encoder_hidden_states_uncond),
456+
# )
457+
###########################
458+
459+
hidden_states_cond = self.norm_out(hidden_states_cond, emb_cond) # 结果对应于megatron里的final_layer_input
460+
hidden_states_uncond = self.norm_out(hidden_states_uncond, emb_uncond) # 结果对应于megatron里的final_layer_input
431461
hidden_states_cond = self.proj_out(hidden_states_cond) # (batch_size, height*width, patch_size*patch_size*out_channels)
432462
hidden_states_uncond = self.proj_out(hidden_states_uncond) # (batch_size, height*width, patch_size*patch_size*out_channels)
433463

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ def _get_glm_embeds(
216216
device=text_input_ids.device,
217217
)
218218
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
219-
prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True).hidden_states[-2]
219+
prompt_embeds = self.text_encoder(
220+
text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True
221+
).hidden_states[-2]
220222
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
221223
_, seq_len, _= prompt_embeds.shape
222224
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -592,6 +594,7 @@ def __call__(
592594

593595
# Prepare latents.
594596
latent_channels = self.transformer.config.in_channels
597+
#########################
595598
latents = self.prepare_latents(
596599
batch_size * num_images_per_prompt,
597600
latent_channels,
@@ -602,6 +605,8 @@ def __call__(
602605
generator,
603606
latents,
604607
)
608+
latents = torch.ones_like(latents)
609+
#########################
605610

606611
# Prepare additional timestep conditions
607612
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)

0 commit comments

Comments
 (0)