Skip to content

Commit 96b1075

Browse files
committed
fix cache-dit bug2
Signed-off-by: David Chen <530634352@qq.com>
1 parent 3c664f0 commit 96b1075

File tree

2 files changed

+16
-54
lines changed

2 files changed

+16
-54
lines changed

vllm_omni/diffusion/cache/cache_dit_backend.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -353,44 +353,6 @@ def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], N
353353
"""Enable cache-dit for LTX2 pipelines (audio-video transformer blocks)."""
354354
transformer = pipeline.transformer
355355

356-
if not getattr(transformer, "_cache_dit_ltx2_patched", False):
357-
def _wrap_block_forward(orig):
358-
@functools.wraps(orig)
359-
def _wrapped_forward(self, hidden_states, encoder_hidden_states=None, *args, **kwargs):
360-
audio_hidden_states = kwargs.pop("audio_hidden_states", None)
361-
362-
text_encoder_hidden_states = encoder_hidden_states
363-
if "encoder_hidden_states" in kwargs:
364-
text_encoder_hidden_states = kwargs.pop("encoder_hidden_states")
365-
audio_encoder_hidden_states = kwargs.pop("audio_encoder_hidden_states", None)
366-
temb = kwargs.pop("temb", None)
367-
temb_audio = kwargs.pop("temb_audio", None)
368-
temb_ca_scale_shift = kwargs.pop("temb_ca_scale_shift", None)
369-
temb_ca_audio_scale_shift = kwargs.pop("temb_ca_audio_scale_shift", None)
370-
temb_ca_gate = kwargs.pop("temb_ca_gate", None)
371-
temb_ca_audio_gate = kwargs.pop("temb_ca_audio_gate", None)
372-
373-
return orig(
374-
hidden_states,
375-
audio_hidden_states,
376-
text_encoder_hidden_states,
377-
audio_encoder_hidden_states,
378-
temb,
379-
temb_audio,
380-
temb_ca_scale_shift,
381-
temb_ca_audio_scale_shift,
382-
temb_ca_gate,
383-
temb_ca_audio_gate,
384-
**kwargs,
385-
)
386-
387-
return _wrapped_forward
388-
389-
for block in transformer.transformer_blocks:
390-
block.forward = _wrap_block_forward(block.forward).__get__(block, block.__class__)
391-
392-
transformer._cache_dit_ltx2_patched = True
393-
394356
db_cache_config = _build_db_cache_config(cache_config)
395357

396358
calibrator_config = None

vllm_omni/diffusion/models/ltx2/ltx2_transformer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,22 +1359,22 @@ def forward(
13591359
)
13601360
else:
13611361
hidden_states, audio_hidden_states = block(
1362-
hidden_states=hidden_states,
1363-
audio_hidden_states=audio_hidden_states,
1364-
encoder_hidden_states=encoder_hidden_states,
1365-
audio_encoder_hidden_states=audio_encoder_hidden_states,
1366-
temb=temb,
1367-
temb_audio=temb_audio,
1368-
temb_ca_scale_shift=video_cross_attn_scale_shift,
1369-
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1370-
temb_ca_gate=video_cross_attn_a2v_gate,
1371-
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1372-
video_rotary_emb=video_rotary_emb,
1373-
audio_rotary_emb=audio_rotary_emb,
1374-
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1375-
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1376-
encoder_attention_mask=encoder_attention_mask,
1377-
audio_encoder_attention_mask=audio_encoder_attention_mask,
1362+
hidden_states,
1363+
audio_hidden_states,
1364+
encoder_hidden_states,
1365+
audio_encoder_hidden_states,
1366+
temb,
1367+
temb_audio,
1368+
video_cross_attn_scale_shift,
1369+
audio_cross_attn_scale_shift,
1370+
video_cross_attn_a2v_gate,
1371+
audio_cross_attn_v2a_gate,
1372+
video_rotary_emb,
1373+
audio_rotary_emb,
1374+
video_cross_attn_rotary_emb,
1375+
audio_cross_attn_rotary_emb,
1376+
encoder_attention_mask,
1377+
audio_encoder_attention_mask,
13781378
)
13791379

13801380
# 6. Output layers (including unpatchification)

0 commit comments

Comments
 (0)