Skip to content

Commit 69423ec

Browse files
committed
fix cache-dit bug
Signed-off-by: David Chen <530634352@qq.com>
1 parent 5ae528a commit 69423ec

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

vllm_omni/diffusion/cache/cache_dit_backend.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,49 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
351351

352352
def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
353353
"""Enable cache-dit for LTX2 pipelines (audio-video transformer blocks)."""
354+
transformer = pipeline.transformer
355+
356+
if not getattr(transformer, "_cache_dit_ltx2_patched", False):
357+
def _wrap_block_forward(orig):
358+
@functools.wraps(orig)
359+
def _wrapped_forward(self, hidden_states, encoder_hidden_states=None, *args, **kwargs):
360+
audio_hidden_states = encoder_hidden_states
361+
if "audio_hidden_states" in kwargs:
362+
if audio_hidden_states is None:
363+
audio_hidden_states = kwargs.pop("audio_hidden_states")
364+
else:
365+
kwargs.pop("audio_hidden_states")
366+
367+
text_encoder_hidden_states = kwargs.pop("encoder_hidden_states", None)
368+
audio_encoder_hidden_states = kwargs.pop("audio_encoder_hidden_states", None)
369+
temb = kwargs.pop("temb", None)
370+
temb_audio = kwargs.pop("temb_audio", None)
371+
temb_ca_scale_shift = kwargs.pop("temb_ca_scale_shift", None)
372+
temb_ca_audio_scale_shift = kwargs.pop("temb_ca_audio_scale_shift", None)
373+
temb_ca_gate = kwargs.pop("temb_ca_gate", None)
374+
temb_ca_audio_gate = kwargs.pop("temb_ca_audio_gate", None)
375+
376+
return orig(
377+
hidden_states,
378+
audio_hidden_states,
379+
text_encoder_hidden_states,
380+
audio_encoder_hidden_states,
381+
temb,
382+
temb_audio,
383+
temb_ca_scale_shift,
384+
temb_ca_audio_scale_shift,
385+
temb_ca_gate,
386+
temb_ca_audio_gate,
387+
**kwargs,
388+
)
389+
390+
return _wrapped_forward
391+
392+
for block in transformer.transformer_blocks:
393+
block.forward = _wrap_block_forward(block.forward).__get__(block, block.__class__)
394+
395+
transformer._cache_dit_ltx2_patched = True
396+
354397
db_cache_config = _build_db_cache_config(cache_config)
355398

356399
calibrator_config = None
@@ -359,7 +402,6 @@ def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], N
359402
calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
360403
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
361404

362-
transformer = pipeline.transformer
363405
blocks = transformer.transformer_blocks
364406

365407
logger.info(

0 commit comments

Comments
 (0)