diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 1a2e0a7d23..8ce689ed4f 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -72,6 +72,12 @@ The following table shows which models are currently supported by each accelerat |-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:| | **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ | +### AudioGen + +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |CFG-Parallel |Tensor-Parallel | +|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:|:----------------:| +| **Stable-Audio-Open** | `stabilityai/stable-audio-open-1.0` | | ✅ | ✅ | ❓ | ❓ | ❓ | ❌ | + ### Quantization | Model | Model Identifier | FP8 | diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index f014d073e8..ee9c0d34a9 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -346,6 +346,79 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Stable Audio Open pipeline. + + Args: + pipeline: The StableAudioPipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + db_cache_config = _build_db_cache_config(cache_config) + + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + logger.info( + f"Enabling cache-dit on Stable Audio transformer: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # StableAudio is officially registered in CacheDiT as Pattern_3: + # https://github.com/vipshop/cache-dit/blob/69e82bd1/src/cache_dit/caching/block_adapters/__init__.py#L562 + # + # Pattern_3 is required because StableAudioDiT uses cross-attention + # with static encoder_hidden_states that do not change inside the + # transformer block loop. + cache_dit.enable_cache( + BlockAdapter( + transformer=pipeline.transformer, + blocks=pipeline.transformer.transformer_blocks, + forward_pattern=ForwardPattern.Pattern_3, + params_modifiers=[ + ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + ], + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The StableAudioPipeline instance. + num_inference_steps: New number of inference steps. + verbose: Whether to log refresh operations. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for StableDiffusion3Pipeline. @@ -861,6 +934,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "FluxPipeline": enable_cache_for_flux, "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, + "StableAudioPipeline": enable_cache_for_stable_audio_open, "StableDiffusion3Pipeline": enable_cache_for_sd3, "BagelPipeline": enable_cache_for_bagel, } diff --git a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py index 22d56ac1fd..4a4892673f 100644 --- a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py +++ b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py @@ -375,6 +375,8 @@ class StableAudioDiTModel(nn.Module): - Output: [B, out_channels, L] """ + _repeated_blocks = ["StableAudioDiTBlock"] + def __init__( self, od_config: OmniDiffusionConfig | None = None,