@@ -232,7 +232,8 @@ def __init__(
232232 embedding_dim = self .inner_dim ,
233233 conditioning_embedding_dim = time_embed_dim ,
234234 elementwise_affine = False ,
235- eps = 1e-6 ,
235+ # eps=1e-6,
236+ eps = 1e-5 ,
236237 )
237238 self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True )
238239
@@ -399,8 +400,6 @@ def forward(
399400 )
400401 emb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
401402
402- encoder_hidden_states_cond = prompt_embeds
403- encoder_hidden_states_uncond = negative_prompt_embeds
404403 hidden_states_cond , hidden_states_uncond = hidden_states .chunk (2 )
405404 emb_cond , emb_uncond = emb .chunk (2 )
406405
@@ -409,6 +408,22 @@ def forward(
409408 patch_height , patch_width , target_h = patch_height , target_w = patch_width , device = hidden_states .device
410409 )
411410
411+ ######################
412+ # prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
413+ # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
414+ prompt_embeds = torch .load ("/home/lhy/code/cogview/cp_condition_0_16.pt" )[None , ::]
415+ negative_prompt_embeds = torch .load ("/home/lhy/code/cogview/cp_uncondition_16_32.pt" )[None , ::]
416+
417+ hidden_states_cond = torch .load ("/home/lhy/code/cogview/cp_vision_input_0_4096.pt" )
418+ hidden_states_uncond = torch .load ("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt" )
419+
420+ emb_cond = torch .load ("/home/lhy/code/cogview/time_embedding_0_1.pt" )
421+ emb_uncond = torch .load ("/home/lhy/code/cogview/time_embedding_1_2.pt" )
422+ ######################
423+
424+ encoder_hidden_states_cond = prompt_embeds
425+ encoder_hidden_states_uncond = negative_prompt_embeds
426+
412427 for index_block , block in enumerate (self .transformer_blocks ):
413428 if torch .is_grad_enabled () and self .gradient_checkpointing :
414429 ...
@@ -418,16 +433,31 @@ def forward(
418433 encoder_hidden_states = encoder_hidden_states_cond ,
419434 emb = emb_cond , # refactor later
420435 image_rotary_emb = image_rotary_emb ,
436+ # image_rotary_emb=None,
421437 )
438+ ###########################
439+ # hidden_states_cond, encoder_hidden_states_cond = (
440+ # self.norm_out.norm(hidden_states_cond),
441+ # self.norm_out.norm(encoder_hidden_states_cond),
442+ # )
443+ ###########################
444+
422445 hidden_states_uncond , encoder_hidden_states_uncond = block (
423446 hidden_states = hidden_states_uncond ,
424447 encoder_hidden_states = encoder_hidden_states_uncond ,
425448 emb = emb_uncond , # refactor later
426449 image_rotary_emb = image_rotary_emb ,
450+ # image_rotary_emb=None,
427451 )
428-
429- hidden_states_cond = self .norm_out (hidden_states_cond , emb ) # 结果对应于megatron里的final_layer_input
430- hidden_states_uncond = self .norm_out (hidden_states_uncond , emb ) # 结果对应于megatron里的final_layer_input
452+ ###########################
453+ # hidden_states_uncond, encoder_hidden_states_uncond = (
454+ # self.norm_out.norm(hidden_states_uncond),
455+ # self.norm_out.norm(encoder_hidden_states_uncond),
456+ # )
457+ ###########################
458+
459+ hidden_states_cond = self .norm_out (hidden_states_cond , emb_cond ) # 结果对应于megatron里的final_layer_input
460+ hidden_states_uncond = self .norm_out (hidden_states_uncond , emb_uncond ) # 结果对应于megatron里的final_layer_input
431461 hidden_states_cond = self .proj_out (hidden_states_cond ) # (batch_size, height*width, patch_size*patch_size*out_channels)
432462 hidden_states_uncond = self .proj_out (hidden_states_uncond ) # (batch_size, height*width, patch_size*patch_size*out_channels)
433463
0 commit comments