-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathmoe_context.py
More file actions
201 lines (162 loc) · 6.71 KB
/
moe_context.py
File metadata and controls
201 lines (162 loc) · 6.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
Simplified interface for MoE model calibration.
MoE (Mixture of Experts) models route tokens to different expert networks.
During calibration for quantization/compression, we need to ensure ALL experts
see data, not just the ones selected by the router. This module provides the
infrastructure to temporarily modify MoE modules for proper calibration.
Key components:
- MoECalibrationModule: Abstract base class for calibration modules
- moe_calibration_context: Context manager that applies calibration to a model
"""
import contextlib
from abc import ABC
import torch
import torch.distributed as dist
from compressed_tensors.offload import (
get_cache_init_kwargs,
is_distributed,
)
from compressed_tensors.offload.cache import OffloadCache
from compressed_tensors.offload.module import offload_module
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
from loguru import logger
from tqdm import tqdm
from transformers import PreTrainedModel
__all__ = [
"MoECalibrationModule",
"moe_calibration_context",
]
class MoECalibrationModule(ABC, torch.nn.Module, RegistryMixin):
"""
Abstract base class for MoE calibration modules.
Calibration modules replace original MoE modules during the calibration
phase to ensure all experts receive data for proper quantization statistics.
Subclasses must:
1. Implement `__init__()` with signature:
(self, original, config, calibrate_all_experts=True)
2. Set `is_permanent` to indicate if module should stay in calibration form
3. Optionally implement `restore()` if is_permanent=False
"""
is_permanent: bool = False
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
"""
Restore the original module structure.
Only needed if is_permanent=False. For permanent modules, this is a no-op.
Returns:
The original module (or self if permanent)
"""
if self.is_permanent:
return self
raise NotImplementedError(
f"{self.__class__.__name__} has is_permanent=False but doesn't "
"implement restore()"
)
@contextlib.contextmanager
def moe_calibration_context(
model: PreTrainedModel,
calibrate_all_experts: bool = True,
):
"""
Context manager that applies MoE calibration to a model.
This scans all modules in the model and replaces any MoE modules with their
calibration equivalents. After the context exits, non-permanent modules are
restored to their original form.
The model is modified in-place, so the same model object should be used
within the context.
Args:
model: The model to apply MoE calibration to (modified in-place)
calibrate_all_experts: If True, all experts see all tokens during calibration.
If False, use normal routing (useful for some techniques)
Example:
with moe_calibration_context(model):
# Run calibration - all experts will see data
for batch in dataloader:
model(**batch)
# Model is now restored (unless permanent)
"""
replaced = {}
# Step 1: Collect all MoE modules that need replacement
logger.debug("Entering MoE calibration context")
modules_to_replace = []
for name, module in model.named_modules():
class_name = module.__class__.__name__
if _is_registered(class_name, MoECalibrationModule):
modules_to_replace.append((name, module, class_name))
# Step 2: Replace modules with progress bar
if modules_to_replace:
logger.info(f"Found {len(modules_to_replace)} MoE modules to replace")
for name, module, class_name in tqdm(
modules_to_replace, desc="Replacing MoE modules for calibration"
):
replacement = MoECalibrationModule.load_from_registry(
class_name,
original=module,
config=model.config,
calibrate_all_experts=calibrate_all_experts,
)
# Apply the same offloading settings as the original module
_apply_offloading_to_replacement(module, replacement)
model.set_submodule(name, replacement)
# Only store original if we need to restore it later
if replacement.is_permanent:
replaced[name] = (None, replacement)
del module
else:
replaced[name] = (module, replacement)
if is_distributed():
dist.barrier()
# Log what was replaced
if replaced:
logger.info(f"Replaced {len(replaced)} MoE modules for calibration")
permanent_count = sum(
1 for _, (_, repl) in replaced.items() if repl.is_permanent
)
if permanent_count > 0:
logger.info(
f"{permanent_count}/{len(replaced)} modules will remain in "
"calibration form (permanent)"
)
if permanent_count < len(replaced):
logger.info(
f"{len(replaced) - permanent_count}/{len(replaced)} modules will "
"be restored after calibration"
)
try:
yield
finally:
# Step 2: Restore non-permanent modules
for name, (original, replacement) in replaced.items():
if not replacement.is_permanent:
restored = replacement.restore(original)
model.set_submodule(name, restored)
def _is_registered(name: str, subclass: RegistryMixin):
return standardize_lookup_name(name) in subclass.registered_names()
def _find_ancestor_with_offload_cache(module):
if isinstance(module._parameters, OffloadCache):
return module
for child in module.children():
child_val = _find_ancestor_with_offload_cache(child)
if child_val is not None:
return child_val
return None
def _apply_offloading_to_replacement(
original: torch.nn.Module, replacement: torch.nn.Module
):
"""
Apply the same offloading configuration from original to replacement module.
If the original module or ANY of its children use OffloadCache, this recursively
applies the same offloading settings to all submodules of the replacement that
have parameters.
"""
module_with_cache = _find_ancestor_with_offload_cache(original)
if module_with_cache is None:
return
kwargs = get_cache_init_kwargs(module_with_cache)
# Apply offloading to all submodules that have parameters
# and are not already offloaded
for module in replacement.modules():
if isinstance(module._parameters, OffloadCache):
continue
if len(list(module.parameters(recurse=False))) == 0:
continue
offload_module(module, **kwargs)