diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 292fa8300..c7838bfbf 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,18 +3,22 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. +# +# NOTE: This restructuring is specifically required for vLLM compatibility +# Users can customize the calibration behavior as needed by modifying the +# To define custom calibration logic, implement your function in +# modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). +# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your +# custom function. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/README.md b/examples/quantization_w4a4_fp4/README.md index ab9e3eb37..a0d458722 100644 --- a/examples/quantization_w4a4_fp4/README.md +++ b/examples/quantization_w4a4_fp4/README.md @@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model! # Quantizing MoEs -To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to: +To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which: -1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. -2. Ensure experts are quantized correctly as not all experts are activated during calibration +1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. +2. Ensures experts are quantized correctly as not all experts are activated during calibration -Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. +Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 28b57dda9..3a52986ad 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 9977584c3..9e5d1ca63 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -2,7 +2,6 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -20,7 +19,9 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `DeepseekV3MoECalibrate` modules will be applied during calibration +# to enable proper expert calibration. # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index da816584d..e3e92fdf4 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -7,6 +7,7 @@ HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ +import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union @@ -126,16 +127,6 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) - calibrate_moe_context: bool = field( - default=False, - metadata={ - "help": "If during calibration, the MoE context should be enabled " - "for the given model. This usually involves updating all MoE modules " - "in the model for the duration of calibration. See moe_context under " - "modeling/prepare.py for a list of supported MoEs and their updated " - "module definitions" - }, - ) shuffle_calibration_samples: Optional[bool] = field( default=True, metadata={ @@ -181,6 +172,17 @@ class DatasetArguments(CustomDatasetArguments): ), }, ) + calibrate_moe_context: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "DEPRECATED: This parameter is deprecated and will be \ + removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect the calibration process." + ), + }, + ) # --- pipeline arguments --- # pipeline: Optional[str] = field( default="independent", @@ -229,3 +231,16 @@ class DatasetArguments(CustomDatasetArguments): def is_dataset_provided(self) -> bool: return self.dataset is not None or self.dataset_path is not None + + def __post_init__(self): + """Post-initialization hook to issue deprecation warnings.""" + if self.calibrate_moe_context is not None: + warnings.warn( + "The 'calibrate_moe_context' parameter is deprecated\ + and will be removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect\ + the calibration process.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py new file mode 100644 index 000000000..9de74c20a --- /dev/null +++ b/src/llmcompressor/modeling/moe_context.py @@ -0,0 +1,304 @@ +""" +Standardized interface for MoE model calibration. +MoE calibration context is used to apply MoE calibration modifications to the model. +There are two types of MoE calibration contexts: +1. ContextualMoECalibration: uses context managers for temporary modifications + and restores the model to its original state after pipeline execution +2. PermanentMoECalibration: permanently modifies the model and stays in its modified + form after pipeline execution +""" + +import contextlib +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, Optional, TypeVar, Union + +import tqdm +from compressed_tensors.utils import replace_module +from transformers import PreTrainedModel + +from llmcompressor.utils.helpers import patch_attr + +T = TypeVar("T", bound="MoECalibrationContext") + + +class MoECalibrationType(Enum): + """Enumeration of supported MoE calibration types.""" + + PERMANENT = "permanent" + CONTEXTUAL = "contextual" + + +@dataclass +class MoEModelConfig: + """ + Configuration for MoE model calibration. + + This dataclass defines the parameters needed to configure MoE calibration + for a specific model architecture. It follows the same pattern used by + other model configuration systems in the project (e.g., SmoothQuant, AWQ). + + Attributes: + calibration_type: Type of calibration - MoECalibrationType.PERMANENT or + MoECalibrationType.CONTEXTUAL + target_class_name: The class name of the MoE module to replace + replace_function: Function that creates the replacement module + generally defined in modeling/model_name.py + target_attribute: For contextual calibration, the attribute to replace + description: Optional description of the model configuration + """ + + calibration_type: MoECalibrationType + target_class_name: str + replace_function: Callable + target_attribute: Optional[str] = None + description: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if ( + self.calibration_type == MoECalibrationType.CONTEXTUAL + and self.target_attribute is None + ): + raise ValueError("target_attribute is required for contextual calibration") + + if ( + self.calibration_type == MoECalibrationType.PERMANENT + and self.target_attribute is not None + ): + raise ValueError( + "target_attribute should not be set for permanent calibration" + ) + + +# Registry of MoE model configurations +# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY +MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} + + +class MoECalibrationContext(ABC): + """ + Abstract base class for MoE calibration. + This provides a standardized interface for MoE model calibration. + """ + + @abstractmethod + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """ + Apply MoE calibration modifications to the model. + :param model: The model to modify + :param calibrate_all_experts: Whether to calibrate all + experts or only routed ones + """ + pass + + @abstractmethod + def restore(self, model: PreTrainedModel) -> None: + """ + Restore the model to its original state. + :param model: The model to restore + """ + pass + + +class ContextualMoECalibration(MoECalibrationContext): + """ + MoE calibration that uses context managers for temporary modifications. + This is suitable for models that need to be restored after calibration. + """ + + def __init__(self, model_class_name: str, update_function): + """ + Initialize the context manager-based MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param update_function: Function that applies the MoE modifications + """ + self.model_class_name = model_class_name + self.update_function = update_function + self._stack = None + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply MoE calibration modifications using context managers.""" + if self._stack is None: + self._stack = contextlib.ExitStack() + self._stack.__enter__() + + self.update_function(model, self._stack, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore the model by exiting the context stack.""" + if self._stack is not None: + self._stack.__exit__(None, None, None) + self._stack = None + + +class PermanentMoECalibration(MoECalibrationContext): + """ + MoE calibration context that permanently modifies the model. + This is suitable for models that can be loaded in their modified form + (e.g., Llama4 in vLLM). + """ + + def __init__(self, model_class_name: str, replacement_function): + """ + Initialize the permanent MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param replacement_function: Function that permanently replaces MoE modules + """ + self.model_class_name = model_class_name + self.replacement_function = replacement_function + self._original_modules = {} + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply permanent MoE calibration modifications.""" + # Store original modules for potential restoration + for name, module in model.named_modules(): + if module.__class__.__name__ == self.model_class_name: + self._original_modules[name] = module + + # Apply the replacement + self.replacement_function(model, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore original modules (if needed).""" + # For permanent MoE calibrations, restoration is typically not needed + # as the model is meant to stay in its modified form + pass + + +# Registry for MoE calibrations +_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} + + +def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: + """ + Register a MoE calibration context for a model class. + :param model_class_name: The class name of the model + :param context: The MoE calibration context to register + """ + _MOE_CONTEXTS[model_class_name] = context + + +def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: + """ + Get the registered MoE calibration context for a model class. + :param model_class_name: The class name of the model + :return: The MoE calibration context or None if not found + """ + return _MOE_CONTEXTS.get(model_class_name) + + +def list_supported_models() -> list: + """ + List all model classes that have registered MoE calibration contexts. + :return: List of supported model class names + """ + return list(_MOE_CONTEXTS.keys()) + + +# Generic factory functions for creating MoE updaters +def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): + """ + Create a permanent MoE updater function for the given target class. + + Args: + target_class_name: The class name to look for in the model + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with PermanentMoECalibration + """ + + def update_function(model: PreTrainedModel, calibrate_all_experts: bool): + """Update MoE modules for calibration.""" + for name, module in tqdm.tqdm(list(model.named_modules())): + if module.__class__.__name__ == target_class_name: + new_module = replace_function( + config=model.config, + module=module, + calibrate_all_experts=calibrate_all_experts, + ) + replace_module(model, name, new_module) + + return update_function + + +def create_contextual_moe_updater( + target_class_name: str, target_attr: str, replace_function: Callable +): + """ + Create a contextual MoE updater function for the given target class and attribute. + + Args: + target_class_name: The class name to look for in the model + target_attr: The attribute name to replace within the target class + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with ContextualMoECalibration + """ + + def update_function( + model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool + ): + """Update MoE modules for calibration using context managers.""" + for module in model.modules(): + if module.__class__.__name__ == target_class_name: + stack.enter_context( + patch_attr( + module, + target_attr, + replace_function( + config=model.config, + module=getattr(module, target_attr), + calibrate_all_experts=calibrate_all_experts, + ), + ) + ) + + return update_function + + +def register_moe_model(model_class_name: str, config: MoEModelConfig): + """ + Register a MoE model with its configuration. + + Args: + model_class_name: The model class name + config: MoEModelConfig dataclass instance with calibration parameters + """ + if config.calibration_type == MoECalibrationType.PERMANENT: + updater = create_permanent_moe_updater( + config.target_class_name, config.replace_function + ) + context = PermanentMoECalibration(config.target_class_name, updater) + elif config.calibration_type == MoECalibrationType.CONTEXTUAL: + updater = create_contextual_moe_updater( + config.target_class_name, config.target_attribute, config.replace_function + ) + context = ContextualMoECalibration(model_class_name, updater) + else: + raise ValueError(f"Unknown MoE type: {config.calibration_type}") + + register_moe_context(model_class_name, context) + + +def register_moe_model_from_dict(model_class_name: str, config_dict: dict): + """ + Register a MoE model from a dictionary configuration (backward compatibility). + + Args: + model_class_name: The model class name + config_dict: Dictionary with calibration parameters + """ + # Convert string calibration_type to enum + if "calibration_type" in config_dict and isinstance( + config_dict["calibration_type"], str + ): + config_dict["calibration_type"] = MoECalibrationType( + config_dict["calibration_type"] + ) + + config = MoEModelConfig(**config_dict) + register_moe_model(model_class_name, config) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index e966761bd..d3c985535 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,13 +1,21 @@ +import contextlib +import warnings + import tqdm from compressed_tensors.utils import replace_module from transformers import PreTrainedModel from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.moe_context import ( + MoECalibrationType, + MoEModelConfig, + get_moe_context, + register_moe_model, +) from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -from llmcompressor.utils.helpers import patch_attr -__all__ = ["replace_modules_for_calibration"] +__all__ = ["moe_calibration_context"] # ---------------------- module replacements; permanent ------------------------- replacements = { @@ -20,6 +28,14 @@ def replace_modules_for_calibration( model: PreTrainedModel, calibrate_all_experts: bool = True, ) -> PreTrainedModel: + # This function is deprecated. Use moe_calibration_context instead. + warnings.warn( + "replace_modules_for_calibration is deprecated. " + "Use moe_calibration_context instead.", + DeprecationWarning, + stacklevel=2, + ) + for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: @@ -35,37 +51,67 @@ def replace_modules_for_calibration( # ------------------- module replacements; during calibration -------------------- - -def update_qwen3_moe(model, stack, calibrate_all_experts): - for module in model.modules(): - cls_name = module.__class__.__name__ - if cls_name == "Qwen3MoeDecoderLayer": - # Optionally update the model.config to pass in other arguments - stack.enter_context( - patch_attr( - module, - "mlp", - replace_Qwen3MoE( - config=model.config, - module=module.mlp, - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) +# MoE model configurations - centralized registry +# Adding a new MoE model is now as simple as adding an entry here! +# This follows the same pattern as MAPPINGS_REGISTRY in SmoothQuant and AWQ +MOE_EXPERTS_REPLACEMENTS = { + "Qwen3MoeForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.CONTEXTUAL, + target_class_name="Qwen3MoeDecoderLayer", + target_attribute="mlp", + replace_function=replace_Qwen3MoE, + description="Qwen3 MoE model with contextual calibration for MLP layers", + ), + "DeepseekV3ForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="DeepseekV3MoE", + replace_function=replace_deepseekv3, + description="DeepSeek V3 MoE model with permanent calibration", + ), + "Llama4ForConditionalGeneration": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="Llama4TextMoe", + replace_function=replace_llama4, + description=( + "Llama4 MoE model with permanent calibration for vLLM compatibility" + ), + ), +} -moe_context = { - "Qwen3MoeForCausalLM": update_qwen3_moe, -} +# Register all MoE models automatically +for model_class_name, config in MOE_EXPERTS_REPLACEMENTS.items(): + register_moe_model(model_class_name, config) +@contextlib.contextmanager def moe_calibration_context( model: PreTrainedModel, - stack, calibrate_all_experts: bool = True, ): - # Temporarily updates the MoE modules within the context - # Once the context exists, parameter updates persist + """ + Context manager for MoE calibration that temporarily updates MoE modules. + + Args: + model: The model to apply MoE calibration to + calibrate_all_experts: Whether to calibrate all experts or only routed ones + + Yields: + The model with MoE calibration applied + """ cls_name = model.__class__.__name__ - if cls_name in moe_context: - moe_context.get(cls_name)(model, stack, calibrate_all_experts) + moe_context = get_moe_context(cls_name) + + if moe_context is None: + # No MoE context registered for this model, yield unchanged + yield model + return + + # Apply MoE calibration + moe_context.apply(model, calibrate_all_experts) + + try: + yield model + finally: + # Restore original state + moe_context.restore(model) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index db4b15305..edcb46b09 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -46,9 +46,7 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - - if dataset_args is not None and dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 565cb81fb..bab732f48 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -82,8 +82,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index f00a37e43..9827052b7 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -85,8 +85,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index ca7fb06af..239365fdd 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial import pytest @@ -9,7 +10,7 @@ DeepseekV3MoECalibrate, OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -21,39 +22,43 @@ def test_calib_replace_deepseekv3moe_all_experts(model_stub): with skip_weights_download(): model = AutoModelForCausalLM.from_pretrained(model_stub) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Deepseek MoE layer - moe_layer = None - for _, module in model.named_modules(): - if isinstance(module, DeepseekV3MoECalibrate): - moe_layer = module - break + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 4eb609ca9..d5363d35c 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -1,3 +1,4 @@ +import contextlib import os from functools import partial @@ -10,7 +11,7 @@ Llama4TextMoe, SequentialLlama4TextMoe, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -28,39 +29,43 @@ def test_calib_replace_llama4_moe_all_experts(model_stub): model_stub, torch_dtype="auto" ) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Llama4 MoE layer - moe_layer = None - for module in model.modules(): - if isinstance(module, SequentialLlama4TextMoe): - moe_layer = module - break + # Find a Llama4 MoE layer + moe_layer = None + for module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.text_config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.text_config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index 822af18db..a20acf8a8 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -26,8 +26,7 @@ def test_calib_replace_qwen3moe_all_experts(model_stub): with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) stack.enter_context(DisableQuantization(model)) - - moe_calibration_context(model, stack, calibrate_all_experts=True) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) # Find one MoE layer moe_layer = None