@@ -351,6 +351,49 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
351351
352352def enable_cache_for_ltx2 (pipeline : Any , cache_config : Any ) -> Callable [[int ], None ]:
353353 """Enable cache-dit for LTX2 pipelines (audio-video transformer blocks)."""
354+ transformer = pipeline .transformer
355+
356+ if not getattr (transformer , "_cache_dit_ltx2_patched" , False ):
357+ def _wrap_block_forward (orig ):
358+ @functools .wraps (orig )
359+ def _wrapped_forward (self , hidden_states , encoder_hidden_states = None , * args , ** kwargs ):
360+ audio_hidden_states = encoder_hidden_states
361+ if "audio_hidden_states" in kwargs :
362+ if audio_hidden_states is None :
363+ audio_hidden_states = kwargs .pop ("audio_hidden_states" )
364+ else :
365+ kwargs .pop ("audio_hidden_states" )
366+
367+ text_encoder_hidden_states = kwargs .pop ("encoder_hidden_states" , None )
368+ audio_encoder_hidden_states = kwargs .pop ("audio_encoder_hidden_states" , None )
369+ temb = kwargs .pop ("temb" , None )
370+ temb_audio = kwargs .pop ("temb_audio" , None )
371+ temb_ca_scale_shift = kwargs .pop ("temb_ca_scale_shift" , None )
372+ temb_ca_audio_scale_shift = kwargs .pop ("temb_ca_audio_scale_shift" , None )
373+ temb_ca_gate = kwargs .pop ("temb_ca_gate" , None )
374+ temb_ca_audio_gate = kwargs .pop ("temb_ca_audio_gate" , None )
375+
376+ return orig (
377+ hidden_states ,
378+ audio_hidden_states ,
379+ text_encoder_hidden_states ,
380+ audio_encoder_hidden_states ,
381+ temb ,
382+ temb_audio ,
383+ temb_ca_scale_shift ,
384+ temb_ca_audio_scale_shift ,
385+ temb_ca_gate ,
386+ temb_ca_audio_gate ,
387+ ** kwargs ,
388+ )
389+
390+ return _wrapped_forward
391+
392+ for block in transformer .transformer_blocks :
393+ block .forward = _wrap_block_forward (block .forward ).__get__ (block , block .__class__ )
394+
395+ transformer ._cache_dit_ltx2_patched = True
396+
354397 db_cache_config = _build_db_cache_config (cache_config )
355398
356399 calibrator_config = None
@@ -359,7 +402,6 @@ def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], N
359402 calibrator_config = TaylorSeerCalibratorConfig (taylorseer_order = taylorseer_order )
360403 logger .info (f"TaylorSeer enabled with order={ taylorseer_order } " )
361404
362- transformer = pipeline .transformer
363405 blocks = transformer .transformer_blocks
364406
365407 logger .info (
0 commit comments