|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import types |
| 5 | +from typing import Any |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | +from vllm.config import LoadConfig |
| 10 | + |
| 11 | +from vllm_omni.diffusion.cache.teacache.extractors import get_extractor |
| 12 | +from vllm_omni.diffusion.data import OmniDiffusionConfig |
| 13 | +from vllm_omni.diffusion.hooks import HookRegistry, ModelHook |
| 14 | +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader |
| 15 | +from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline |
| 16 | + |
| 17 | + |
| 18 | +class DataCollectionHook(ModelHook): |
| 19 | + """Hook to collect modulated inputs and model outputs for TeaCache coefficient estimation.""" |
| 20 | + |
| 21 | + _HOOK_NAME = "teacache_collector" |
| 22 | + |
| 23 | + def __init__(self, transformer_type: str): |
| 24 | + super().__init__() |
| 25 | + self.transformer_type = transformer_type |
| 26 | + self.extractor_fn = None |
| 27 | + self.current_trajectory: list[tuple[np.ndarray, np.ndarray]] = [] |
| 28 | + |
| 29 | + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: |
| 30 | + self.extractor_fn = get_extractor(self.transformer_type) |
| 31 | + return module |
| 32 | + |
| 33 | + def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any: |
| 34 | + ctx = self.extractor_fn(module, *args, **kwargs) |
| 35 | + modulated_input_cpu = ctx.modulated_input.detach().cpu().numpy() |
| 36 | + |
| 37 | + outputs = ctx.run_transformer_blocks() |
| 38 | + ctx.hidden_states = outputs[0] |
| 39 | + if len(outputs) > 1 and ctx.encoder_hidden_states is not None: |
| 40 | + ctx.encoder_hidden_states = outputs[1] |
| 41 | + |
| 42 | + model_output_cpu = ctx.hidden_states.detach().cpu().numpy() |
| 43 | + self.current_trajectory.append((modulated_input_cpu, model_output_cpu)) |
| 44 | + |
| 45 | + return ctx.postprocess(ctx.hidden_states) |
| 46 | + |
| 47 | + def start_collection(self): |
| 48 | + self.current_trajectory = [] |
| 49 | + |
| 50 | + def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]: |
| 51 | + return list(self.current_trajectory) |
| 52 | + |
| 53 | + |
| 54 | +class BagelAdapter: |
| 55 | + """Adapter for Bagel model.""" |
| 56 | + |
| 57 | + @staticmethod |
| 58 | + def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline: |
| 59 | + od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) |
| 60 | + od_config.model_class_name = "BagelPipeline" |
| 61 | + |
| 62 | + pipeline = BagelPipeline(od_config=od_config) |
| 63 | + loader = DiffusersPipelineLoader(LoadConfig()) |
| 64 | + loader.load_weights(pipeline) |
| 65 | + pipeline.to(device) |
| 66 | + return pipeline |
| 67 | + |
| 68 | + @staticmethod |
| 69 | + def get_transformer(pipeline: Any) -> tuple[Any, str]: |
| 70 | + return pipeline.bagel, "Bagel" |
| 71 | + |
| 72 | + @staticmethod |
| 73 | + def install_hook(transformer: Any, hook: DataCollectionHook) -> None: |
| 74 | + original_forward_flow = transformer._forward_flow |
| 75 | + |
| 76 | + def forward_alias(self, *args, **kwargs): |
| 77 | + return original_forward_flow(*args, **kwargs) |
| 78 | + |
| 79 | + transformer.forward = types.MethodType(forward_alias, transformer) |
| 80 | + registry = HookRegistry.get_or_create(transformer) |
| 81 | + registry.register_hook(hook._HOOK_NAME, hook) |
| 82 | + transformer._forward_flow = transformer.forward |
| 83 | + |
| 84 | + |
| 85 | +class DefaultAdapter: |
| 86 | + """Default adapter for standard diffusers pipelines.""" |
| 87 | + |
| 88 | + @staticmethod |
| 89 | + def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any: |
| 90 | + raise NotImplementedError("DefaultAdapter.load_pipeline not implemented") |
| 91 | + |
| 92 | + @staticmethod |
| 93 | + def get_transformer(pipeline: Any) -> tuple[Any, str]: |
| 94 | + return pipeline.transformer, pipeline.transformer.__class__.__name__ |
| 95 | + |
| 96 | + @staticmethod |
| 97 | + def install_hook(transformer: Any, hook: DataCollectionHook) -> None: |
| 98 | + registry = HookRegistry.get_or_create(transformer) |
| 99 | + registry.register_hook(hook._HOOK_NAME, hook) |
| 100 | + |
| 101 | + |
| 102 | +_MODEL_ADAPTERS: dict[str, type] = { |
| 103 | + "Bagel": BagelAdapter, |
| 104 | +} |
| 105 | + |
| 106 | +_EPSILON = 1e-6 |
| 107 | + |
| 108 | + |
| 109 | +def calculate_relative_l1(tensor_current: np.ndarray, tensor_next: np.ndarray) -> float: |
| 110 | + """Calculate relative L1 distance (Eq. 4 from TeaCache paper).""" |
| 111 | + diff = np.abs(tensor_current - tensor_next).sum() |
| 112 | + norm = np.abs(tensor_current).sum() + _EPSILON |
| 113 | + return diff / norm |
| 114 | + |
| 115 | + |
| 116 | +def estimate_teacache_coefficients( |
| 117 | + collected_data: list[list[tuple[np.ndarray, np.ndarray]]], poly_order: int = 4 |
| 118 | +) -> list[float]: |
| 119 | + """Estimate polynomial coefficients for TeaCache using np.polyfit.""" |
| 120 | + input_diffs, output_diffs = [], [] |
| 121 | + |
| 122 | + for sample in collected_data: |
| 123 | + for t in range(len(sample) - 1): |
| 124 | + feat_in_curr, feat_out_curr = sample[t] |
| 125 | + feat_in_next, feat_out_next = sample[t + 1] |
| 126 | + input_diffs.append(calculate_relative_l1(feat_in_curr, feat_in_next)) |
| 127 | + output_diffs.append(calculate_relative_l1(feat_out_curr, feat_out_next)) |
| 128 | + |
| 129 | + x = np.array(input_diffs, dtype=np.float64) |
| 130 | + y = np.array(output_diffs, dtype=np.float64) |
| 131 | + |
| 132 | + print("Data statistics:") |
| 133 | + print(f" Count: {len(x)}") |
| 134 | + print(f" Input Diffs (x): min={x.min():.4e}, max={x.max():.4e}, mean={x.mean():.4e}") |
| 135 | + print(f" Output Diffs (y): min={y.min():.4e}, max={y.max():.4e}, mean={y.mean():.4e}") |
| 136 | + |
| 137 | + return np.polyfit(x, y, poly_order).tolist() |
| 138 | + |
| 139 | + |
| 140 | +class TeaCacheCoefficientEstimator: |
| 141 | + """Model-agnostic helper class to collect data and estimate TeaCache coefficients.""" |
| 142 | + |
| 143 | + def __init__( |
| 144 | + self, |
| 145 | + model_path: str, |
| 146 | + model_type: str = "Bagel", |
| 147 | + device: str = "cuda", |
| 148 | + dtype: torch.dtype = torch.bfloat16, |
| 149 | + ): |
| 150 | + # Add validation here ⬇️ |
| 151 | + if model_type not in _MODEL_ADAPTERS: |
| 152 | + available_types = list(_MODEL_ADAPTERS.keys()) |
| 153 | + raise ValueError( |
| 154 | + f"Unsupported model_type: '{model_type}'. " |
| 155 | + f"Available types: {available_types}. " |
| 156 | + f"To add support for a new model, add an entry to _MODEL_ADAPTERS." |
| 157 | + ) |
| 158 | + |
| 159 | + adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter) |
| 160 | + self.pipeline = adapter.load_pipeline(model_path, device, dtype) |
| 161 | + self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline) |
| 162 | + self.hook = DataCollectionHook(self.transformer_type) |
| 163 | + self.collected_data: list[list[tuple[np.ndarray, np.ndarray]]] = [] |
| 164 | + adapter.install_hook(self.transformer, self.hook) |
| 165 | + |
| 166 | + def collect_from_prompt(self, prompt: str, **generate_kwargs): |
| 167 | + self.hook.start_collection() |
| 168 | + from vllm_omni.diffusion.request import OmniDiffusionRequest |
| 169 | + |
| 170 | + req = OmniDiffusionRequest( |
| 171 | + prompt=prompt, |
| 172 | + num_inference_steps=generate_kwargs.get("num_inference_steps", 20), |
| 173 | + seed=generate_kwargs.get("seed", 42), |
| 174 | + ) |
| 175 | + self.pipeline.forward(req) |
| 176 | + trajectory = self.hook.stop_collection() |
| 177 | + if trajectory: |
| 178 | + self.collected_data.append(trajectory) |
| 179 | + |
| 180 | + def estimate(self, poly_order: int = 4) -> list[float]: |
| 181 | + """Estimate polynomial coefficients from collected data. |
| 182 | +
|
| 183 | + Args: |
| 184 | + poly_order: Order of polynomial fit (default: 4) |
| 185 | +
|
| 186 | + Returns: |
| 187 | + List of polynomial coefficients [a_n, a_{n-1}, ..., a_1, a_0] |
| 188 | +
|
| 189 | + Raises: |
| 190 | + RuntimeError: If no data has been collected |
| 191 | + """ |
| 192 | + if not self.collected_data: |
| 193 | + raise RuntimeError( |
| 194 | + "No data collected for coefficient estimation. " |
| 195 | + "Call collect_from_prompt() at least once before calling estimate()." |
| 196 | + ) |
| 197 | + return estimate_teacache_coefficients(self.collected_data, poly_order) |
0 commit comments