|
8 | 8 |
|
9 | 9 | import contextlib
|
10 | 10 | from abc import ABC, abstractmethod
|
11 |
| -from typing import Dict, TypeVar, Union |
| 11 | +from dataclasses import dataclass |
| 12 | +from enum import Enum |
| 13 | +from typing import Callable, Dict, Optional, TypeVar, Union |
12 | 14 |
|
| 15 | +import tqdm |
| 16 | +from compressed_tensors.utils import replace_module |
13 | 17 | from transformers import PreTrainedModel
|
14 | 18 |
|
| 19 | +from llmcompressor.utils.helpers import patch_attr |
| 20 | + |
15 | 21 | T = TypeVar("T", bound="MoECalibrationContext")
|
16 | 22 |
|
17 | 23 |
|
| 24 | +class MoECalibrationType(Enum): |
| 25 | + """Enumeration of supported MoE calibration types.""" |
| 26 | + |
| 27 | + PERMANENT = "permanent" |
| 28 | + CONTEXTUAL = "contextual" |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class MoEModelConfig: |
| 33 | + """ |
| 34 | + Configuration for MoE model calibration. |
| 35 | +
|
| 36 | + This dataclass defines the parameters needed to configure MoE calibration |
| 37 | + for a specific model architecture. It follows the same pattern used by |
| 38 | + other model configuration systems in the project (e.g., SmoothQuant, AWQ). |
| 39 | +
|
| 40 | + Attributes: |
| 41 | + calibration_type: Type of calibration - MoECalibrationType.PERMANENT or |
| 42 | + MoECalibrationType.CONTEXTUAL |
| 43 | + target_class_name: The class name of the MoE module to replace |
| 44 | + replace_function: Function that creates the replacement module |
| 45 | + target_attribute: For contextual calibration, the attribute to replace |
| 46 | + description: Optional description of the model configuration |
| 47 | + """ |
| 48 | + |
| 49 | + calibration_type: MoECalibrationType |
| 50 | + target_class_name: str |
| 51 | + replace_function: Callable |
| 52 | + target_attribute: Optional[str] = None |
| 53 | + description: Optional[str] = None |
| 54 | + |
| 55 | + def __post_init__(self): |
| 56 | + """Validate configuration after initialization.""" |
| 57 | + if ( |
| 58 | + self.calibration_type == MoECalibrationType.CONTEXTUAL |
| 59 | + and self.target_attribute is None |
| 60 | + ): |
| 61 | + raise ValueError("target_attribute is required for contextual calibration") |
| 62 | + |
| 63 | + if ( |
| 64 | + self.calibration_type == MoECalibrationType.PERMANENT |
| 65 | + and self.target_attribute is not None |
| 66 | + ): |
| 67 | + raise ValueError( |
| 68 | + "target_attribute should not be set for permanent calibration" |
| 69 | + ) |
| 70 | + |
| 71 | + |
| 72 | +# Registry of MoE model configurations |
| 73 | +# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY |
| 74 | +MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} |
| 75 | + |
| 76 | + |
18 | 77 | class MoECalibrationContext(ABC):
|
19 | 78 | """
|
20 | 79 | Abstract base class for MoE calibration.
|
@@ -60,13 +119,14 @@ def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> N
|
60 | 119 | """Apply MoE calibration modifications using context managers."""
|
61 | 120 | if self._stack is None:
|
62 | 121 | self._stack = contextlib.ExitStack()
|
| 122 | + self._stack.__enter__() |
63 | 123 |
|
64 | 124 | self.update_function(model, self._stack, calibrate_all_experts)
|
65 | 125 |
|
66 | 126 | def restore(self, model: PreTrainedModel) -> None:
|
67 | 127 | """Restore the model by exiting the context stack."""
|
68 | 128 | if self._stack is not None:
|
69 |
| - self._stack.close() |
| 129 | + self._stack.__exit__(None, None, None) |
70 | 130 | self._stack = None
|
71 | 131 |
|
72 | 132 |
|
@@ -134,22 +194,108 @@ def list_supported_models() -> list:
|
134 | 194 | return list(_MOE_CONTEXTS.keys())
|
135 | 195 |
|
136 | 196 |
|
137 |
| -# Convenience function for backward compatibility |
138 |
| -def create_context_manager_context(model_class_name: str, update_function): |
| 197 | +# Generic factory functions for creating MoE updaters |
| 198 | +def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): |
139 | 199 | """
|
140 |
| - Create a context manager-based MoE calibration. |
141 |
| - :param model_class_name: The class name of the model |
142 |
| - :param update_function: Function that applies the MoE modifications |
143 |
| - :return: A ContextualMoECalibration instance |
| 200 | + Create a permanent MoE updater function for the given target class. |
| 201 | +
|
| 202 | + Args: |
| 203 | + target_class_name: The class name to look for in the model |
| 204 | + replace_function: Function that creates the replacement module |
| 205 | +
|
| 206 | + Returns: |
| 207 | + A function that can be used with PermanentMoECalibration |
144 | 208 | """
|
145 |
| - return ContextualMoECalibration(model_class_name, update_function) |
146 | 209 |
|
| 210 | + def update_function(model: PreTrainedModel, calibrate_all_experts: bool): |
| 211 | + """Update MoE modules for calibration.""" |
| 212 | + for name, module in tqdm.tqdm(list(model.named_modules())): |
| 213 | + if module.__class__.__name__ == target_class_name: |
| 214 | + new_module = replace_function( |
| 215 | + config=model.config, |
| 216 | + module=module, |
| 217 | + calibrate_all_experts=calibrate_all_experts, |
| 218 | + ) |
| 219 | + replace_module(model, name, new_module) |
| 220 | + |
| 221 | + return update_function |
147 | 222 |
|
148 |
| -def create_permanent_context(model_class_name: str, replacement_function): |
| 223 | + |
| 224 | +def create_contextual_moe_updater( |
| 225 | + target_class_name: str, target_attr: str, replace_function: Callable |
| 226 | +): |
149 | 227 | """
|
150 |
| - Create a permanent MoE calibration. |
151 |
| - :param model_class_name: The class name of the model |
152 |
| - :param replacement_function: Function that permanently replaces MoE modules |
153 |
| - :return: A PermanentMoECalibration instance |
| 228 | + Create a contextual MoE updater function for the given target class and attribute. |
| 229 | +
|
| 230 | + Args: |
| 231 | + target_class_name: The class name to look for in the model |
| 232 | + target_attr: The attribute name to replace within the target class |
| 233 | + replace_function: Function that creates the replacement module |
| 234 | +
|
| 235 | + Returns: |
| 236 | + A function that can be used with ContextualMoECalibration |
| 237 | + """ |
| 238 | + |
| 239 | + def update_function( |
| 240 | + model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool |
| 241 | + ): |
| 242 | + """Update MoE modules for calibration using context managers.""" |
| 243 | + for module in model.modules(): |
| 244 | + if module.__class__.__name__ == target_class_name: |
| 245 | + stack.enter_context( |
| 246 | + patch_attr( |
| 247 | + module, |
| 248 | + target_attr, |
| 249 | + replace_function( |
| 250 | + config=model.config, |
| 251 | + module=getattr(module, target_attr), |
| 252 | + calibrate_all_experts=calibrate_all_experts, |
| 253 | + ), |
| 254 | + ) |
| 255 | + ) |
| 256 | + |
| 257 | + return update_function |
| 258 | + |
| 259 | + |
| 260 | +def register_moe_model(model_class_name: str, config: MoEModelConfig): |
154 | 261 | """
|
155 |
| - return PermanentMoECalibration(model_class_name, replacement_function) |
| 262 | + Register a MoE model with its configuration. |
| 263 | +
|
| 264 | + Args: |
| 265 | + model_class_name: The model class name |
| 266 | + config: MoEModelConfig dataclass instance with calibration parameters |
| 267 | + """ |
| 268 | + if config.calibration_type == MoECalibrationType.PERMANENT: |
| 269 | + updater = create_permanent_moe_updater( |
| 270 | + config.target_class_name, config.replace_function |
| 271 | + ) |
| 272 | + context = PermanentMoECalibration(config.target_class_name, updater) |
| 273 | + elif config.calibration_type == MoECalibrationType.CONTEXTUAL: |
| 274 | + updater = create_contextual_moe_updater( |
| 275 | + config.target_class_name, config.target_attribute, config.replace_function |
| 276 | + ) |
| 277 | + context = ContextualMoECalibration(model_class_name, updater) |
| 278 | + else: |
| 279 | + raise ValueError(f"Unknown MoE type: {config.calibration_type}") |
| 280 | + |
| 281 | + register_moe_context(model_class_name, context) |
| 282 | + |
| 283 | + |
| 284 | +def register_moe_model_from_dict(model_class_name: str, config_dict: dict): |
| 285 | + """ |
| 286 | + Register a MoE model from a dictionary configuration (backward compatibility). |
| 287 | +
|
| 288 | + Args: |
| 289 | + model_class_name: The model class name |
| 290 | + config_dict: Dictionary with calibration parameters |
| 291 | + """ |
| 292 | + # Convert string calibration_type to enum |
| 293 | + if "calibration_type" in config_dict and isinstance( |
| 294 | + config_dict["calibration_type"], str |
| 295 | + ): |
| 296 | + config_dict["calibration_type"] = MoECalibrationType( |
| 297 | + config_dict["calibration_type"] |
| 298 | + ) |
| 299 | + |
| 300 | + config = MoEModelConfig(**config_dict) |
| 301 | + register_moe_model(model_class_name, config) |
0 commit comments