From c5dbef607682cba86b5f71c1283f56666cd6d28a Mon Sep 17 00:00:00 2001 From: akshatvishu Date: Wed, 11 Feb 2026 21:53:08 +0530 Subject: [PATCH 1/7] feat(diffusion): add cache-dit support for Stable Audio Open 1.0 Signed-off-by: akshatvishu --- .../diffusion/cache/cache_dit_backend.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index f014d073e8..c74638a42c 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -346,6 +346,74 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Stable Audio Open pipeline. + + Args: + pipeline: The StableAudioPipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + db_cache_config = _build_db_cache_config(cache_config) + + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + logger.info( + f"Enabling cache-dit on Stable Audio transformer: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # StableAudioDiTBlock.forward has two required positional args. + # Pattern_2 intercepts both, treating the 1st as residual and 2nd as context. + cache_dit.enable_cache( + BlockAdapter( + transformer=pipeline.transformer, + blocks=pipeline.transformer.transformer_blocks, + forward_pattern=ForwardPattern.Pattern_2, + params_modifiers=[ + ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + ], + ), + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The StableAudioPipeline instance. + num_inference_steps: New number of inference steps. + verbose: Whether to log refresh operations. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for StableDiffusion3Pipeline. @@ -861,6 +929,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "FluxPipeline": enable_cache_for_flux, "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, + "StableAudioPipeline": enable_cache_for_stable_audio_open, "StableDiffusion3Pipeline": enable_cache_for_sd3, "BagelPipeline": enable_cache_for_bagel, } From 5212ac66d226d5086b89b68174668f763575e51a Mon Sep 17 00:00:00 2001 From: akshatvishu Date: Wed, 11 Feb 2026 23:47:29 +0530 Subject: [PATCH 2/7] fix(cache-dit): add step guard to prevent sao scm crash during warmup Signed-off-by: akshatvishu --- vllm_omni/diffusion/cache/cache_dit_backend.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index c74638a42c..acc9153252 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -395,7 +395,11 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool num_inference_steps: New number of inference steps. verbose: Whether to log refresh operations. """ - if cache_config.scm_steps_mask_policy is None: + scm_policy = cache_config.get("scm_steps_mask_policy") + is_supported_scm_step = num_inference_steps >= 8 or num_inference_steps in (4, 6) + use_scm = scm_policy is not None and is_supported_scm_step + + if not use_scm: cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) else: cache_dit.refresh_context( From adadbc146aede3e51cbf100f7139ce609ebc0a6e Mon Sep 17 00:00:00 2001 From: akshatvishu Date: Thu, 12 Feb 2026 00:37:58 +0530 Subject: [PATCH 3/7] Revert "fix(cache-dit): add step guard to prevent sao scm crash during warmup" This reverts commit e4c5a1f55bc098aafda226294c88a1532d021292. Signed-off-by: akshatvishu --- vllm_omni/diffusion/cache/cache_dit_backend.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index acc9153252..c74638a42c 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -395,11 +395,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool num_inference_steps: New number of inference steps. verbose: Whether to log refresh operations. """ - scm_policy = cache_config.get("scm_steps_mask_policy") - is_supported_scm_step = num_inference_steps >= 8 or num_inference_steps in (4, 6) - use_scm = scm_policy is not None and is_supported_scm_step - - if not use_scm: + if cache_config.scm_steps_mask_policy is None: cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) else: cache_dit.refresh_context( From 4569fd7f1bf4b21af12ab6bea18fe0e2acfc80af Mon Sep 17 00:00:00 2001 From: akshatvishu Date: Thu, 12 Feb 2026 01:47:55 +0530 Subject: [PATCH 4/7] docs: update Cache-DiT entry Signed-off-by: akshatvishu --- docs/user_guide/diffusion_acceleration.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 1a2e0a7d23..8ce689ed4f 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -72,6 +72,12 @@ The following table shows which models are currently supported by each accelerat |-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:| | **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ | +### AudioGen + +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |CFG-Parallel |Tensor-Parallel | +|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:|:----------------:| +| **Stable-Audio-Open** | `stabilityai/stable-audio-open-1.0` | | ✅ | ✅ | ❓ | ❓ | ❓ | ❌ | + ### Quantization | Model | Model Identifier | FP8 | From 63901916b0b2a3ccd6f2d50ee2cac1c351b2bec0 Mon Sep 17 00:00:00 2001 From: akshatvishu Date: Thu, 12 Feb 2026 21:45:01 +0530 Subject: [PATCH 5/7] correctly initialize Stable Audio cache context Signed-off-by: akshatvishu --- vllm_omni/diffusion/cache/cache_dit_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index c74638a42c..f85b45cd6e 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -385,6 +385,7 @@ def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Call ) ], ), + cache_config=db_cache_config, ) def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: From 2667215a2d6df3e7a65fa36dfb39ad74c34142db Mon Sep 17 00:00:00 2001 From: akshatvishu Date: Thu, 12 Feb 2026 23:09:37 +0530 Subject: [PATCH 6/7] Fix StableAudio CacheDiT forward_pattern to Pattern_3 to match vipshop/cache-dit Signed-off-by: akshatvishu --- vllm_omni/diffusion/cache/cache_dit_backend.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index f85b45cd6e..ee9c0d34a9 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -371,13 +371,17 @@ def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Call f"W={db_cache_config.max_warmup_steps}, " ) - # StableAudioDiTBlock.forward has two required positional args. - # Pattern_2 intercepts both, treating the 1st as residual and 2nd as context. + # StableAudio is officially registered in CacheDiT as Pattern_3: + # https://github.com/vipshop/cache-dit/blob/69e82bd1/src/cache_dit/caching/block_adapters/__init__.py#L562 + # + # Pattern_3 is required because StableAudioDiT uses cross-attention + # with static encoder_hidden_states that do not change inside the + # transformer block loop. cache_dit.enable_cache( BlockAdapter( transformer=pipeline.transformer, blocks=pipeline.transformer.transformer_blocks, - forward_pattern=ForwardPattern.Pattern_2, + forward_pattern=ForwardPattern.Pattern_3, params_modifiers=[ ParamsModifier( cache_config=db_cache_config, From 3dfc0b2f2abbfb03ede63f2caee98cf35da02d25 Mon Sep 17 00:00:00 2001 From: akshatvishu Date: Fri, 13 Feb 2026 03:32:50 +0530 Subject: [PATCH 7/7] feat: add _repeated_blocks to Stable Audio DiT for Cache-DiT acceleration Signed-off-by: akshatvishu --- .../diffusion/models/stable_audio/stable_audio_transformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py index 22d56ac1fd..4a4892673f 100644 --- a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py +++ b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py @@ -375,6 +375,8 @@ class StableAudioDiTModel(nn.Module): - Output: [B, out_channels, L] """ + _repeated_blocks = ["StableAudioDiTBlock"] + def __init__( self, od_config: OmniDiffusionConfig | None = None,