-
Notifications
You must be signed in to change notification settings - Fork 453
Feature/calibrate weights dfs fused modules #2394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fc5ec42
c99ce38
97e6ca8
480294d
b3d2de9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| """ | ||
| Central source of truth for vLLM-aligned fused module layouts. | ||
|
|
||
| Defines which submodules form a single "fused" group for TENSOR_GROUP (e.g. NVFP4) | ||
| global-scale sharing in vLLM. All callers that need to detect or iterate over | ||
| fused attention / MLP linears should use the functions in this module. | ||
|
|
||
| Fused attention (vLLM) | ||
| ---------------------- | ||
| - **Traditional:** Three linears `q_proj`, `k_proj`, `v_proj` that share one | ||
| global scale in vLLM. | ||
| - **Fused QKV:** Single linear `qkv_proj` (Phi, etc.). No multi-layer fusion | ||
| needed; already one tensor. | ||
| - **MLA (Multi-head Latent Attention):** Two linears that share one global | ||
| scale: query projection and fused KV projection. Common attribute names: | ||
| - `q_a_proj` (or `q_proj`) for query | ||
| - `kv_a_proj_with_mqa` for key/value (fused KV with MQA). | ||
| Used in DeepSeek V2/V3, Kimi K2, Mistral Large 3, and similar architectures. | ||
|
|
||
| Fused MLP (vLLM) | ||
| ---------------- | ||
| - **Gate/Up:** Two linears `gate_proj`, `up_proj` that share one global scale. | ||
| - **Fused Gate-Up:** Single linear `gate_up_proj` (Phi, etc.). No multi-layer | ||
| fusion needed. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| from torch.nn import Linear, Module | ||
|
|
||
| __all__ = [ | ||
| "get_fused_attention_linears", | ||
| "get_fused_mlp_linears", | ||
| "is_fused_attention_module", | ||
| "is_fused_mlp_module", | ||
| ] | ||
|
|
||
|
|
||
| def get_fused_attention_linears(module: Module) -> Optional[List[Linear]]: | ||
| """ | ||
| Return the list of Linear submodules that form one fused attention group | ||
| for vLLM TENSOR_GROUP global scale, or None if this module is not a known | ||
| fused attention container. | ||
|
|
||
| Definitions (vLLM-aligned): | ||
| - **Traditional:** `q_proj`, `k_proj`, `v_proj` (three linears). | ||
| - **Fused QKV:** single `qkv_proj` → returns None (no cross-layer fusion). | ||
| - **MLA:** `q_a_proj` (or `q_proj`) + `kv_a_proj_with_mqa` (two linears). | ||
|
|
||
| :param module: A candidate attention container (e.g. parent of q/k/v or MLA). | ||
| :return: List of Linear modules that should share one global scale, or None. | ||
| """ | ||
| # Already fused as one layer; no cross-layer global scale to apply | ||
| if hasattr(module, "qkv_proj"): | ||
| return None | ||
|
|
||
| # Traditional: q_proj, k_proj, v_proj | ||
| if ( | ||
| hasattr(module, "q_proj") | ||
| and hasattr(module, "k_proj") | ||
| and hasattr(module, "v_proj") | ||
| ): | ||
| q, k, v = module.q_proj, module.k_proj, module.v_proj | ||
| # Avoid treating MLA blocks as traditional (MLA has q_proj + kv_a_proj_with_mqa) | ||
| if hasattr(module, "kv_a_proj_with_mqa"): | ||
| return None | ||
| if isinstance(q, Linear) and isinstance(k, Linear) and isinstance(v, Linear): | ||
| return [q, k, v] | ||
|
|
||
| # MLA: q_a_proj (or q_proj) + kv_a_proj_with_mqa | ||
| if hasattr(module, "kv_a_proj_with_mqa"): | ||
| kv = module.kv_a_proj_with_mqa | ||
| q_linear = getattr(module, "q_a_proj", None) or getattr(module, "q_proj", None) | ||
| if ( | ||
| q_linear is not None | ||
| and isinstance(q_linear, Linear) | ||
| and isinstance(kv, Linear) | ||
| ): | ||
| return [q_linear, kv] | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def get_fused_mlp_linears(module: Module) -> Optional[List[Linear]]: | ||
| """ | ||
| Return the list of Linear submodules that form one fused MLP group for | ||
| vLLM TENSOR_GROUP global scale, or None if not a known fused MLP container. | ||
|
|
||
| Definitions (vLLM-aligned): | ||
| - **Gate/Up:** `gate_proj`, `up_proj` (two linears). | ||
| - **Fused Gate-Up:** single `gate_up_proj` → returns None (no cross-layer fusion). | ||
|
|
||
| :param module: A candidate MLP container (e.g. parent of gate_proj/up_proj). | ||
| :return: List of Linear modules that should share one global scale, or None. | ||
| """ | ||
| # Already fused as one layer | ||
| if hasattr(module, "gate_up_proj"): | ||
| return None | ||
|
|
||
| # Gate/Up: gate_proj, up_proj (require "mlp" in class name to avoid false positives) | ||
| if "mlp" not in module.__class__.__name__.lower(): | ||
| return None | ||
| if hasattr(module, "gate_proj") and hasattr(module, "up_proj"): | ||
| gate = module.gate_proj | ||
| up = module.up_proj | ||
| if isinstance(gate, Linear) and isinstance(up, Linear): | ||
| return [gate, up] | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def is_fused_attention_module(module: Module) -> bool: | ||
| """True if this module is a fused attention container (traditional or MLA).""" | ||
| return get_fused_attention_linears(module) is not None | ||
|
|
||
|
|
||
| def is_fused_mlp_module(module: Module) -> bool: | ||
| """True if this module is a fused MLP container (gate/up).""" | ||
| return get_fused_mlp_linears(module) is not None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,9 @@ | ||
| from typing import Any, Optional | ||
| import threading | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from typing import Any, Iterable, Iterator, Optional, Tuple | ||
|
|
||
| import torch | ||
| import tqdm | ||
| from compressed_tensors.quantization import ( | ||
| DynamicType, | ||
| QuantizationArgs, | ||
|
|
@@ -11,14 +14,17 @@ | |
| from compressed_tensors.utils import ( | ||
| align_module_device, | ||
| getattr_chain, | ||
| match_named_modules, | ||
| update_offload_parameter, | ||
| ) | ||
| from loguru import logger | ||
| from torch.nn import Module | ||
|
|
||
| from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales | ||
| from llmcompressor.observers import Observer | ||
|
|
||
| __all__ = [ | ||
| "calibrate_weights", | ||
| "initialize_observer", | ||
| "update_weight_zp_scale", | ||
| "calibrate_input_hook", | ||
|
|
@@ -134,6 +140,52 @@ def update_weight_global_scale(module: Module): | |
| ) | ||
|
|
||
|
|
||
| def _post_order_modules(model: Module) -> Iterator[Module]: | ||
| """Yield every module in the tree in DFS post-order.""" | ||
| stack: list[Tuple[Module, bool]] = [(model, False)] | ||
| while stack: | ||
| module, children_done = stack.pop() | ||
| if not children_done: | ||
| stack.append((module, True)) | ||
| for child in reversed(list(module.children())): | ||
| stack.append((child, False)) | ||
| else: | ||
| yield module | ||
|
|
||
|
|
||
| def _update_weight_calibration_once(module: Module, update_zp_scale: bool) -> None: | ||
| """ | ||
| Onload weight once and run both global scale (gparam) and scale/zp (qparams). | ||
| Used in sequential DFS to avoid double onload for NVFP4. | ||
| """ | ||
| if getattr_chain(module, "quantization_scheme.weights", None) is None: | ||
| return | ||
| need_gparam = ( | ||
| getattr_chain(module, "quantization_scheme.weights.strategy", None) | ||
| == QuantizationStrategy.TENSOR_GROUP | ||
| ) | ||
| need_qparams = update_zp_scale | ||
| if not need_gparam and not need_qparams: | ||
| return | ||
| if ( | ||
| need_qparams | ||
| and getattr(module, "quantization_status", None) | ||
| != QuantizationStatus.CALIBRATION | ||
| ): | ||
| logger.warning( | ||
| "Attempting to calibrate weights of a module not in calibration mode" | ||
| ) | ||
| with align_module_device(module): | ||
| value = module.weight | ||
| call_observer( | ||
| module, | ||
| base_name="weight", | ||
| value=value, | ||
| should_calculate_gparam=need_gparam, | ||
| should_calculate_qparams=need_qparams, | ||
| ) | ||
|
|
||
|
|
||
| def update_weight_zp_scale(module: Module): | ||
| """ | ||
| marks a layer as ready for calibration which activates observers | ||
|
|
@@ -156,6 +208,121 @@ def update_weight_zp_scale(module: Module): | |
| call_observer(module=module, base_name="weight") | ||
|
|
||
|
|
||
| def calibrate_weights( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the pre-order and post-order framing is very elegant, but it may get in the way of sharing weight offloading between both
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the pre/post structure was getting in the way of a single onload. I’ve added _update_weight_calibration_once(module, update_zp_scale) which onloads module.weight once and calls call_observer(..., value=value, should_calculate_gparam=..., should_calculate_qparams=...) so both use the same tensor. The sequential DFS now uses this in pre-order for target modules and no longer calls update_weight_zp_scale in post-order for them, so we get one onload per module for NVFP4. |
||
| model: Module, | ||
| *, | ||
| named_modules: Optional[Iterable[Tuple[str, Module]]] = None, | ||
| targets: Iterable[str] = (), | ||
| ignore: Iterable[str] = (), | ||
| update_zp_scale: bool = True, | ||
| desc: Optional[str] = "Calibrating weights", | ||
| show_progress: bool = True, | ||
| parallel: bool = False, | ||
| max_workers: Optional[int] = None, | ||
| ) -> None: | ||
| """ | ||
| Run weight calibration: per-tensor global scale (gparam), fused global scales | ||
| for Attention/MLP, and scale/zero-point (qparams). Minimizes weight onloads | ||
| when using offloading (one onload per target in the default path). | ||
|
|
||
| Two modes: | ||
| - Sequential (parallel=False): DFS over the model. Pre-order: one onload per | ||
| target via _update_weight_calibration_once (gparam + qparams). Post-order: | ||
| update_fused_layer_weight_global_scales (no extra onload for targets). | ||
| - Parallel (parallel=True): Phase 1 runs gparam + qparams per target | ||
| (order-independent, parallelizable). Phase 2 applies fused global scales | ||
| and rescales per-tensor scale s' = s * (g' / g). | ||
|
|
||
| DDP: Works with distributed setups. Pass named_modules as this rank's | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not seeing any of this functionality |
||
| subset so each rank only calibrates its assigned modules (see e.g. #2220). | ||
| Activation observer sync across ranks is handled by | ||
| QuantizationMixin.sync_activation_observers at layer | ||
| boundaries (PR #2391); weight calibration does not all-reduce weight | ||
| observer state—each rank calibrates its subset and can broadcast | ||
| quantized params afterward (e.g. GPTQ-style) if needed. Fused groups | ||
| (q/k/v, gate/up) must be assigned to the same rank so | ||
| update_fused_layer_weight_global_scales sees the full group. For | ||
| balanced wall time, assign by weight size (e.g. greedy_bin_packing with | ||
| item_weight_fn=lambda m: m.weight.numel(); see GPTQ DDP #2333 which uses | ||
| hessian shape for the same idea). | ||
|
|
||
| Benchmark: See tests/benchmark_calibrate_weights.py for onload count and | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this a real file? |
||
| single-vs-double-onload timing. | ||
|
|
||
| :param model: Root module to traverse (e.g. state.model). | ||
| :param named_modules: If provided, only these (name, module) pairs are | ||
| calibrated; enables DDP by passing this rank's subset. If None, uses | ||
| match_named_modules(model, targets, ignore). | ||
| :param targets: Name patterns when named_modules is None. Default (). | ||
| :param ignore: Ignore patterns when named_modules is None. Default (). | ||
| :param update_zp_scale: If True, compute scale/zp for targets. False for | ||
| modifiers that do zp in hooks (e.g. GPTQ). | ||
| :param desc: Progress bar description; None disables bar. | ||
| :param show_progress: If True and desc set, show tqdm bar. | ||
| :param parallel: If True, use two-phase parallel calibration. | ||
| :param max_workers: If parallel and int, phase 1 uses this many workers. | ||
| """ | ||
| if named_modules is None: | ||
| named_modules = list(match_named_modules(model, targets, ignore)) | ||
| else: | ||
| named_modules = list(named_modules) | ||
| # DDP: target_set = only these get gparam + qparams (this rank's subset). | ||
| target_set = {m for _, m in named_modules} | ||
| target_list = list(target_set) | ||
| total_targets = len(target_list) | ||
|
|
||
| if show_progress and desc is not None and total_targets > 0: | ||
| pbar = tqdm.tqdm(total=total_targets, desc=desc) | ||
| else: | ||
| pbar = None | ||
|
|
||
| if parallel: | ||
| # Phase 1: per-module global scale + scale/zp (order-independent) | ||
| pbar_lock = threading.Lock() | ||
|
|
||
| def _phase1_one(module: Module) -> None: | ||
| update_weight_global_scale(module) | ||
| if update_zp_scale: | ||
| update_weight_zp_scale(module) | ||
| if pbar is not None: | ||
| with pbar_lock: | ||
| pbar.update(1) | ||
|
|
||
| if max_workers is not None and max_workers > 0: | ||
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | ||
| list(executor.map(_phase1_one, target_list)) | ||
| else: | ||
| for module in target_list: | ||
| _phase1_one(module) | ||
|
|
||
| # Phase 2: fused global scales (rescale per-tensor scale s' = s * g' / g) | ||
| for module in _post_order_modules(model): | ||
| update_fused_layer_weight_global_scales(module) | ||
| else: | ||
| # Sequential DFS: pre-order one onload for gparam + qparams, post-order fused | ||
| seen_pre: set[Module] = set() | ||
| seen_post: set[Module] = set() | ||
| stack = [(model, False)] | ||
| while stack: | ||
| module, children_done = stack.pop() | ||
| if not children_done: | ||
| if module in target_set and module not in seen_pre: | ||
| seen_pre.add(module) | ||
| _update_weight_calibration_once(module, update_zp_scale) | ||
| stack.append((module, True)) | ||
| for child in reversed(list(module.children())): | ||
| stack.append((child, False)) | ||
| else: | ||
| update_fused_layer_weight_global_scales(module) | ||
| if update_zp_scale and module in target_set and module not in seen_post: | ||
| seen_post.add(module) | ||
| if pbar is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this progress bar update should be unindented or else it only updates when zp_scale is hit |
||
| pbar.update(1) | ||
|
|
||
| if pbar is not None: | ||
| pbar.close() | ||
|
|
||
|
|
||
| def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): | ||
| """ | ||
| Calibrate input or output activations by calling the a module's attached | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.