Skip to content

Commit 6163679

Browse files
encode with glm
1 parent eba11fa commit 6163679

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from ...utils.torch_utils import randn_tensor
2929
from .pipeline_output import CogView4PipelineOutput
3030

31-
3231
if is_torch_xla_available():
3332
import torch_xla.core.xla_model as xm
3433

@@ -38,7 +37,6 @@
3837

3938
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4039

41-
4240
EXAMPLE_DOC_STRING = """
4341
Examples:
4442
```python
@@ -180,39 +178,47 @@ def _get_glm_embeds(
180178

181179
text_inputs = self.tokenizer(
182180
prompt,
183-
padding="max_length",
181+
padding="longest", # not use max length
184182
max_length=max_sequence_length,
185183
truncation=True,
186184
add_special_tokens=True,
187185
return_tensors="pt",
188186
)
189187
text_input_ids = text_inputs.input_ids
190188
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
191-
192189
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
193190
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
194191
logger.warning(
195192
"The following part of your input was truncated because `max_sequence_length` is set to "
196193
f" {max_sequence_length} tokens: {removed_text}"
197194
)
195+
current_length = text_input_ids.shape[1]
196+
pad_length = (16 - (current_length % 16)) % 16
197+
if pad_length > 0:
198+
pad_ids = torch.full(
199+
(text_input_ids.shape[0], pad_length),
200+
fill_value=151329, # <|endoftext|> of glm-4
201+
dtype=text_input_ids.dtype,
202+
device=text_input_ids.device,
203+
)
204+
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
198205

199-
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
206+
prompt_embeds = self.text_encoder.model.embed_tokens(text_input_ids)[0]
200207
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
201-
202-
# duplicate text embeddings for each generation per prompt, using mps friendly method
203-
_, seq_len, _ = prompt_embeds.shape
208+
seq_len, _ = prompt_embeds.shape
204209
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
205210
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
206211
return prompt_embeds
207212

208213
def encode_prompt(
209214
self,
210215
prompt: Union[str, List[str]],
216+
negative_prompt: Optional[Union[str, List[str]]] = None,
211217
do_classifier_free_guidance: bool = True,
212218
num_images_per_prompt: int = 1,
213219
prompt_embeds: Optional[torch.Tensor] = None,
214220
negative_prompt_embeds: Optional[torch.Tensor] = None,
215-
max_sequence_length: int = 224,
221+
max_sequence_length: int = 1024,
216222
device: Optional[torch.device] = None,
217223
dtype: Optional[torch.dtype] = None,
218224
):
@@ -222,6 +228,10 @@ def encode_prompt(
222228
Args:
223229
prompt (`str` or `List[str]`, *optional*):
224230
prompt to be encoded
231+
negative_prompt (`str` or `List[str]`, *optional*):
232+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
233+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
234+
less than `1`).
225235
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
226236
Whether to use classifier free guidance or not.
227237
num_images_per_prompt (`int`, *optional*, defaults to 1):
@@ -233,7 +243,7 @@ def encode_prompt(
233243
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
234244
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
235245
argument.
236-
max_sequence_length (`int`, defaults to `224`):
246+
max_sequence_length (`int`, defaults to `1024`):
237247
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
238248
device: (`torch.device`, *optional*):
239249
torch device
@@ -249,7 +259,7 @@ def encode_prompt(
249259
batch_size = prompt_embeds.shape[0]
250260

251261
if prompt_embeds is None:
252-
prompt_embeds = self._get_t5_prompt_embeds(
262+
prompt_embeds = self._get_glm_embeds(
253263
prompt=prompt,
254264
num_images_per_prompt=num_images_per_prompt,
255265
max_sequence_length=max_sequence_length,
@@ -258,7 +268,13 @@ def encode_prompt(
258268
)
259269

260270
if do_classifier_free_guidance and negative_prompt is None:
261-
negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape)
271+
negative_prompt_embeds = self._get_glm_embeds(
272+
prompt="",
273+
num_images_per_prompt=num_images_per_prompt,
274+
max_sequence_length=max_sequence_length,
275+
device=device,
276+
dtype=dtype,
277+
)
262278

263279
if do_classifier_free_guidance and negative_prompt_embeds is None:
264280
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
@@ -275,14 +291,13 @@ def encode_prompt(
275291
" the batch size of `prompt`."
276292
)
277293

278-
negative_prompt_embeds = self._get_t5_prompt_embeds(
294+
negative_prompt_embeds = self._get_glm_embeds(
279295
prompt=negative_prompt,
280296
num_images_per_prompt=num_images_per_prompt,
281297
max_sequence_length=max_sequence_length,
282298
device=device,
283299
dtype=dtype,
284300
)
285-
286301
return prompt_embeds, negative_prompt_embeds
287302

288303
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -422,7 +437,7 @@ def __call__(
422437
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
423438
] = None,
424439
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
425-
CogView4Pipeline: int = 224,
440+
max_sequence_length: int = 1024,
426441
) -> Union[CogView4PipelineOutput, Tuple]:
427442
"""
428443
Function invoked when calling the pipeline for generation.
@@ -543,7 +558,6 @@ def __call__(
543558
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
544559
# corresponds to doing no classifier free guidance.
545560
do_classifier_free_guidance = guidance_scale > 1.0
546-
547561
# 3. Encode input prompt
548562
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
549563
prompt,

0 commit comments

Comments
 (0)