Skip to content

Commit b86bfd4

Browse files
use with -2 hidden state
1 parent f608f82 commit b86bfd4

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
)
175175
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
176176
self.image_factor = 16
177-
177+
self.text_projector = torch.nn.Linear(4096, 4096)
178178
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
179179

180180
def _get_glm_embeds(
@@ -217,10 +217,12 @@ def _get_glm_embeds(
217217
device=text_input_ids.device,
218218
)
219219
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
220-
221-
prompt_embeds = self.text_encoder.model.embed_tokens(text_input_ids.to(self.text_encoder.model.device))[0]
220+
prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True).hidden_states[-2]
221+
self.text_projector.to(dtype=dtype, device=device)
222+
prompt_embeds = self.text_projector(prompt_embeds)
223+
breakpoint()
222224
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
223-
seq_len, _ = prompt_embeds.shape
225+
_, seq_len, _= prompt_embeds.shape
224226
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
225227
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
226228
return prompt_embeds

0 commit comments

Comments
 (0)