-
Notifications
You must be signed in to change notification settings - Fork 472
[feature] Add TeaCache Support to Glm Image #1458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
539465b
88ff0a7
d831a6a
14be7fa
aed91cb
88fd80c
561a175
c10fe29
c9234aa
8074caa
85def43
3e1dca9
d0bf296
4b69c41
6d39eb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,24 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import json | ||
| import types | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from huggingface_hub import hf_hub_download | ||
| from vllm.config import LoadConfig | ||
| from vllm.logger import init_logger | ||
|
|
||
| from vllm_omni.diffusion.cache.teacache.extractors import get_extractor | ||
| from vllm_omni.diffusion.data import OmniDiffusionConfig | ||
| from vllm_omni.diffusion.hooks import HookRegistry, ModelHook | ||
| from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader | ||
| from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline | ||
| from vllm_omni.diffusion.models.glm_image.pipeline_glm_image import GlmImagePipeline | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class DataCollectionHook(ModelHook): | ||
|
|
@@ -32,14 +38,15 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | |
|
|
||
| def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any: | ||
| ctx = self.extractor_fn(module, *args, **kwargs) | ||
| modulated_input_cpu = ctx.modulated_input.detach().cpu().numpy() | ||
| # Cast to float32 before .numpy() — bfloat16 is not supported by numpy | ||
| modulated_input_cpu = ctx.modulated_input.detach().cpu().float().numpy() | ||
|
|
||
| outputs = ctx.run_transformer_blocks() | ||
| ctx.hidden_states = outputs[0] | ||
| if len(outputs) > 1 and ctx.encoder_hidden_states is not None: | ||
| ctx.encoder_hidden_states = outputs[1] | ||
|
|
||
| model_output_cpu = ctx.hidden_states.detach().cpu().numpy() | ||
| model_output_cpu = ctx.hidden_states.detach().cpu().float().numpy() | ||
| self.current_trajectory.append((modulated_input_cpu, model_output_cpu)) | ||
|
|
||
| return ctx.postprocess(ctx.hidden_states) | ||
|
|
@@ -82,6 +89,44 @@ def forward_alias(self, *args, **kwargs): | |
| transformer._forward_flow = transformer.forward | ||
|
|
||
|
|
||
| class GlmImageAdapter: | ||
| """Adapter for GLM-Image model.""" | ||
|
|
||
| @staticmethod | ||
| def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16): | ||
| od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) | ||
| od_config.model_class_name = "GlmImagePipeline" | ||
|
|
||
| tf_config_path = hf_hub_download(model_path, "transformer/config.json") | ||
| with open(tf_config_path) as f: | ||
| tf_cfg = json.load(f) | ||
| od_config.tf_model_config.params = tf_cfg | ||
|
|
||
| pipeline = GlmImagePipeline(od_config=od_config) | ||
| loader = DiffusersPipelineLoader(LoadConfig()) | ||
| loader.load_weights(pipeline) | ||
|
|
||
| # GLM-Image bundles a large AR model alongside the DiT. | ||
| # Moving the full pipeline to GPU at once exceeds 80GB. | ||
| # For TeaCache coefficient estimation we only need the DiT transformer, | ||
| # so we move submodels selectively and offload the rest to CP | ||
| pipeline.transformer.to(device, dtype=dtype) | ||
| pipeline.vae.to(device, dtype=dtype) | ||
| pipeline.text_encoder.to(device, dtype=dtype) | ||
| # pipeline.model (the AR model) stays on CPU — it's not needed here | ||
|
|
||
| return pipeline | ||
|
|
||
| @staticmethod | ||
| def get_transformer(pipeline: Any) -> tuple[Any, str]: | ||
| return pipeline.transformer, "GlmImageTransformer2DModel" | ||
|
|
||
| @staticmethod | ||
| def install_hook(transformer: Any, hook: Any) -> None: | ||
| registry = HookRegistry.get_or_create(transformer) | ||
| registry.register_hook(hook._HOOK_NAME, hook) | ||
|
|
||
|
|
||
| class DefaultAdapter: | ||
| """Default adapter for standard diffusers pipelines.""" | ||
|
|
||
|
|
@@ -101,6 +146,7 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None: | |
|
|
||
| _MODEL_ADAPTERS: dict[str, type] = { | ||
| "Bagel": BagelAdapter, | ||
| "GlmImage": GlmImageAdapter, | ||
| } | ||
|
|
||
| _EPSILON = 1e-6 | ||
|
|
@@ -113,30 +159,6 @@ def calculate_relative_l1(tensor_current: np.ndarray, tensor_next: np.ndarray) - | |
| return diff / norm | ||
|
|
||
|
|
||
| def estimate_teacache_coefficients( | ||
| collected_data: list[list[tuple[np.ndarray, np.ndarray]]], poly_order: int = 4 | ||
| ) -> list[float]: | ||
| """Estimate polynomial coefficients for TeaCache using np.polyfit.""" | ||
| input_diffs, output_diffs = [], [] | ||
|
|
||
| for sample in collected_data: | ||
| for t in range(len(sample) - 1): | ||
| feat_in_curr, feat_out_curr = sample[t] | ||
| feat_in_next, feat_out_next = sample[t + 1] | ||
| input_diffs.append(calculate_relative_l1(feat_in_curr, feat_in_next)) | ||
| output_diffs.append(calculate_relative_l1(feat_out_curr, feat_out_next)) | ||
|
|
||
| x = np.array(input_diffs, dtype=np.float64) | ||
| y = np.array(output_diffs, dtype=np.float64) | ||
|
|
||
| print("Data statistics:") | ||
| print(f" Count: {len(x)}") | ||
| print(f" Input Diffs (x): min={x.min():.4e}, max={x.max():.4e}, mean={x.mean():.4e}") | ||
| print(f" Output Diffs (y): min={y.min():.4e}, max={y.max():.4e}, mean={y.mean():.4e}") | ||
|
|
||
| return np.polyfit(x, y, poly_order).tolist() | ||
|
|
||
|
|
||
| class TeaCacheCoefficientEstimator: | ||
| """Model-agnostic helper class to collect data and estimate TeaCache coefficients.""" | ||
|
|
||
|
|
@@ -147,7 +169,7 @@ def __init__( | |
| device: str = "cuda", | ||
| dtype: torch.dtype = torch.bfloat16, | ||
| ): | ||
| # Add validation here ⬇️ | ||
| # Add validation here | ||
| if model_type not in _MODEL_ADAPTERS: | ||
| available_types = list(_MODEL_ADAPTERS.keys()) | ||
| raise ValueError( | ||
|
|
@@ -160,22 +182,34 @@ def __init__( | |
| self.pipeline = adapter.load_pipeline(model_path, device, dtype) | ||
| self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline) | ||
| self.hook = DataCollectionHook(self.transformer_type) | ||
| self.collected_data: list[list[tuple[np.ndarray, np.ndarray]]] = [] | ||
| # Store only scalar diffs instead of raw numpy arrays to keep RAM flat | ||
| self.input_diffs: list[float] = [] | ||
| self.output_diffs: list[float] = [] | ||
| adapter.install_hook(self.transformer, self.hook) | ||
|
|
||
| def collect_from_prompt(self, prompt: str, **generate_kwargs): | ||
| self.hook.start_collection() | ||
| from vllm_omni.diffusion.request import OmniDiffusionRequest | ||
| from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams | ||
|
|
||
| req = OmniDiffusionRequest( | ||
| prompt=prompt, | ||
| sampling_params = OmniDiffusionSamplingParams( | ||
| num_inference_steps=generate_kwargs.get("num_inference_steps", 20), | ||
| seed=generate_kwargs.get("seed", 42), | ||
| ) | ||
|
Comment on lines
+194
to
197
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
| req = OmniDiffusionRequest( | ||
| prompts=[prompt], | ||
| sampling_params=sampling_params, | ||
| ) | ||
| self.pipeline.forward(req) | ||
| trajectory = self.hook.stop_collection() | ||
|
|
||
| # Compute diffs immediately and discard raw arrays — keeps RAM flat | ||
| if trajectory: | ||
| self.collected_data.append(trajectory) | ||
| for t in range(len(trajectory) - 1): | ||
| feat_in_curr, feat_out_curr = trajectory[t] | ||
| feat_in_next, feat_out_next = trajectory[t + 1] | ||
| self.input_diffs.append(calculate_relative_l1(feat_in_curr, feat_in_next)) | ||
| self.output_diffs.append(calculate_relative_l1(feat_out_curr, feat_out_next)) | ||
| del trajectory # explicitly free the raw arrays | ||
|
|
||
| def estimate(self, poly_order: int = 4) -> list[float]: | ||
| """Estimate polynomial coefficients from collected data. | ||
|
|
@@ -189,9 +223,23 @@ def estimate(self, poly_order: int = 4) -> list[float]: | |
| Raises: | ||
| RuntimeError: If no data has been collected | ||
| """ | ||
| if not self.collected_data: | ||
| if not self.input_diffs: | ||
| raise RuntimeError( | ||
| "No data collected for coefficient estimation. " | ||
| "Call collect_from_prompt() at least once before calling estimate()." | ||
| ) | ||
| return estimate_teacache_coefficients(self.collected_data, poly_order) | ||
| x = np.array(self.input_diffs, dtype=np.float64) | ||
| y = np.array(self.output_diffs, dtype=np.float64) | ||
|
|
||
| logger.info( | ||
| "TeaCache data statistics: count=%d, x=[min=%.4e, max=%.4e, mean=%.4e], y=[min=%.4e, max=%.4e, mean=%.4e]", | ||
| len(x), | ||
| x.min(), | ||
| x.max(), | ||
| x.mean(), | ||
| y.min(), | ||
| y.max(), | ||
| y.mean(), | ||
| ) | ||
|
|
||
| return np.polyfit(x, y, poly_order).tolist() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GlmImageAdapter.load_pipelineunconditionally callshf_hub_download(model_path, "transformer/config.json"), which treatsmodel_pathas a Hub repo id; when callers pass a local model directory (a supported pattern elsewhere in this repo), this path resolution fails before the pipeline is built. That makes TeaCache coefficient estimation for GLM unusable in local/offline setups and breaks parity with other loaders that accept filesystem paths.Useful? React with 👍 / 👎.