Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,4 @@ def forward(
if not return_dict:
return (output,)

return Transformer2DModelOutput(sample=output)
return Transformer2DModelOutput(sample=output)
15 changes: 7 additions & 8 deletions src/diffusers/pipelines/lumina2/pipeline_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def __init__(
)
self.vae_scale_factor = 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.max_sequence_length = 256
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
Expand All @@ -222,7 +221,7 @@ def _get_gemma_prompt_embeds(
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
max_length: Optional[int] = None,
max_length: int = 256,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
Expand All @@ -231,7 +230,7 @@ def _get_gemma_prompt_embeds(
text_inputs = self.tokenizer(
prompt,
pad_to_multiple_of=8,
max_length=self.max_sequence_length,
max_length=max_length,
truncation=True,
padding=True,
return_tensors="pt",
Expand All @@ -240,10 +239,10 @@ def _get_gemma_prompt_embeds(
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)

if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1])
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because Gemma can only handle sequences up to"
f" {self.max_sequence_length} tokens: {removed_text}"
f" {max_length} tokens: {removed_text}"
)

prompt_attention_mask = text_inputs.attention_mask.to(device)
Expand Down Expand Up @@ -283,6 +282,7 @@ def encode_prompt(
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
system_prompt: Optional[str] = None,
max_sequence_length: int = 256,
**kwargs,
):
r"""
Expand Down Expand Up @@ -326,6 +326,7 @@ def encode_prompt(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
max_length=max_sequence_length,
)

# Get negative embeddings for classifier free guidance
Expand Down Expand Up @@ -791,8 +792,6 @@ def __call__(
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)

progress_bar.update()

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
Expand Down Expand Up @@ -822,4 +821,4 @@ def __call__(
if not return_dict:
return (image,)

return ImagePipelineOutput(images=image)
return ImagePipelineOutput(images=image)