Skip to content

Commit 8304adc

Browse files
authored
Make zeroing prompt embeds for Mochi Pipeline configurable (huggingface#10284)
update
1 parent b389f33 commit 8304adc

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def __init__(
188188
text_encoder: T5EncoderModel,
189189
tokenizer: T5TokenizerFast,
190190
transformer: MochiTransformer3DModel,
191+
force_zeros_for_empty_prompt: bool = False,
191192
):
192193
super().__init__()
193194

@@ -205,10 +206,11 @@ def __init__(
205206

206207
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
207208
self.tokenizer_max_length = (
208-
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
209+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
209210
)
210211
self.default_height = 480
211212
self.default_width = 848
213+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
212214

213215
def _get_t5_prompt_embeds(
214216
self,
@@ -236,7 +238,11 @@ def _get_t5_prompt_embeds(
236238
text_input_ids = text_inputs.input_ids
237239
prompt_attention_mask = text_inputs.attention_mask
238240
prompt_attention_mask = prompt_attention_mask.bool().to(device)
239-
if prompt == "" or prompt[-1] == "":
241+
242+
# The original Mochi implementation zeros out empty negative prompts
243+
# but this can lead to overflow when placing the entire pipeline under the autocast context
244+
# adding this here so that we can enable zeroing prompts if necessary
245+
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
240246
text_input_ids = torch.zeros_like(text_input_ids, device=device)
241247
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
242248

0 commit comments

Comments
 (0)