@@ -397,59 +397,62 @@ def forward(
397397 hidden_states , prompt_embeds , negative_prompt_embeds = self .patch_embed (
398398 hidden_states , prompt_embeds , negative_prompt_embeds
399399 )
400+ emb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
400401
401- encoder_hidden_states = torch .cat ([prompt_embeds , negative_prompt_embeds ], dim = 0 )
402+ encoder_hidden_states_cond = prompt_embeds
403+ encoder_hidden_states_uncond = negative_prompt_embeds
404+ hidden_states_cond , hidden_states_uncond = hidden_states .chunk (2 )
405+ emb_cond , emb_uncond = emb .chunk (2 )
402406
403407 # prepare image_rotary__emb
404408 image_rotary_emb = self .get_rope_embedding (
405409 patch_height , patch_width , target_h = patch_height , target_w = patch_width , device = hidden_states .device
406410 )
407411
408- emb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
409-
410412 for index_block , block in enumerate (self .transformer_blocks ):
411413 if torch .is_grad_enabled () and self .gradient_checkpointing :
412-
413- def create_custom_forward (module ):
414- def custom_forward (* inputs ):
415- return module (* inputs )
416-
417- return custom_forward
418-
419- ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
420- hidden_states , encoder_hidden_states = torch .utils .checkpoint .checkpoint (
421- create_custom_forward (block ),
422- hidden_states ,
423- encoder_hidden_states ,
424- emb = emb ,
414+ ...
415+ else :
416+ hidden_states_cond , encoder_hidden_states_cond = block (
417+ hidden_states = hidden_states_cond ,
418+ encoder_hidden_states = encoder_hidden_states_cond ,
419+ emb = emb_cond , # refactor later
425420 image_rotary_emb = image_rotary_emb ,
426- ** ckpt_kwargs ,
427421 )
428- else :
429- hidden_states , encoder_hidden_states = block (
430- hidden_states = hidden_states ,
431- encoder_hidden_states = encoder_hidden_states ,
432- emb = emb ,
422+ hidden_states_uncond , encoder_hidden_states_uncond = block (
423+ hidden_states = hidden_states_uncond ,
424+ encoder_hidden_states = encoder_hidden_states_uncond ,
425+ emb = emb_uncond , # refactor later
433426 image_rotary_emb = image_rotary_emb ,
434427 )
435428
436- hidden_states = self .norm_out (hidden_states , emb ) # 结果对应于megatron里的final_layer_input
437- hidden_states = self .proj_out (hidden_states ) # (batch_size, height*width, patch_size*patch_size*out_channels)
429+ hidden_states_cond = self .norm_out (hidden_states_cond , emb ) # 结果对应于megatron里的final_layer_input
430+ hidden_states_uncond = self .norm_out (hidden_states_uncond , emb ) # 结果对应于megatron里的final_layer_input
431+ hidden_states_cond = self .proj_out (hidden_states_cond ) # (batch_size, height*width, patch_size*patch_size*out_channels)
432+ hidden_states_uncond = self .proj_out (hidden_states_uncond ) # (batch_size, height*width, patch_size*patch_size*out_channels)
438433
439434 # unpatchify
440435 patch_size = self .config .patch_size
441436 height = height // patch_size
442437 width = width // patch_size
443438
444- hidden_states = hidden_states .reshape (
445- shape = (hidden_states .shape [0 ], height , width , self .out_channels , patch_size , patch_size )
439+ hidden_states_cond = hidden_states_cond .reshape (
440+ shape = (hidden_states_cond .shape [0 ], height , width , self .out_channels , patch_size , patch_size )
441+ )
442+ hidden_states_cond = torch .einsum ("nhwcpq->nchpwq" , hidden_states_cond )
443+ output_cond = hidden_states_cond .reshape (
444+ shape = (hidden_states_cond .shape [0 ], self .out_channels , height * patch_size , width * patch_size )
445+ )
446+
447+ hidden_states_uncond = hidden_states_uncond .reshape (
448+ shape = (hidden_states_uncond .shape [0 ], height , width , self .out_channels , patch_size , patch_size )
446449 )
447- hidden_states = torch .einsum ("nhwcpq->nchpwq" , hidden_states )
448- output = hidden_states .reshape (
449- shape = (hidden_states .shape [0 ], self .out_channels , height * patch_size , width * patch_size )
450+ hidden_states_uncond = torch .einsum ("nhwcpq->nchpwq" , hidden_states_uncond )
451+ output_uncond = hidden_states_uncond .reshape (
452+ shape = (hidden_states_uncond .shape [0 ], self .out_channels , height * patch_size , width * patch_size )
450453 )
451454
452455 if not return_dict :
453- return (output , )
456+ return (output_cond , output_uncond )
454457
455- return Transformer2DModelOutput (sample = output )
458+ return Transformer2DModelOutput (sample = output_cond ), Transformer2DModelOutput ( sample = output_uncond )
0 commit comments