Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/user_guide/diffusion_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ The following table shows which models are currently supported by each accelerat
|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:|
| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ |

### AudioGen

| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |CFG-Parallel |Tensor-Parallel |
|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:|:----------------:|
| **Stable-Audio-Open** | `stabilityai/stable-audio-open-1.0` | | ✅ | ✅ | ❓ | ❓ | ❓ | ❌ |

### Quantization

| Model | Model Identifier | FP8 |
Expand Down
74 changes: 74 additions & 0 deletions vllm_omni/diffusion/cache/cache_dit_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,79 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context


def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for Stable Audio Open pipeline.

Args:
pipeline: The StableAudioPipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.

Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
db_cache_config = _build_db_cache_config(cache_config)

calibrator_config = None
if cache_config.enable_taylorseer:
taylorseer_order = cache_config.taylorseer_order
calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

logger.info(
f"Enabling cache-dit on Stable Audio transformer: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)

# StableAudio is officially registered in CacheDiT as Pattern_3:
# https://github.com/vipshop/cache-dit/blob/69e82bd1/src/cache_dit/caching/block_adapters/__init__.py#L562
#
# Pattern_3 is required because StableAudioDiT uses cross-attention
# with static encoder_hidden_states that do not change inside the
# transformer block loop.
cache_dit.enable_cache(
BlockAdapter(
transformer=pipeline.transformer,
blocks=pipeline.transformer.transformer_blocks,
forward_pattern=ForwardPattern.Pattern_3,
params_modifiers=[
ParamsModifier(
cache_config=db_cache_config,
calibrator_config=calibrator_config,
)
],
),
cache_config=db_cache_config,
)

def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context for the transformer with new num_inference_steps.

Args:
pipeline: The StableAudioPipeline instance.
num_inference_steps: New number of inference steps.
verbose: Whether to log refresh operations.
"""
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
else:
cache_dit.refresh_context(
pipeline.transformer,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_inference_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy,
total_steps=num_inference_steps,
),
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)

return refresh_cache_context


def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for StableDiffusion3Pipeline.

Expand Down Expand Up @@ -861,6 +934,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"FluxPipeline": enable_cache_for_flux,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"LongCatImageEditPipeline": enable_cache_for_longcat_image,
"StableAudioPipeline": enable_cache_for_stable_audio_open,
"StableDiffusion3Pipeline": enable_cache_for_sd3,
"BagelPipeline": enable_cache_for_bagel,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ class StableAudioDiTModel(nn.Module):
- Output: [B, out_channels, L]
"""

_repeated_blocks = ["StableAudioDiTBlock"]

def __init__(
self,
od_config: OmniDiffusionConfig | None = None,
Expand Down