Skip to content

Commit 3c664f0

Browse files
committed
fix cache-dit bug
Signed-off-by: David Chen <530634352@qq.com>
1 parent 69423ec commit 3c664f0

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

vllm_omni/diffusion/cache/cache_dit_backend.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,14 +357,11 @@ def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], N
357357
def _wrap_block_forward(orig):
358358
@functools.wraps(orig)
359359
def _wrapped_forward(self, hidden_states, encoder_hidden_states=None, *args, **kwargs):
360-
audio_hidden_states = encoder_hidden_states
361-
if "audio_hidden_states" in kwargs:
362-
if audio_hidden_states is None:
363-
audio_hidden_states = kwargs.pop("audio_hidden_states")
364-
else:
365-
kwargs.pop("audio_hidden_states")
360+
audio_hidden_states = kwargs.pop("audio_hidden_states", None)
366361

367-
text_encoder_hidden_states = kwargs.pop("encoder_hidden_states", None)
362+
text_encoder_hidden_states = encoder_hidden_states
363+
if "encoder_hidden_states" in kwargs:
364+
text_encoder_hidden_states = kwargs.pop("encoder_hidden_states")
368365
audio_encoder_hidden_states = kwargs.pop("audio_encoder_hidden_states", None)
369366
temb = kwargs.pop("temb", None)
370367
temb_audio = kwargs.pop("temb_audio", None)

0 commit comments

Comments
 (0)