-
Notifications
You must be signed in to change notification settings - Fork 249
[MoE Calibration] Simplify MoE calibration interface #1851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MoE Calibration] Simplify MoE calibration interface #1851
Conversation
👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review. Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed. |
@kylesayrs @dsikka Few clarifications:
|
7fefaac
to
ba42881
Compare
@sairampillai , regarding DCO, you can ignore that. We can sign it via github once reviewed/approved |
|
||
if dataset_args is not None and dataset_args.calibrate_moe_context: | ||
moe_calibration_context(model, stack) | ||
stack.enter_context(moe_calibration_context(model)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we want to do this for every case as not every model will be an MoE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is implemented along with a quick check to see if a particular model is added to the experts replacement list in modeling/prepare.py
as such here in modeling/prepare.py
:
This would end the need for any parameters in DatasetArgs, simplifying MoE calibration. Do you think there would be overhead when we do stack.enter_context()? Do you recommend a better way to implement this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fair to assume that a model is an MoE if its architecture is listed in the dictionary MOE_EXPERTS_REPLACEMENTS
, therefore entering the context in all cases is fine.
|
||
if dataset_args.calibrate_moe_context: | ||
moe_calibration_context(model, stack) | ||
stack.enter_context(moe_calibration_context(model)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
…illai/llm-compressor into moe_calibration_refactor
…illai/llm-compressor into moe_calibration_refactor
…illai/llm-compressor into moe_calibration_refactor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, but I worry that this implementation uses more abstraction than is necessary. I like the idea of "contextual" vs "permanent" changes, and we should definitely log which one is being used to the user.
Please consider simplifying to a single mapping dictionary, and a single ABC class to handle the from_original
and restore
functions. Don't be afraid to remove/ refactor existing code!
|
||
if dataset_args is not None and dataset_args.calibrate_moe_context: | ||
moe_calibration_context(model, stack) | ||
stack.enter_context(moe_calibration_context(model)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fair to assume that a model is an MoE if its architecture is listed in the dictionary MOE_EXPERTS_REPLACEMENTS
, therefore entering the context in all cases is fine.
|
||
if dataset_args is not None and dataset_args.calibrate_moe_context: | ||
moe_calibration_context(model, stack) | ||
stack.enter_context(moe_calibration_context(model)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a small nit, consider entering the moe context here. Entering the context before the pipeline call comes with some benefits
- We no longer need to enter the context for each pipeline explicitly
- We no longer need to enter, exit, and reenter in cases where multiple pipelines are composed (independent pipeline)
with moe_calibration_context(self.model):
pipeline(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! I will try this change and test it out
calibrate_all_experts: bool = True, | ||
) -> PreTrainedModel: | ||
# This function is deprecated. Use moe_calibration_context instead. | ||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use compressed_tensors/deprecated
moe_context.apply(model, calibrate_all_experts) | ||
|
||
try: | ||
yield model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this yield value is never used. Do you still want to include it?
default=512, | ||
metadata={"help": "Number of samples to use for one-shot calibration"}, | ||
) | ||
calibrate_moe_context: bool = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a dataset argument called moe_calibrate_all_experts
which defaults to True?
"Qwen3MoeForCausalLM": MoEModelConfig( | ||
calibration_type=MoECalibrationType.CONTEXTUAL, | ||
target_class_name="Qwen3MoeDecoderLayer", | ||
target_attribute="mlp", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we shouldn't need to ever target attributes, only target parent modules. For example, only targeting Qwen3MoeMLP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually unfamiliar, do we need to specify target_attribute
s?
|
||
|
||
# Registry for MoE calibrations | ||
_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having _MOE_CONTEXTS
, replacements
, and MOE_EXPERTS_REPLACEMENTS
is more than we need.
Additionally, mapping from model class names to replacement modules will run into issues for nested modules (for example, if an MoE model is inside of a nested model architecture, we will not be able to replace its modules).
Ideally we could simplify this to just one dictionary which maps module class names to their contexts. We can also use ABC to define from_original
and restore
interfaces. For example:
T = TypeVar("T")
class MoECalibrationModule(ABC):
is_permanent = False
@classmethod
def from_original(self, original: T) -> Self:
# converts from original module to moe calibration module
...
def restore(self) -> T:
# might include repacking weights/qparams into 3d structure, or not
...
from llmcompressor.modeling.deepseek_v3 import CalibrationDeepseekV3MoE
moe_modules: Dict[str, MoECalibrationModule] = {
"DeepseekV3MoE": CalibrationDeepseekV3MoE,
...
}
def moe_calibration_context(model, dataset_args):
for name, module in model.named_modules():
if module.__class__.__name__ in moe_modules:
replacement = moe_modules[module.__class__.__name__].from_original(module, dataset_args.calibrate_all_experts)
model.set_submodule(name, replacement)
# ... maybe some logging about if/which modules were replaced
# maybe some logging about if the structure will stay `is_permanent`
yield
for name, module in model.named_modules():
if isinstance(module, MoECalibrationModule):
original = module.restore()
model.set_submodule(name, original)
While keying by module class names rather than by model class names requires iterating through all of the modules in the model in order to check, I think this is minimal overhead and acceptable in order to support nested models. What do you think @dsikka?
self.calibration_type == MoECalibrationType.CONTEXTUAL | ||
and self.target_attribute is None | ||
): | ||
raise ValueError("target_attribute is required for contextual calibration") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this coupling necessary? Why do we need to specify an attribute at all?
return update_function | ||
|
||
|
||
def register_moe_model(model_class_name: str, config: MoEModelConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be more abstraction than is necessary
Introduce standardized MoE calibration interface and deprecate legacy
replace_modules_for_calibration
Summary
Implements a standardized
MoeContextCalibration
class and simplified registration interface for MoE model calibration, making MoE model integration easier and deprecating the legacyreplace_modules_for_calibration
function.Problem
MoE model calibration currently requires module replacement logic scattered across
replace_modules_for_calibration
andmoe_calibration_context
. This makes contributing new MoE model support difficult. Additionally, theDatasetArgs.calibrate_moe_context
parameter created confusion by being optional when MoE calibration should always execute by default.Relevant Issues
Fixes #1829
Solution
MoeContextCalibration
abstract base class withContextualMoECalibration
andPermanentMoECalibration
implementationsMoEModelConfig
dataclass and automatic registration systemreplace_modules_for_calibration
with warningscalibrate_moe_context
parameter - MoE context is handled automatically by pipelinesTest Plan
Testing
✅ All unit tests pass
✅ Contextual and permanent calibration types working correctly
✅ Model structure correctly changed and restored inside/outside contexts
✅ Linting and type checking pass
✅ Backward compatibility verified with deprecation warnings