Skip to content

Commit 220cd59

Browse files
princepridehsliuustc0106
authored andcommitted
[TeaCache]: Add Coefficient Estimation (#940)
Signed-off-by: princepride <wangzhipeng628@gmail.com> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
1 parent 7fb15a1 commit 220cd59

File tree

2 files changed

+342
-1
lines changed

2 files changed

+342
-1
lines changed

docs/contributing/model/adding_diffusion_model.md

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,151 @@ Key point for writing the example:
204204

205205
+ Save or display the generated results so users can validate the integration.
206206

207-
## Step 5: Open a Pull Request
207+
## Step 5: TeaCache Coefficient Estimation (Optional)
208+
209+
If your model supports TeaCache acceleration, you need to estimate the polynomial coefficients for optimal caching performance.
210+
211+
### 5.1 Add Extractor Function
212+
213+
First, implement an extractor function in `vllm_omni/diffusion/cache/teacache/extractors.py`. The extractor extracts the modulated input and defines how to run transformer blocks:
214+
215+
```python
216+
def extract_your_model_context(
217+
module: nn.Module,
218+
hidden_states: torch.Tensor,
219+
timestep: torch.Tensor,
220+
**kwargs: Any,
221+
) -> CacheContext:
222+
# 1. Preprocessing
223+
temb = module.time_embed(timestep)
224+
225+
# 2. Extract modulated input (for cache decision)
226+
modulated_input = module.transformer_blocks[0].norm1(hidden_states, temb)
227+
228+
# 3. Define transformer execution
229+
def run_transformer_blocks():
230+
h = hidden_states
231+
for block in module.transformer_blocks:
232+
h = block(h, temb=temb)
233+
return (h,)
234+
235+
# 4. Define postprocessing
236+
def postprocess(h):
237+
return module.proj_out(module.norm_out(h, temb))
238+
239+
return CacheContext(
240+
modulated_input=modulated_input,
241+
hidden_states=hidden_states,
242+
encoder_hidden_states=None,
243+
temb=temb,
244+
run_transformer_blocks=run_transformer_blocks,
245+
postprocess=postprocess,
246+
)
247+
```
248+
249+
Register it in `EXTRACTOR_REGISTRY`:
250+
```python
251+
EXTRACTOR_REGISTRY = {
252+
...
253+
"YourTransformer2DModel": extract_your_model_context,
254+
}
255+
```
256+
257+
### 5.2 Add Adapter for Coefficient Estimation
258+
259+
Add an adapter in `vllm_omni/diffusion/cache/teacache/coefficient_estimator.py`:
260+
261+
```python
262+
class YourModelAdapter:
263+
@staticmethod
264+
def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any:
265+
# Load your pipeline
266+
...
267+
268+
@staticmethod
269+
def get_transformer(pipeline: Any) -> tuple[Any, str]:
270+
return pipeline.transformer, "YourTransformer2DModel"
271+
272+
@staticmethod
273+
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
274+
registry = HookRegistry.get_or_create(transformer)
275+
registry.register_hook(hook._HOOK_NAME, hook)
276+
277+
_MODEL_ADAPTERS["YourModel"] = YourModelAdapter
278+
```
279+
280+
### 5.3 Run Coefficient Estimation
281+
282+
Use the provided script to estimate coefficients:
283+
284+
```python
285+
from vllm_omni.diffusion.cache.teacache.coefficient_estimator import (
286+
TeaCacheCoefficientEstimator,
287+
)
288+
from datasets import load_dataset
289+
from tqdm import tqdm
290+
291+
# Load model
292+
estimator = TeaCacheCoefficientEstimator(
293+
model_path="/path/to/model",
294+
model_type="Bagel", # Your model type
295+
device="cuda",
296+
)
297+
298+
# Load prompts (paper suggests ~70 prompts)
299+
dataset = load_dataset("nateraw/parti-prompts", split="train")
300+
prompts = dataset["Prompt"][:70]
301+
302+
# Collect data
303+
for prompt in tqdm(prompts):
304+
estimator.collect_from_prompt(prompt, num_inference_steps=50)
305+
306+
# Estimate coefficients
307+
coeffs = estimator.estimate(poly_order=4)
308+
print(f"Coefficients: {coeffs}")
309+
```
310+
311+
### 5.4 Interpreting Coefficient Estimation Results
312+
313+
The estimator outputs statistics and polynomial coefficients. Here's how to interpret them:
314+
315+
**Example Output:**
316+
```
317+
Data statistics:
318+
Count: 48
319+
Input Diffs (x): min=1.1089e-02, max=5.2555e-02, mean=2.8435e-02
320+
Output Diffs (y): min=2.8242e-02, max=2.9792e-01, mean=7.0312e-02
321+
Coefficients: [1333131.29, -168644.23, 7950.51, -163.75, 1.26]
322+
```
323+
324+
**What to Check:**
325+
- **Count**: Number of timestep pairs analyzed. Should be at least 30-50 for reliable estimation. Low count suggests insufficient prompts or inference steps.
326+
- **Input/Output Ranges**: Verify output differences correlate with input differences. If ranges seem unusual, check your prompt diversity.
327+
- **Coefficient Magnitude**: Extremely large values (>1e8) may indicate numerical instability - try collecting more diverse data.
328+
329+
**Troubleshooting:**
330+
- If results seem unreliable, try:
331+
- Increasing number of prompts (100+ recommended)
332+
- Using more diverse prompts from multiple datasets
333+
- Adjusting `num_inference_steps` (try 20, 50, 100)
334+
335+
### 5.5 Add Coefficients to Config
336+
337+
Add the estimated coefficients to `vllm_omni/diffusion/cache/teacache/config.py`:
338+
339+
```python
340+
_MODEL_COEFFICIENTS = {
341+
...
342+
"YourTransformer2DModel": [
343+
1.04730573e+06, # a4
344+
-1.34150749e+05, # a3
345+
6.51517806e+03, # a2
346+
-1.41209108e+02, # a1
347+
1.17241808e+00, # a0
348+
],
349+
}
350+
```
351+
## Step 6: Open a Pull Request
208352

209353
When submitting a pull request to add support for a new model, please include the following information in the PR description:
210354

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import types
5+
from typing import Any
6+
7+
import numpy as np
8+
import torch
9+
from vllm.config import LoadConfig
10+
11+
from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
12+
from vllm_omni.diffusion.data import OmniDiffusionConfig
13+
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
14+
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
15+
from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline
16+
17+
18+
class DataCollectionHook(ModelHook):
19+
"""Hook to collect modulated inputs and model outputs for TeaCache coefficient estimation."""
20+
21+
_HOOK_NAME = "teacache_collector"
22+
23+
def __init__(self, transformer_type: str):
24+
super().__init__()
25+
self.transformer_type = transformer_type
26+
self.extractor_fn = None
27+
self.current_trajectory: list[tuple[np.ndarray, np.ndarray]] = []
28+
29+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
30+
self.extractor_fn = get_extractor(self.transformer_type)
31+
return module
32+
33+
def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
34+
ctx = self.extractor_fn(module, *args, **kwargs)
35+
modulated_input_cpu = ctx.modulated_input.detach().cpu().numpy()
36+
37+
outputs = ctx.run_transformer_blocks()
38+
ctx.hidden_states = outputs[0]
39+
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
40+
ctx.encoder_hidden_states = outputs[1]
41+
42+
model_output_cpu = ctx.hidden_states.detach().cpu().numpy()
43+
self.current_trajectory.append((modulated_input_cpu, model_output_cpu))
44+
45+
return ctx.postprocess(ctx.hidden_states)
46+
47+
def start_collection(self):
48+
self.current_trajectory = []
49+
50+
def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]:
51+
return list(self.current_trajectory)
52+
53+
54+
class BagelAdapter:
55+
"""Adapter for Bagel model."""
56+
57+
@staticmethod
58+
def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline:
59+
od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
60+
od_config.model_class_name = "BagelPipeline"
61+
62+
pipeline = BagelPipeline(od_config=od_config)
63+
loader = DiffusersPipelineLoader(LoadConfig())
64+
loader.load_weights(pipeline)
65+
pipeline.to(device)
66+
return pipeline
67+
68+
@staticmethod
69+
def get_transformer(pipeline: Any) -> tuple[Any, str]:
70+
return pipeline.bagel, "Bagel"
71+
72+
@staticmethod
73+
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
74+
original_forward_flow = transformer._forward_flow
75+
76+
def forward_alias(self, *args, **kwargs):
77+
return original_forward_flow(*args, **kwargs)
78+
79+
transformer.forward = types.MethodType(forward_alias, transformer)
80+
registry = HookRegistry.get_or_create(transformer)
81+
registry.register_hook(hook._HOOK_NAME, hook)
82+
transformer._forward_flow = transformer.forward
83+
84+
85+
class DefaultAdapter:
86+
"""Default adapter for standard diffusers pipelines."""
87+
88+
@staticmethod
89+
def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any:
90+
raise NotImplementedError("DefaultAdapter.load_pipeline not implemented")
91+
92+
@staticmethod
93+
def get_transformer(pipeline: Any) -> tuple[Any, str]:
94+
return pipeline.transformer, pipeline.transformer.__class__.__name__
95+
96+
@staticmethod
97+
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
98+
registry = HookRegistry.get_or_create(transformer)
99+
registry.register_hook(hook._HOOK_NAME, hook)
100+
101+
102+
_MODEL_ADAPTERS: dict[str, type] = {
103+
"Bagel": BagelAdapter,
104+
}
105+
106+
_EPSILON = 1e-6
107+
108+
109+
def calculate_relative_l1(tensor_current: np.ndarray, tensor_next: np.ndarray) -> float:
110+
"""Calculate relative L1 distance (Eq. 4 from TeaCache paper)."""
111+
diff = np.abs(tensor_current - tensor_next).sum()
112+
norm = np.abs(tensor_current).sum() + _EPSILON
113+
return diff / norm
114+
115+
116+
def estimate_teacache_coefficients(
117+
collected_data: list[list[tuple[np.ndarray, np.ndarray]]], poly_order: int = 4
118+
) -> list[float]:
119+
"""Estimate polynomial coefficients for TeaCache using np.polyfit."""
120+
input_diffs, output_diffs = [], []
121+
122+
for sample in collected_data:
123+
for t in range(len(sample) - 1):
124+
feat_in_curr, feat_out_curr = sample[t]
125+
feat_in_next, feat_out_next = sample[t + 1]
126+
input_diffs.append(calculate_relative_l1(feat_in_curr, feat_in_next))
127+
output_diffs.append(calculate_relative_l1(feat_out_curr, feat_out_next))
128+
129+
x = np.array(input_diffs, dtype=np.float64)
130+
y = np.array(output_diffs, dtype=np.float64)
131+
132+
print("Data statistics:")
133+
print(f" Count: {len(x)}")
134+
print(f" Input Diffs (x): min={x.min():.4e}, max={x.max():.4e}, mean={x.mean():.4e}")
135+
print(f" Output Diffs (y): min={y.min():.4e}, max={y.max():.4e}, mean={y.mean():.4e}")
136+
137+
return np.polyfit(x, y, poly_order).tolist()
138+
139+
140+
class TeaCacheCoefficientEstimator:
141+
"""Model-agnostic helper class to collect data and estimate TeaCache coefficients."""
142+
143+
def __init__(
144+
self,
145+
model_path: str,
146+
model_type: str = "Bagel",
147+
device: str = "cuda",
148+
dtype: torch.dtype = torch.bfloat16,
149+
):
150+
# Add validation here ⬇️
151+
if model_type not in _MODEL_ADAPTERS:
152+
available_types = list(_MODEL_ADAPTERS.keys())
153+
raise ValueError(
154+
f"Unsupported model_type: '{model_type}'. "
155+
f"Available types: {available_types}. "
156+
f"To add support for a new model, add an entry to _MODEL_ADAPTERS."
157+
)
158+
159+
adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter)
160+
self.pipeline = adapter.load_pipeline(model_path, device, dtype)
161+
self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline)
162+
self.hook = DataCollectionHook(self.transformer_type)
163+
self.collected_data: list[list[tuple[np.ndarray, np.ndarray]]] = []
164+
adapter.install_hook(self.transformer, self.hook)
165+
166+
def collect_from_prompt(self, prompt: str, **generate_kwargs):
167+
self.hook.start_collection()
168+
from vllm_omni.diffusion.request import OmniDiffusionRequest
169+
170+
req = OmniDiffusionRequest(
171+
prompt=prompt,
172+
num_inference_steps=generate_kwargs.get("num_inference_steps", 20),
173+
seed=generate_kwargs.get("seed", 42),
174+
)
175+
self.pipeline.forward(req)
176+
trajectory = self.hook.stop_collection()
177+
if trajectory:
178+
self.collected_data.append(trajectory)
179+
180+
def estimate(self, poly_order: int = 4) -> list[float]:
181+
"""Estimate polynomial coefficients from collected data.
182+
183+
Args:
184+
poly_order: Order of polynomial fit (default: 4)
185+
186+
Returns:
187+
List of polynomial coefficients [a_n, a_{n-1}, ..., a_1, a_0]
188+
189+
Raises:
190+
RuntimeError: If no data has been collected
191+
"""
192+
if not self.collected_data:
193+
raise RuntimeError(
194+
"No data collected for coefficient estimation. "
195+
"Call collect_from_prompt() at least once before calling estimate()."
196+
)
197+
return estimate_teacache_coefficients(self.collected_data, poly_order)

0 commit comments

Comments
 (0)