diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 915aae615aff..4086cd2f51aa 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -448,4 +448,4 @@ def forward( if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) \ No newline at end of file + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index ce4eb36702fc..64744baba565 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -208,7 +208,6 @@ def __init__( ) self.vae_scale_factor = 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) - self.max_sequence_length = 256 self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None @@ -222,7 +221,7 @@ def _get_gemma_prompt_embeds( prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, - max_length: Optional[int] = None, + max_length: int = 256, ): device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt @@ -231,7 +230,7 @@ def _get_gemma_prompt_embeds( text_inputs = self.tokenizer( prompt, pad_to_multiple_of=8, - max_length=self.max_sequence_length, + max_length=max_length, truncation=True, padding=True, return_tensors="pt", @@ -240,10 +239,10 @@ def _get_gemma_prompt_embeds( untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1]) + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because Gemma can only handle sequences up to" - f" {self.max_sequence_length} tokens: {removed_text}" + f" {max_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask.to(device) @@ -283,6 +282,7 @@ def encode_prompt( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, system_prompt: Optional[str] = None, + max_sequence_length: int = 256, **kwargs, ): r""" @@ -326,6 +326,7 @@ def encode_prompt( prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device, + max_length=max_sequence_length, ) # Get negative embeddings for classifier free guidance @@ -791,8 +792,6 @@ def __call__( # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - progress_bar.update() - if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: @@ -822,4 +821,4 @@ def __call__( if not return_dict: return (image,) - return ImagePipelineOutput(images=image) \ No newline at end of file + return ImagePipelineOutput(images=image)