Skip to content

Commit b04f957

Browse files
committed
Use moe_context in pipeline and by default and add tests
1 parent ef8e0b7 commit b04f957

File tree

9 files changed

+318
-116
lines changed

9 files changed

+318
-116
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets.
88
"""
99

10+
import warnings
1011
from dataclasses import dataclass, field
1112
from typing import Any, Callable, Dict, List, Optional, Union
1213

@@ -126,16 +127,6 @@ class DatasetArguments(CustomDatasetArguments):
126127
default=512,
127128
metadata={"help": "Number of samples to use for one-shot calibration"},
128129
)
129-
calibrate_moe_context: bool = field(
130-
default=False,
131-
metadata={
132-
"help": "If during calibration, the MoE context should be enabled "
133-
"for the given model. This usually involves updating all MoE modules "
134-
"in the model for the duration of calibration. See moe_context under "
135-
"modeling/prepare.py for a list of supported MoEs and their updated "
136-
"module definitions"
137-
},
138-
)
139130
shuffle_calibration_samples: Optional[bool] = field(
140131
default=True,
141132
metadata={
@@ -181,6 +172,17 @@ class DatasetArguments(CustomDatasetArguments):
181172
),
182173
},
183174
)
175+
calibrate_moe_context: Optional[bool] = field(
176+
default=None,
177+
metadata={
178+
"help": (
179+
"DEPRECATED: This parameter is deprecated and will be \
180+
removed in a future version. "
181+
"MoE calibration context is now handled automatically by the pipeline. "
182+
"This parameter is ignored and will not affect the calibration process."
183+
),
184+
},
185+
)
184186
# --- pipeline arguments --- #
185187
pipeline: Optional[str] = field(
186188
default="independent",
@@ -229,3 +231,16 @@ class DatasetArguments(CustomDatasetArguments):
229231

230232
def is_dataset_provided(self) -> bool:
231233
return self.dataset is not None or self.dataset_path is not None
234+
235+
def __post_init__(self):
236+
"""Post-initialization hook to issue deprecation warnings."""
237+
if self.calibrate_moe_context is not None:
238+
warnings.warn(
239+
"The 'calibrate_moe_context' parameter is deprecated\
240+
and will be removed in a future version. "
241+
"MoE calibration context is now handled automatically by the pipeline. "
242+
"This parameter is ignored and will not affect\
243+
the calibration process.",
244+
DeprecationWarning,
245+
stacklevel=2,
246+
)

src/llmcompressor/modeling/moe_context.py

Lines changed: 161 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,72 @@
88

99
import contextlib
1010
from abc import ABC, abstractmethod
11-
from typing import Dict, TypeVar, Union
11+
from dataclasses import dataclass
12+
from enum import Enum
13+
from typing import Callable, Dict, Optional, TypeVar, Union
1214

15+
import tqdm
16+
from compressed_tensors.utils import replace_module
1317
from transformers import PreTrainedModel
1418

19+
from llmcompressor.utils.helpers import patch_attr
20+
1521
T = TypeVar("T", bound="MoECalibrationContext")
1622

1723

24+
class MoECalibrationType(Enum):
25+
"""Enumeration of supported MoE calibration types."""
26+
27+
PERMANENT = "permanent"
28+
CONTEXTUAL = "contextual"
29+
30+
31+
@dataclass
32+
class MoEModelConfig:
33+
"""
34+
Configuration for MoE model calibration.
35+
36+
This dataclass defines the parameters needed to configure MoE calibration
37+
for a specific model architecture. It follows the same pattern used by
38+
other model configuration systems in the project (e.g., SmoothQuant, AWQ).
39+
40+
Attributes:
41+
calibration_type: Type of calibration - MoECalibrationType.PERMANENT or
42+
MoECalibrationType.CONTEXTUAL
43+
target_class_name: The class name of the MoE module to replace
44+
replace_function: Function that creates the replacement module
45+
target_attribute: For contextual calibration, the attribute to replace
46+
description: Optional description of the model configuration
47+
"""
48+
49+
calibration_type: MoECalibrationType
50+
target_class_name: str
51+
replace_function: Callable
52+
target_attribute: Optional[str] = None
53+
description: Optional[str] = None
54+
55+
def __post_init__(self):
56+
"""Validate configuration after initialization."""
57+
if (
58+
self.calibration_type == MoECalibrationType.CONTEXTUAL
59+
and self.target_attribute is None
60+
):
61+
raise ValueError("target_attribute is required for contextual calibration")
62+
63+
if (
64+
self.calibration_type == MoECalibrationType.PERMANENT
65+
and self.target_attribute is not None
66+
):
67+
raise ValueError(
68+
"target_attribute should not be set for permanent calibration"
69+
)
70+
71+
72+
# Registry of MoE model configurations
73+
# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY
74+
MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {}
75+
76+
1877
class MoECalibrationContext(ABC):
1978
"""
2079
Abstract base class for MoE calibration.
@@ -60,13 +119,14 @@ def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> N
60119
"""Apply MoE calibration modifications using context managers."""
61120
if self._stack is None:
62121
self._stack = contextlib.ExitStack()
122+
self._stack.__enter__()
63123

64124
self.update_function(model, self._stack, calibrate_all_experts)
65125

66126
def restore(self, model: PreTrainedModel) -> None:
67127
"""Restore the model by exiting the context stack."""
68128
if self._stack is not None:
69-
self._stack.close()
129+
self._stack.__exit__(None, None, None)
70130
self._stack = None
71131

72132

@@ -134,22 +194,108 @@ def list_supported_models() -> list:
134194
return list(_MOE_CONTEXTS.keys())
135195

136196

137-
# Convenience function for backward compatibility
138-
def create_context_manager_context(model_class_name: str, update_function):
197+
# Generic factory functions for creating MoE updaters
198+
def create_permanent_moe_updater(target_class_name: str, replace_function: Callable):
139199
"""
140-
Create a context manager-based MoE calibration.
141-
:param model_class_name: The class name of the model
142-
:param update_function: Function that applies the MoE modifications
143-
:return: A ContextualMoECalibration instance
200+
Create a permanent MoE updater function for the given target class.
201+
202+
Args:
203+
target_class_name: The class name to look for in the model
204+
replace_function: Function that creates the replacement module
205+
206+
Returns:
207+
A function that can be used with PermanentMoECalibration
144208
"""
145-
return ContextualMoECalibration(model_class_name, update_function)
146209

210+
def update_function(model: PreTrainedModel, calibrate_all_experts: bool):
211+
"""Update MoE modules for calibration."""
212+
for name, module in tqdm.tqdm(list(model.named_modules())):
213+
if module.__class__.__name__ == target_class_name:
214+
new_module = replace_function(
215+
config=model.config,
216+
module=module,
217+
calibrate_all_experts=calibrate_all_experts,
218+
)
219+
replace_module(model, name, new_module)
220+
221+
return update_function
147222

148-
def create_permanent_context(model_class_name: str, replacement_function):
223+
224+
def create_contextual_moe_updater(
225+
target_class_name: str, target_attr: str, replace_function: Callable
226+
):
149227
"""
150-
Create a permanent MoE calibration.
151-
:param model_class_name: The class name of the model
152-
:param replacement_function: Function that permanently replaces MoE modules
153-
:return: A PermanentMoECalibration instance
228+
Create a contextual MoE updater function for the given target class and attribute.
229+
230+
Args:
231+
target_class_name: The class name to look for in the model
232+
target_attr: The attribute name to replace within the target class
233+
replace_function: Function that creates the replacement module
234+
235+
Returns:
236+
A function that can be used with ContextualMoECalibration
237+
"""
238+
239+
def update_function(
240+
model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool
241+
):
242+
"""Update MoE modules for calibration using context managers."""
243+
for module in model.modules():
244+
if module.__class__.__name__ == target_class_name:
245+
stack.enter_context(
246+
patch_attr(
247+
module,
248+
target_attr,
249+
replace_function(
250+
config=model.config,
251+
module=getattr(module, target_attr),
252+
calibrate_all_experts=calibrate_all_experts,
253+
),
254+
)
255+
)
256+
257+
return update_function
258+
259+
260+
def register_moe_model(model_class_name: str, config: MoEModelConfig):
154261
"""
155-
return PermanentMoECalibration(model_class_name, replacement_function)
262+
Register a MoE model with its configuration.
263+
264+
Args:
265+
model_class_name: The model class name
266+
config: MoEModelConfig dataclass instance with calibration parameters
267+
"""
268+
if config.calibration_type == MoECalibrationType.PERMANENT:
269+
updater = create_permanent_moe_updater(
270+
config.target_class_name, config.replace_function
271+
)
272+
context = PermanentMoECalibration(config.target_class_name, updater)
273+
elif config.calibration_type == MoECalibrationType.CONTEXTUAL:
274+
updater = create_contextual_moe_updater(
275+
config.target_class_name, config.target_attribute, config.replace_function
276+
)
277+
context = ContextualMoECalibration(model_class_name, updater)
278+
else:
279+
raise ValueError(f"Unknown MoE type: {config.calibration_type}")
280+
281+
register_moe_context(model_class_name, context)
282+
283+
284+
def register_moe_model_from_dict(model_class_name: str, config_dict: dict):
285+
"""
286+
Register a MoE model from a dictionary configuration (backward compatibility).
287+
288+
Args:
289+
model_class_name: The model class name
290+
config_dict: Dictionary with calibration parameters
291+
"""
292+
# Convert string calibration_type to enum
293+
if "calibration_type" in config_dict and isinstance(
294+
config_dict["calibration_type"], str
295+
):
296+
config_dict["calibration_type"] = MoECalibrationType(
297+
config_dict["calibration_type"]
298+
)
299+
300+
config = MoEModelConfig(**config_dict)
301+
register_moe_model(model_class_name, config)
Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import warnings
23

34
import tqdm
@@ -6,10 +7,15 @@
67

78
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
89
from llmcompressor.modeling.llama4 import replace as replace_llama4
10+
from llmcompressor.modeling.moe_context import (
11+
MoECalibrationType,
12+
MoEModelConfig,
13+
get_moe_context,
14+
register_moe_model,
15+
)
916
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
10-
from llmcompressor.utils.helpers import patch_attr
1117

12-
__all__ = ["replace_modules_for_calibration"]
18+
__all__ = ["moe_calibration_context"]
1319

1420
# ---------------------- module replacements; permanent -------------------------
1521
replacements = {
@@ -24,8 +30,8 @@ def replace_modules_for_calibration(
2430
) -> PreTrainedModel:
2531
# This function is deprecated. Use moe_calibration_context instead.
2632
warnings.warn(
27-
"replace_modules_for_calibration is deprecated. \
28-
Use moe_calibration_context instead.",
33+
"replace_modules_for_calibration is deprecated. "
34+
"Use moe_calibration_context instead.",
2935
DeprecationWarning,
3036
stacklevel=2,
3137
)
@@ -45,37 +51,67 @@ def replace_modules_for_calibration(
4551

4652
# ------------------- module replacements; during calibration --------------------
4753

48-
49-
def update_qwen3_moe(model, stack, calibrate_all_experts):
50-
for module in model.modules():
51-
cls_name = module.__class__.__name__
52-
if cls_name == "Qwen3MoeDecoderLayer":
53-
# Optionally update the model.config to pass in other arguments
54-
stack.enter_context(
55-
patch_attr(
56-
module,
57-
"mlp",
58-
replace_Qwen3MoE(
59-
config=model.config,
60-
module=module.mlp,
61-
calibrate_all_experts=calibrate_all_experts,
62-
),
63-
)
64-
)
54+
# MoE model configurations - centralized registry
55+
# Adding a new MoE model is now as simple as adding an entry here!
56+
# This follows the same pattern as MAPPINGS_REGISTRY in SmoothQuant and AWQ
57+
MOE_EXPERTS_REPLACEMENTS = {
58+
"Qwen3MoeForCausalLM": MoEModelConfig(
59+
calibration_type=MoECalibrationType.CONTEXTUAL,
60+
target_class_name="Qwen3MoeDecoderLayer",
61+
target_attribute="mlp",
62+
replace_function=replace_Qwen3MoE,
63+
description="Qwen3 MoE model with contextual calibration for MLP layers",
64+
),
65+
"DeepseekV3ForCausalLM": MoEModelConfig(
66+
calibration_type=MoECalibrationType.PERMANENT,
67+
target_class_name="DeepseekV3MoE",
68+
replace_function=replace_deepseekv3,
69+
description="DeepSeek V3 MoE model with permanent calibration",
70+
),
71+
"Llama4ForConditionalGeneration": MoEModelConfig(
72+
calibration_type=MoECalibrationType.PERMANENT,
73+
target_class_name="Llama4TextMoe",
74+
replace_function=replace_llama4,
75+
description=(
76+
"Llama4 MoE model with permanent calibration for vLLM compatibility"
77+
),
78+
),
79+
}
6580

6681

67-
moe_context = {
68-
"Qwen3MoeForCausalLM": update_qwen3_moe,
69-
}
82+
# Register all MoE models automatically
83+
for model_class_name, config in MOE_EXPERTS_REPLACEMENTS.items():
84+
register_moe_model(model_class_name, config)
7085

7186

87+
@contextlib.contextmanager
7288
def moe_calibration_context(
7389
model: PreTrainedModel,
74-
stack,
75-
calibrate_all_experts: bool = False,
90+
calibrate_all_experts: bool = True,
7691
):
77-
# Temporarily updates the MoE modules within the context
78-
# Once the context exists, parameter updates persist
92+
"""
93+
Context manager for MoE calibration that temporarily updates MoE modules.
94+
95+
Args:
96+
model: The model to apply MoE calibration to
97+
calibrate_all_experts: Whether to calibrate all experts or only routed ones
98+
99+
Yields:
100+
The model with MoE calibration applied
101+
"""
79102
cls_name = model.__class__.__name__
80-
if cls_name in moe_context:
81-
moe_context.get(cls_name)(model, stack, calibrate_all_experts)
103+
moe_context = get_moe_context(cls_name)
104+
105+
if moe_context is None:
106+
# No MoE context registered for this model, yield unchanged
107+
yield model
108+
return
109+
110+
# Apply MoE calibration
111+
moe_context.apply(model, calibrate_all_experts)
112+
113+
try:
114+
yield model
115+
finally:
116+
# Restore original state
117+
moe_context.restore(model)

0 commit comments

Comments
 (0)