@@ -174,7 +174,7 @@ def __init__(
174174 )
175175 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
176176 self .image_factor = 16
177-
177+ self . text_projector = torch . nn . Linear ( 4096 , 4096 )
178178 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor )
179179
180180 def _get_glm_embeds (
@@ -217,10 +217,12 @@ def _get_glm_embeds(
217217 device = text_input_ids .device ,
218218 )
219219 text_input_ids = torch .cat ([pad_ids , text_input_ids ], dim = 1 )
220-
221- prompt_embeds = self .text_encoder .model .embed_tokens (text_input_ids .to (self .text_encoder .model .device ))[0 ]
220+ prompt_embeds = self .text_encoder (text_input_ids .to (self .text_encoder .model .device ), output_hidden_states = True ).hidden_states [- 2 ]
221+ self .text_projector .to (dtype = dtype , device = device )
222+ prompt_embeds = self .text_projector (prompt_embeds )
223+ breakpoint ()
222224 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
223- seq_len , _ = prompt_embeds .shape
225+ _ , seq_len , _ = prompt_embeds .shape
224226 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
225227 prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
226228 return prompt_embeds
0 commit comments