2828from ...utils .torch_utils import randn_tensor
2929from .pipeline_output import CogView4PipelineOutput
3030
31-
3231if is_torch_xla_available ():
3332 import torch_xla .core .xla_model as xm
3433
3837
3938logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
4039
41-
4240EXAMPLE_DOC_STRING = """
4341 Examples:
4442 ```python
@@ -180,39 +178,47 @@ def _get_glm_embeds(
180178
181179 text_inputs = self .tokenizer (
182180 prompt ,
183- padding = "max_length" ,
181+ padding = "longest" , # not use max length
184182 max_length = max_sequence_length ,
185183 truncation = True ,
186184 add_special_tokens = True ,
187185 return_tensors = "pt" ,
188186 )
189187 text_input_ids = text_inputs .input_ids
190188 untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
191-
192189 if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
193190 removed_text = self .tokenizer .batch_decode (untruncated_ids [:, max_sequence_length - 1 : - 1 ])
194191 logger .warning (
195192 "The following part of your input was truncated because `max_sequence_length` is set to "
196193 f" { max_sequence_length } tokens: { removed_text } "
197194 )
195+ current_length = text_input_ids .shape [1 ]
196+ pad_length = (16 - (current_length % 16 )) % 16
197+ if pad_length > 0 :
198+ pad_ids = torch .full (
199+ (text_input_ids .shape [0 ], pad_length ),
200+ fill_value = 151329 , # <|endoftext|> of glm-4
201+ dtype = text_input_ids .dtype ,
202+ device = text_input_ids .device ,
203+ )
204+ text_input_ids = torch .cat ([pad_ids , text_input_ids ], dim = 1 )
198205
199- prompt_embeds = self .text_encoder ( text_input_ids . to ( device ) )[0 ]
206+ prompt_embeds = self .text_encoder . model . embed_tokens ( text_input_ids )[0 ]
200207 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
201-
202- # duplicate text embeddings for each generation per prompt, using mps friendly method
203- _ , seq_len , _ = prompt_embeds .shape
208+ seq_len , _ = prompt_embeds .shape
204209 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
205210 prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
206211 return prompt_embeds
207212
208213 def encode_prompt (
209214 self ,
210215 prompt : Union [str , List [str ]],
216+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
211217 do_classifier_free_guidance : bool = True ,
212218 num_images_per_prompt : int = 1 ,
213219 prompt_embeds : Optional [torch .Tensor ] = None ,
214220 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
215- max_sequence_length : int = 224 ,
221+ max_sequence_length : int = 1024 ,
216222 device : Optional [torch .device ] = None ,
217223 dtype : Optional [torch .dtype ] = None ,
218224 ):
@@ -222,6 +228,10 @@ def encode_prompt(
222228 Args:
223229 prompt (`str` or `List[str]`, *optional*):
224230 prompt to be encoded
231+ negative_prompt (`str` or `List[str]`, *optional*):
232+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
233+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
234+ less than `1`).
225235 do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
226236 Whether to use classifier free guidance or not.
227237 num_images_per_prompt (`int`, *optional*, defaults to 1):
@@ -233,7 +243,7 @@ def encode_prompt(
233243 Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
234244 weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
235245 argument.
236- max_sequence_length (`int`, defaults to `224 `):
246+ max_sequence_length (`int`, defaults to `1024 `):
237247 Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
238248 device: (`torch.device`, *optional*):
239249 torch device
@@ -249,7 +259,7 @@ def encode_prompt(
249259 batch_size = prompt_embeds .shape [0 ]
250260
251261 if prompt_embeds is None :
252- prompt_embeds = self ._get_t5_prompt_embeds (
262+ prompt_embeds = self ._get_glm_embeds (
253263 prompt = prompt ,
254264 num_images_per_prompt = num_images_per_prompt ,
255265 max_sequence_length = max_sequence_length ,
@@ -258,7 +268,13 @@ def encode_prompt(
258268 )
259269
260270 if do_classifier_free_guidance and negative_prompt is None :
261- negative_prompt_embeds = prompt_embeds .new_zeros (prompt_embeds .shape )
271+ negative_prompt_embeds = self ._get_glm_embeds (
272+ prompt = "" ,
273+ num_images_per_prompt = num_images_per_prompt ,
274+ max_sequence_length = max_sequence_length ,
275+ device = device ,
276+ dtype = dtype ,
277+ )
262278
263279 if do_classifier_free_guidance and negative_prompt_embeds is None :
264280 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
@@ -275,14 +291,13 @@ def encode_prompt(
275291 " the batch size of `prompt`."
276292 )
277293
278- negative_prompt_embeds = self ._get_t5_prompt_embeds (
294+ negative_prompt_embeds = self ._get_glm_embeds (
279295 prompt = negative_prompt ,
280296 num_images_per_prompt = num_images_per_prompt ,
281297 max_sequence_length = max_sequence_length ,
282298 device = device ,
283299 dtype = dtype ,
284300 )
285-
286301 return prompt_embeds , negative_prompt_embeds
287302
288303 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -422,7 +437,7 @@ def __call__(
422437 Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
423438 ] = None ,
424439 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
425- CogView4Pipeline : int = 224 ,
440+ max_sequence_length : int = 1024 ,
426441 ) -> Union [CogView4PipelineOutput , Tuple ]:
427442 """
428443 Function invoked when calling the pipeline for generation.
@@ -543,7 +558,6 @@ def __call__(
543558 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
544559 # corresponds to doing no classifier free guidance.
545560 do_classifier_free_guidance = guidance_scale > 1.0
546-
547561 # 3. Encode input prompt
548562 prompt_embeds , negative_prompt_embeds = self .encode_prompt (
549563 prompt ,
0 commit comments