Skip to content

Commit c7a9943

Browse files
committed
Refactor MoECalibrationContext
1 parent 0994db6 commit c7a9943

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""
2+
Standardized interface for MoE model calibration.
3+
MoE calibration context is used to apply MoE calibration modifications to the model.
4+
There are two types of MoE calibration contexts:
5+
1. ContextualMoECalibration: uses context managers for temporary modifications
6+
2. PermanentMoECalibration: permanently modifies the model
7+
"""
8+
9+
import contextlib
10+
from abc import ABC, abstractmethod
11+
from typing import Dict, TypeVar, Union
12+
13+
from transformers import PreTrainedModel
14+
15+
T = TypeVar("T", bound="MoECalibrationContext")
16+
17+
18+
class MoECalibrationContext(ABC):
19+
"""
20+
Abstract base class for MoE calibration.
21+
This provides a standardized interface for MoE model calibration.
22+
"""
23+
24+
@abstractmethod
25+
def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None:
26+
"""
27+
Apply MoE calibration modifications to the model.
28+
:param model: The model to modify
29+
:param calibrate_all_experts: Whether to calibrate all
30+
experts or only routed ones
31+
"""
32+
pass
33+
34+
@abstractmethod
35+
def restore(self, model: PreTrainedModel) -> None:
36+
"""
37+
Restore the model to its original state.
38+
:param model: The model to restore
39+
"""
40+
pass
41+
42+
43+
class ContextualMoECalibration(MoECalibrationContext):
44+
"""
45+
MoE calibration that uses context managers for temporary modifications.
46+
This is suitable for models that need to be restored after calibration.
47+
"""
48+
49+
def __init__(self, model_class_name: str, update_function):
50+
"""
51+
Initialize the context manager-based MoE calibration.
52+
:param model_class_name: The class name of the model this context applies to
53+
:param update_function: Function that applies the MoE modifications
54+
"""
55+
self.model_class_name = model_class_name
56+
self.update_function = update_function
57+
self._stack = None
58+
59+
def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None:
60+
"""Apply MoE calibration modifications using context managers."""
61+
if self._stack is None:
62+
self._stack = contextlib.ExitStack()
63+
64+
self.update_function(model, self._stack, calibrate_all_experts)
65+
66+
def restore(self, model: PreTrainedModel) -> None:
67+
"""Restore the model by exiting the context stack."""
68+
if self._stack is not None:
69+
self._stack.close()
70+
self._stack = None
71+
72+
73+
class PermanentMoECalibration(MoECalibrationContext):
74+
"""
75+
MoE calibration context that permanently modifies the model.
76+
This is suitable for models that can be loaded in their modified form
77+
(e.g., Llama4 in vLLM).
78+
"""
79+
80+
def __init__(self, model_class_name: str, replacement_function):
81+
"""
82+
Initialize the permanent MoE calibration.
83+
:param model_class_name: The class name of the model this context applies to
84+
:param replacement_function: Function that permanently replaces MoE modules
85+
"""
86+
self.model_class_name = model_class_name
87+
self.replacement_function = replacement_function
88+
self._original_modules = {}
89+
90+
def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None:
91+
"""Apply permanent MoE calibration modifications."""
92+
# Store original modules for potential restoration
93+
for name, module in model.named_modules():
94+
if module.__class__.__name__ == self.model_class_name:
95+
self._original_modules[name] = module
96+
97+
# Apply the replacement
98+
self.replacement_function(model, calibrate_all_experts)
99+
100+
def restore(self, model: PreTrainedModel) -> None:
101+
"""Restore original modules (if needed)."""
102+
# For permanent MoE calibrations, restoration is typically not needed
103+
# as the model is meant to stay in its modified form
104+
pass
105+
106+
107+
# Registry for MoE calibrations
108+
_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {}
109+
110+
111+
def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None:
112+
"""
113+
Register a MoE calibration context for a model class.
114+
:param model_class_name: The class name of the model
115+
:param context: The MoE calibration context to register
116+
"""
117+
_MOE_CONTEXTS[model_class_name] = context
118+
119+
120+
def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]:
121+
"""
122+
Get the registered MoE calibration context for a model class.
123+
:param model_class_name: The class name of the model
124+
:return: The MoE calibration context or None if not found
125+
"""
126+
return _MOE_CONTEXTS.get(model_class_name)
127+
128+
129+
def list_supported_models() -> list:
130+
"""
131+
List all model classes that have registered MoE calibration contexts.
132+
:return: List of supported model class names
133+
"""
134+
return list(_MOE_CONTEXTS.keys())
135+
136+
137+
# Convenience function for backward compatibility
138+
def create_context_manager_context(model_class_name: str, update_function):
139+
"""
140+
Create a context manager-based MoE calibration.
141+
:param model_class_name: The class name of the model
142+
:param update_function: Function that applies the MoE modifications
143+
:return: A ContextualMoECalibration instance
144+
"""
145+
return ContextualMoECalibration(model_class_name, update_function)
146+
147+
148+
def create_permanent_context(model_class_name: str, replacement_function):
149+
"""
150+
Create a permanent MoE calibration.
151+
:param model_class_name: The class name of the model
152+
:param replacement_function: Function that permanently replaces MoE modules
153+
:return: A PermanentMoECalibration instance
154+
"""
155+
return PermanentMoECalibration(model_class_name, replacement_function)

0 commit comments

Comments
 (0)