Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 82 additions & 34 deletions vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import types
from typing import Any

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from vllm.config import LoadConfig
from vllm.logger import init_logger

from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline
from vllm_omni.diffusion.models.glm_image.pipeline_glm_image import GlmImagePipeline

logger = init_logger(__name__)


class DataCollectionHook(ModelHook):
Expand All @@ -32,14 +38,15 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:

def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
ctx = self.extractor_fn(module, *args, **kwargs)
modulated_input_cpu = ctx.modulated_input.detach().cpu().numpy()
# Cast to float32 before .numpy() — bfloat16 is not supported by numpy
modulated_input_cpu = ctx.modulated_input.detach().cpu().float().numpy()

outputs = ctx.run_transformer_blocks()
ctx.hidden_states = outputs[0]
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
ctx.encoder_hidden_states = outputs[1]

model_output_cpu = ctx.hidden_states.detach().cpu().numpy()
model_output_cpu = ctx.hidden_states.detach().cpu().float().numpy()
self.current_trajectory.append((modulated_input_cpu, model_output_cpu))

return ctx.postprocess(ctx.hidden_states)
Expand Down Expand Up @@ -82,6 +89,44 @@ def forward_alias(self, *args, **kwargs):
transformer._forward_flow = transformer.forward


class GlmImageAdapter:
"""Adapter for GLM-Image model."""

@staticmethod
def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
od_config.model_class_name = "GlmImagePipeline"

tf_config_path = hf_hub_download(model_path, "transformer/config.json")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle local checkpoints when reading GLM transformer config

GlmImageAdapter.load_pipeline unconditionally calls hf_hub_download(model_path, "transformer/config.json"), which treats model_path as a Hub repo id; when callers pass a local model directory (a supported pattern elsewhere in this repo), this path resolution fails before the pipeline is built. That makes TeaCache coefficient estimation for GLM unusable in local/offline setups and breaks parity with other loaders that accept filesystem paths.

Useful? React with 👍 / 👎.

with open(tf_config_path) as f:
tf_cfg = json.load(f)
od_config.tf_model_config.params = tf_cfg

pipeline = GlmImagePipeline(od_config=od_config)
loader = DiffusersPipelineLoader(LoadConfig())
loader.load_weights(pipeline)

# GLM-Image bundles a large AR model alongside the DiT.
# Moving the full pipeline to GPU at once exceeds 80GB.
# For TeaCache coefficient estimation we only need the DiT transformer,
# so we move submodels selectively and offload the rest to CP
pipeline.transformer.to(device, dtype=dtype)
pipeline.vae.to(device, dtype=dtype)
pipeline.text_encoder.to(device, dtype=dtype)
# pipeline.model (the AR model) stays on CPU — it's not needed here

return pipeline

@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
return pipeline.transformer, "GlmImageTransformer2DModel"

@staticmethod
def install_hook(transformer: Any, hook: Any) -> None:
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)


class DefaultAdapter:
"""Default adapter for standard diffusers pipelines."""

Expand All @@ -101,6 +146,7 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:

_MODEL_ADAPTERS: dict[str, type] = {
"Bagel": BagelAdapter,
"GlmImage": GlmImageAdapter,
}

_EPSILON = 1e-6
Expand All @@ -113,30 +159,6 @@ def calculate_relative_l1(tensor_current: np.ndarray, tensor_next: np.ndarray) -
return diff / norm


def estimate_teacache_coefficients(
collected_data: list[list[tuple[np.ndarray, np.ndarray]]], poly_order: int = 4
) -> list[float]:
"""Estimate polynomial coefficients for TeaCache using np.polyfit."""
input_diffs, output_diffs = [], []

for sample in collected_data:
for t in range(len(sample) - 1):
feat_in_curr, feat_out_curr = sample[t]
feat_in_next, feat_out_next = sample[t + 1]
input_diffs.append(calculate_relative_l1(feat_in_curr, feat_in_next))
output_diffs.append(calculate_relative_l1(feat_out_curr, feat_out_next))

x = np.array(input_diffs, dtype=np.float64)
y = np.array(output_diffs, dtype=np.float64)

print("Data statistics:")
print(f" Count: {len(x)}")
print(f" Input Diffs (x): min={x.min():.4e}, max={x.max():.4e}, mean={x.mean():.4e}")
print(f" Output Diffs (y): min={y.min():.4e}, max={y.max():.4e}, mean={y.mean():.4e}")

return np.polyfit(x, y, poly_order).tolist()


class TeaCacheCoefficientEstimator:
"""Model-agnostic helper class to collect data and estimate TeaCache coefficients."""

Expand All @@ -147,7 +169,7 @@ def __init__(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
# Add validation here ⬇️
# Add validation here
if model_type not in _MODEL_ADAPTERS:
available_types = list(_MODEL_ADAPTERS.keys())
raise ValueError(
Expand All @@ -160,22 +182,34 @@ def __init__(
self.pipeline = adapter.load_pipeline(model_path, device, dtype)
self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline)
self.hook = DataCollectionHook(self.transformer_type)
self.collected_data: list[list[tuple[np.ndarray, np.ndarray]]] = []
# Store only scalar diffs instead of raw numpy arrays to keep RAM flat
self.input_diffs: list[float] = []
self.output_diffs: list[float] = []
adapter.install_hook(self.transformer, self.hook)

def collect_from_prompt(self, prompt: str, **generate_kwargs):
self.hook.start_collection()
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams

req = OmniDiffusionRequest(
prompt=prompt,
sampling_params = OmniDiffusionSamplingParams(
num_inference_steps=generate_kwargs.get("num_inference_steps", 20),
seed=generate_kwargs.get("seed", 42),
)
Comment on lines +194 to 197

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Pass GLM prior tokens into coefficient collection requests

collect_from_prompt only populates num_inference_steps and seed, then sends a plain prompt; for GlmImagePipeline, non-warmup requests require prior_token_ids (and optionally prior_token_image_ids) in request extras, otherwise pipeline_glm_image.forward raises a ValueError. As written, the new GLM estimator path cannot collect data from normal prompts, so the advertised support is functionally broken unless users bypass this API.

Useful? React with 👍 / 👎.

req = OmniDiffusionRequest(
prompts=[prompt],
sampling_params=sampling_params,
)
self.pipeline.forward(req)
trajectory = self.hook.stop_collection()

# Compute diffs immediately and discard raw arrays — keeps RAM flat
if trajectory:
self.collected_data.append(trajectory)
for t in range(len(trajectory) - 1):
feat_in_curr, feat_out_curr = trajectory[t]
feat_in_next, feat_out_next = trajectory[t + 1]
self.input_diffs.append(calculate_relative_l1(feat_in_curr, feat_in_next))
self.output_diffs.append(calculate_relative_l1(feat_out_curr, feat_out_next))
del trajectory # explicitly free the raw arrays

def estimate(self, poly_order: int = 4) -> list[float]:
"""Estimate polynomial coefficients from collected data.
Expand All @@ -189,9 +223,23 @@ def estimate(self, poly_order: int = 4) -> list[float]:
Raises:
RuntimeError: If no data has been collected
"""
if not self.collected_data:
if not self.input_diffs:
raise RuntimeError(
"No data collected for coefficient estimation. "
"Call collect_from_prompt() at least once before calling estimate()."
)
return estimate_teacache_coefficients(self.collected_data, poly_order)
x = np.array(self.input_diffs, dtype=np.float64)
y = np.array(self.output_diffs, dtype=np.float64)

logger.info(
"TeaCache data statistics: count=%d, x=[min=%.4e, max=%.4e, mean=%.4e], y=[min=%.4e, max=%.4e, mean=%.4e]",
len(x),
x.min(),
x.max(),
x.mean(),
y.min(),
y.max(),
y.mean(),
)

return np.polyfit(x, y, poly_order).tolist()
8 changes: 8 additions & 0 deletions vllm_omni/diffusion/cache/teacache/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
3.20000000e00,
-2.00000000e-02,
],
# Calculated GLM-Image Coefficient
"GlmImageTransformer2DModel": [
-6071.632298241158,
1837.6579251847247,
-172.12278847677337,
7.159036598427308,
-0.07853601464946189,
],
}


Expand Down
121 changes: 121 additions & 0 deletions vllm_omni/diffusion/cache/teacache/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,126 @@ def postprocess(h):
)


def extract_glmimage_context(
module: nn.Module,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
prior_token_id: torch.Tensor,
prior_token_drop: torch.Tensor,
timestep: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
kv_cache=None,
attention_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> CacheContext:
from diffusers.models.modeling_outputs import Transformer2DModelOutput

if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
raise ValueError("Module must contain transformer_blocks.")

device = hidden_states.device
dtype = hidden_states.dtype
return_dict = kwargs.get("return_dict", True)

orig_batch, orig_c, orig_h, orig_w = hidden_states.shape

if image_rotary_emb is None:
image_rotary_emb = module.rope(hidden_states)
image_rotary_emb = (
image_rotary_emb[0].to(device),
image_rotary_emb[1].to(device),
)

hidden_states = module.image_projector(hidden_states)
encoder_hidden_states = module.glyph_projector(encoder_hidden_states)

prior_embedding = module.prior_token_embedding(prior_token_id).clone()
prior_embedding[prior_token_drop] *= 0.0
prior_hidden_states = module.prior_projector(prior_embedding)
hidden_states = hidden_states + prior_hidden_states

target_size = target_size.to(device)
crop_coords = crop_coords.to(device)

timestep = timestep.to(device)
temb = module.time_condition_embed(
timestep,
target_size,
crop_coords,
dtype,
)

first_block = module.transformer_blocks[0]

(
norm_hidden_states,
_,
_,
_,
_,
_,
_,
_,
_,
_,
) = first_block.norm1(hidden_states, encoder_hidden_states, temb)

# Use image stream after AdaLayerNormZero as similarity signal.
# This is the dominant diffusion signal across timesteps.
modulated_input = norm_hidden_states

def run_transformer_blocks() -> tuple[torch.Tensor, torch.Tensor]:
h = hidden_states
e = encoder_hidden_states
kv_cache_mode = kv_cache.mode if kv_cache is not None else None

for layer_idx, block in enumerate(module.transformer_blocks):
layer_kv_cache = kv_cache[layer_idx] if kv_cache is not None else None

h, e = block(
hidden_states=h,
encoder_hidden_states=e,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
attention_kwargs=attention_kwargs,
kv_cache=layer_kv_cache,
kv_cache_mode=kv_cache_mode,
)

return (h, e)

def postprocess(h: torch.Tensor):
h = module.norm_out(h, temb)
h = module.proj_out(h)

# Reconstruct spatial layout from patch tokens using original latent shape.
# Mirrors GlmImageTransformer2DModel.forward() unpatchify logic.
p = module.patch_size
post_patch_height = orig_h // p
post_patch_width = orig_w // p

h = h.reshape(orig_batch, post_patch_height, post_patch_width, -1, p, p)
output = h.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

return CacheContext(
modulated_input=modulated_input,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
run_transformer_blocks=run_transformer_blocks,
postprocess=postprocess,
extra_states=None,
)


def extract_zimage_context(
module: nn.Module,
x: list[torch.Tensor],
Expand Down Expand Up @@ -576,6 +696,7 @@ def postprocess(h):
"QwenImageTransformer2DModel": extract_qwen_context,
"Bagel": extract_bagel_context,
"ZImageTransformer2DModel": extract_zimage_context,
"GlmImageTransformer2DModel": extract_glmimage_context,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
Expand Down
13 changes: 13 additions & 0 deletions vllm_omni/diffusion/cache/teacache/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import torch
from vllm.logger import init_logger

from vllm_omni.diffusion.cache.teacache.config import TeaCacheConfig
from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
Expand All @@ -26,6 +27,8 @@
)
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook, StateManager

logger = init_logger(__name__)


class TeaCacheHook(ModelHook):
"""
Expand Down Expand Up @@ -226,8 +229,18 @@ def _should_compute_full_transformer(self, state: TeaCacheState, modulated_inp:

# Decision: below threshold = cache, above = compute
if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh:
logger.debug(
"[%d] TeaCache: SKIPPING (Cache Hit!) - Dist: %.4f",
state.cnt,
state.accumulated_rel_l1_distance,
)
return False # Use cache
else:
logger.debug(
"[%d] TeaCache: COMPUTING - Dist: %.4f",
state.cnt,
state.accumulated_rel_l1_distance,
)
state.accumulated_rel_l1_distance = 0.0 # Reset accumulator
return True # Compute

Expand Down
Loading