diff --git a/docs/.nav.yml b/docs/.nav.yml index fe03240b0f..8f2e576e59 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -39,6 +39,11 @@ nav: - Observers: guides/observers.md - Memory Requirements: guides/memory.md - Runtime Performance: guides/runtime.md + - Developer Guides: + - developer-tutorials/index.md + - Adding a New Modifier: developer-tutorials/add-modifier.md + - Adding a New Observer: developer-tutorials/add-observer.md + - Adding MoE Calibration Support for a New Model: developer-tutorials/add-moe-support.md - Examples: - examples/README.md - examples/* diff --git a/docs/developer-tutorials/add-modifier.md b/docs/developer-tutorials/add-modifier.md new file mode 100644 index 0000000000..53fa85029c --- /dev/null +++ b/docs/developer-tutorials/add-modifier.md @@ -0,0 +1,255 @@ +# Adding a New Modifier + +Modifiers are the core extension point in LLM Compressor. Each compression algorithm — GPTQ, AWQ, SmoothQuant, and others — is implemented as a modifier. This tutorial walks through the modifier contract, lifecycle, and how to implement a custom one. + +## What is a Modifier? + +A modifier is a Pydantic model that hooks into the compression pipeline at well-defined lifecycle points. When you call `oneshot`, LLM Compressor: + +1. Instantiates modifiers from the recipe +2. Calls `initialize` on each modifier +3. Runs calibration batches, firing `Event`s that modifiers respond to +4. Calls `finalize` on each modifier + +Modifiers express what they want to do at each stage by overriding lifecycle hooks. + +## How Events Work + +Not all lifecycle hooks are driven by events. `on_initialize` and `on_finalize` are called directly by `CompressionLifecycle` — before and after the pipeline runs respectively. Everything in between is event-driven. + +| Hook | Called by | +|------|-----------| +| `on_initialize` | `CompressionLifecycle.initialize()` | +| `on_event` / `on_start` / `on_update` / `on_end` | `CompressionLifecycle.event()` → `Modifier.update_event()` | +| `on_finalize` | `CompressionLifecycle.finalize()` | + +The pipeline fires events by calling methods on `LifecycleCallbacks` (aliased as `callbacks`), which routes them through the active session into `CompressionLifecycle.event()`. Modifiers never fire events themselves — they only react to them. + +## The Modifier Contract + +All modifiers subclass `llmcompressor.modifiers.Modifier` and must implement `on_initialize`. All other hooks are optional. + +```python +from llmcompressor.modifiers import Modifier +from llmcompressor.core import State, Event + +class MyModifier(Modifier): + # Pydantic fields — declare your parameters here + my_param: float = 1.0 + + def on_initialize(self, state: State, **kwargs) -> bool: + # Called once before calibration begins. + # Set up hooks, attach attributes to modules, etc. + # Return True if initialization succeeded. + ... + return True + + def on_start(self, state: State, event: Event, **kwargs): + # Called when calibration starts. + # The base class dispatches this on the first BATCH_START event, but in + # practice most modifiers trigger it themselves from on_event by checking + # for CALIBRATION_EPOCH_START (see note below). + ... + + def on_update(self, state: State, event: Event, **kwargs): + # Called on every event while the modifier is active (between on_start and + # on_end). Rarely needed — only useful for per-batch callbacks such as + # dynamic pruning schedules. Compression modifiers (GPTQ, AWQ, etc.) do + # not use this hook. + ... + + def on_end(self, state: State, event: Event, **kwargs): + # Called when calibration ends. + # The base class dispatches this on BATCH_END, but in practice all + # modifiers call it manually from on_event on CALIBRATION_EPOCH_END. + ... + + def on_event(self, state: State, event: Event, **kwargs): + # Called on every event, unconditionally, before on_start/on_update/on_end + # dispatch. Override to respond to specific EventTypes such as + # CALIBRATION_EPOCH_START or SEQUENTIAL_EPOCH_END that fall outside + # the BATCH_START / BATCH_END pattern. + ... + + def on_finalize(self, state: State, **kwargs) -> bool: + # Called after calibration completes. + # Clean up hooks, apply final transformations, etc. + # Return True if finalization succeeded. + ... + return True +``` + +### Lifecycle Summary + +| Hook | When it runs | Required | +|------|-------------|----------| +| `on_initialize` | Once, before calibration | Yes | +| `on_event` | Every event, unconditionally (before start/update/end dispatch) | No | +| `on_start` | First `BATCH_START` event (base class); most modifiers call it manually from `on_event` on `CALIBRATION_EPOCH_START` | No | +| `on_update` | Every event while active (between `on_start` and `on_end`); rarely used outside of pruning modifiers | No | +| `on_end` | `BATCH_END` when modifier ends (base class); in practice all modifiers call it manually from `on_event` on `CALIBRATION_EPOCH_END` | No | +| `on_finalize` | Once, after calibration | No | + +> **Note on `on_start` / `on_end` vs `on_event`:** The base class dispatches `on_start` on the first `BATCH_START` event and `on_end` on `BATCH_END`. However, all built-in modifiers (GPTQ, AWQ, SmoothQuant, SparseGPT, etc.) bypass this by overriding `on_event` and calling `self.on_start()` / `self.on_end()` themselves on `CALIBRATION_EPOCH_START` / `CALIBRATION_EPOCH_END`. If you are writing a new modifier, follow this pattern. + +### The `State` Object + +`state.model` gives you the `torch.nn.Module` being compressed. This is the primary object you will interact with in most hooks. + +### Pydantic Parameters + +Because `Modifier` is a Pydantic model with `extra="forbid"`, all parameters must be declared as class-level fields. This also means your modifier can be instantiated directly in Python or from a YAML recipe. + +```python +class MyModifier(Modifier): + targets: list[str] = ["Linear"] + scale_factor: float = 0.5 + ignore: list[str] = [] +``` + +## Attaching Hooks with `HooksMixin` + +`Modifier` inherits from `HooksMixin`, which provides a managed way to register PyTorch forward hooks. Hooks registered through `HooksMixin` are automatically removed when `finalize` is called. + +```python +from llmcompressor.modifiers import Modifier +from llmcompressor.core import State + +class MyModifier(Modifier): + def on_initialize(self, state: State, **kwargs) -> bool: + for name, module in state.model.named_modules(): + if "Linear" in type(module).__name__: + self.register_hook( + module, + self._forward_hook, + "forward", + ) + return True + + def _forward_hook(self, module, inputs, output): + # Runs after every forward pass through this module + ... +``` + +## Example: A Weight-Clamping Modifier + +The following modifier clamps the stored weight tensors of all `Linear` layers to a fixed absolute magnitude. It follows the real modifier pattern by handling both pipelines in `on_event`: when using the sequential pipeline it clamps weights layer-by-layer as each subgraph completes (`SEQUENTIAL_EPOCH_END`), and when using the basic pipeline it clamps all weights at once at the end of calibration (`CALIBRATION_EPOCH_END`). + +```python +import torch +from compressed_tensors.utils import match_named_modules +from llmcompressor.modifiers import Modifier +from llmcompressor.core import State, Event, EventType + +class WeightClampModifier(Modifier): + """ + Clamps the magnitude of Linear layer weight tensors to a maximum absolute + value. Applied layer-by-layer on SEQUENTIAL_EPOCH_END (sequential pipeline) + or all at once on CALIBRATION_EPOCH_END (basic pipeline). + + :param max_weight_magnitude: maximum allowed absolute value for any weight + :param targets: module types to target + :param ignore: module names to skip + """ + + max_weight_magnitude: float = 1.0 + targets: list[str] = ["Linear"] + ignore: list[str] = [] + + def on_initialize(self, state: State, **kwargs) -> bool: + if self.max_weight_magnitude <= 0: + raise ValueError("max_weight_magnitude must be positive") + + # Verify that at least one target module exists in the model + matched = list(match_named_modules(state.model, self.targets, self.ignore)) + if not matched: + raise ValueError( + f"No modules matched targets={self.targets} ignore={self.ignore}" + ) + + self._clamped: set[str] = set() + return True + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, event) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + # Sequential pipeline: clamp weights for the just-finished subgraph + subgraph = kwargs.get("subgraph") + if subgraph is not None: + self._clamp_modules(state, modules=subgraph.modules()) + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + # Basic pipeline: clamp any modules not yet handled + self._clamp_modules(state) + if not self.ended_: + self.on_end(state, event) + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def _clamp_modules(self, state: State, modules=None): + for name, module in match_named_modules( + state.model, self.targets, self.ignore + ): + if name in self._clamped: + continue + if modules is not None and module not in modules: + continue + with torch.no_grad(): + module.weight.clamp_( + -self.max_weight_magnitude, + self.max_weight_magnitude, + ) + self._clamped.add(name) + + def on_finalize(self, state: State, **kwargs) -> bool: + self._clamped.clear() + return True +``` + +### Using the Modifier with `oneshot` + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from llmcompressor import oneshot + +model = AutoModelForCausalLM.from_pretrained("your-model") +tokenizer = AutoTokenizer.from_pretrained("your-model") + +oneshot( + model=model, + recipe=[WeightClampModifier(max_weight_magnitude=0.5, targets=["Linear"], ignore=["lm_head"])], +) + +model.save_pretrained("your-model-clamped", save_compressed=True) +tokenizer.save_pretrained("your-model-clamped") +``` + +### Using the Modifier from a YAML Recipe + +```yaml +weight_clamp_stage: + weight_clamp_modifiers: + WeightClampModifier: + max_weight_magnitude: 0.5 + targets: + - Linear + ignore: + - lm_head + +``` + +## Tips + +- **Only override what you need.** The default implementations of `on_start`, `on_update`, `on_end`, and `on_finalize` are no-ops or return `True` — you do not need to call `super()` for these. +- **Use `match_named_modules`** (from `compressed_tensors.utils`) to filter modules by type name or path pattern, consistent with how other modifiers handle `targets` and `ignore`. +- **Keep `on_initialize` lightweight.** Expensive operations (e.g., full-model passes) should be deferred to `on_start` or `on_finalize`. +- **Prefer `on_event` over `on_start` for epoch-level control.** Most modifiers override `on_event` and call `self.on_start()` manually on `CALIBRATION_EPOCH_START` rather than relying on the base class `BATCH_START` dispatch. +- **`on_update` is rarely needed.** Only override it if you need a per-batch callback while the modifier is active — e.g., `MagnitudeModifier` uses it to update sparsity each batch. Compression modifiers (GPTQ, AWQ, SmoothQuant, etc.) do not use it. +- **Modifiers never fire events — the pipeline does.** All lifecycle events (`CALIBRATION_EPOCH_START`, `BATCH_START`, `SEQUENTIAL_EPOCH_END`, etc.) are fired by the calibration pipeline. Your modifier only reacts to them. The sequential pipeline additionally fires `SEQUENTIAL_EPOCH_END` between layer groups, which modifiers like GPTQ and SparseGPT use to trigger compression layer-by-layer. diff --git a/docs/developer-tutorials/add-moe-support.md b/docs/developer-tutorials/add-moe-support.md new file mode 100644 index 0000000000..400c36b206 --- /dev/null +++ b/docs/developer-tutorials/add-moe-support.md @@ -0,0 +1,241 @@ +# Adding MoE Calibration Support for a New Model + +Mixture of Experts (MoE) models route each token to only a subset of expert layers. This creates a calibration problem: experts that are not activated for a given token never see calibration data, which can result in poorly calibrated quantization parameters, numerical instability, or NaNs. + +LLM Compressor solves this by replacing MoE modules with calibration-friendly versions that route all tokens through all experts during calibration, while keeping only the routed expert outputs for the final result. + +For background, see [Quantizing MoEs with a custom definition](../../examples/quantizing_moe/README.md#quantizing-moes-with-a-custom-definition). + +## When Do You Need This? + +You need a calibration module definition when: + +- Quantizing with a **data-dependent algorithm** (GPTQ, AWQ, AutoRound) on an MoE model +- Quantizing with **static activation quantization** (FP8 per-tensor, INT8 per-tensor, NVFP4) on an MoE model + +Simple weight-only data-free quantization (e.g., RTN W4A16) does not require calibration data and is not affected. + +## The MoECalibrationModule Contract + +All MoE calibration modules subclass `MoECalibrationModule` and must: + +1. Be decorated with `@MoECalibrationModule.register("OriginalClassName")` where `OriginalClassName` is the exact class name of the MoE block being replaced +2. Implement `__init__(self, original, config, calibrate_all_experts=True)` accepting the original module instance +3. Implement `forward()` with the same input/output signature as the original, routing all tokens through all experts when `calibrate_all_experts=True` +4. Set `is_permanent` to control whether the module is restored after calibration + +If `is_permanent=True`, the calibration module stays in place after calibration and is used for inference. This is necessary when the model's expert weights are stored in a packed format (e.g., a single 3D tensor) that must be unpacked for per-expert quantization and vLLM compatibility. If `is_permanent=False`, implement `restore(original)` to return the original module after calibration. + +```python +import torch +from llmcompressor.modeling.moe_context import MoECalibrationModule + + +@MoECalibrationModule.register("MyModelMoE") # exact class name from transformers +class CalibrationMyModelMoE(MoECalibrationModule): + + is_permanent = True # stays in place for vLLM compatibility + + def __init__(self, original, config, calibrate_all_experts: bool = True): + super().__init__() + self.experts = ... # unpack or copy experts from original + self.router = original.router + self.calibrate_all_experts = calibrate_all_experts + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + ... +``` + +## The `forward` Pattern + +The key behavior difference between normal MoE routing and calibration routing: + +- **Normal routing**: only tokens selected by the router run through each expert +- **Calibration routing**: all tokens run through every expert (but only the routed tokens contribute to the output) + +The Llama4 pattern — where the router returns separate scores and logits and a shared expert always runs on all tokens: + +```python +def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_scores, router_logits = self.router(hidden_states) + out = self.shared_expert(hidden_states) # always runs on all tokens + + _, router_indices = torch.topk(router_logits, self.top_k, dim=1) + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=self.num_experts + ).permute(2, 1, 0) # (num_experts, top_k, batch_size * seq_len) + + for i in range(self.num_experts): + token_idx = torch.where(expert_mask[i].squeeze(0)) + + if self.calibrate_all_experts: + # Run ALL tokens through the expert to collect calibration statistics. + # Only the routed tokens contribute to the output. + expert_out = self.experts[i](hidden_states)[token_idx] + else: + expert_out = self.experts[i](hidden_states[token_idx]) + + if len(token_idx) > 0: + weighted_output = expert_out * router_scores[:, i][token_idx].reshape(-1, 1) + out[token_idx] += weighted_output + + return out, router_logits +``` + +!!! note + The routing scores are applied to the expert **output** rather than the input. Applying scores to the input before passing to the expert can produce NaNs during calibration. + +## Example: Llama4 + +The existing `SequentialLlama4TextMoe` (in `src/llmcompressor/modeling/llama4.py`) is the canonical reference implementation. It registers as a replacement for `Llama4TextMoe` and handles a key Llama4-specific detail: expert weights are stored as a single packed 3D tensor (`gate_up_proj` of shape `(num_experts, hidden, 2*intermediate)`) which must be unpacked into individual `Llama4TextMLP` modules for per-expert calibration. + +This is handled by a helper class `SequentialLlama4TextExperts` that converts the packed tensor into a `ModuleList` of unpacked experts: + +```python +class SequentialLlama4TextExperts(torch.nn.ModuleList): + def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts): + self.num_experts = original.gate_up_proj.shape[0] + with skip_weights_initialize(): + super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) + + for i in range(self.num_experts): + gate_up = original.gate_up_proj[i] + down = original.down_proj[i] + gate_proj, up_proj = gate_up.chunk(2, dim=-1) + + self[i].gate_proj.weight.data = gate_proj.t().contiguous() + self[i].up_proj.weight.data = up_proj.t().contiguous() + self[i].down_proj.weight.data = down.t().contiguous() +``` + +Key details from the Llama4 implementation: + +- `is_permanent = True` — the unpacked expert form is required for vLLM inference, so the module is not restored after calibration +- Expert weights are unpacked from a 3D packed tensor into a `ModuleList` of individual MLPs +- The config passed to `__init__` is a multimodal `Llama4Config`; text-specific settings are extracted via `config.get_text_config()` +- A `shared_expert` runs on all tokens unconditionally and its output is used as the accumulation base + +## Step-by-Step: Adding Support for a New Model + +### Step 1: Identify the MoE block class name + +Find the class in the transformers library that implements the MoE routing for your model: + +```python +from transformers.models.your_model.modeling_your_model import YourModelMoE +print(YourModelMoE.__name__) # e.g. "YourModelMoE" +``` + +### Step 2: Determine whether experts are packed + +Inspect the original MoE module to see how experts are stored: + +```python +import inspect +print(inspect.getsource(YourModelMoE.__init__)) +``` + +- If experts are stored as a `ModuleList` of individual layers, you can copy them directly. +- If experts are stored as a packed 3D tensor (like Llama4), you need a helper class to unpack them into a `ModuleList` before calibration. + +### Step 3: Create the calibration module + +Create a new file `src/llmcompressor/modeling/your_model.py`: + +```python +from typing import Tuple + +import torch +from transformers.models.your_model.configuration_your_model import YourModelConfig +from transformers.models.your_model.modeling_your_model import YourModelMoE as OriginalYourModelMoE + +from llmcompressor.modeling.moe_context import MoECalibrationModule + + +@MoECalibrationModule.register("YourModelMoE") +class SequentialYourModelMoE(MoECalibrationModule): + """ + Calibration version of YourModelMoE that sends all tokens to all experts + during calibration to ensure proper quantization statistics are collected. + """ + + is_permanent = True # set False if unpacking is not needed and you want restoration + + def __init__( + self, + original: OriginalYourModelMoE, + config: YourModelConfig, + calibrate_all_experts: bool = True, + ): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + + # Unpack packed experts if needed, or copy directly: + # self.experts = SequentialYourModelExperts(config, original.experts) + self.experts = original.experts + self.router = original.router + self.shared_expert = original.shared_expert + self.calibrate_all_experts = calibrate_all_experts + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_scores, router_logits = self.router(hidden_states) + out = self.shared_expert(hidden_states) + + _, router_indices = torch.topk(router_logits, self.top_k, dim=1) + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=self.num_experts + ).permute(2, 1, 0) + + for i in range(self.num_experts): + token_idx = torch.where(expert_mask[i].squeeze(0)) + + if self.calibrate_all_experts: + expert_out = self.experts[i](hidden_states)[token_idx] + else: + expert_out = self.experts[i](hidden_states[token_idx]) + + if len(token_idx) > 0: + weighted_output = expert_out * router_scores[:, i][token_idx].reshape(-1, 1) + out[token_idx] += weighted_output + + return out, router_logits +``` + +### Step 4: Import the calibration module at the call site + +The `@MoECalibrationModule.register(...)` decorator only takes effect when the module is imported. Import it before calling `oneshot`: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from llmcompressor import oneshot +from llmcompressor.modeling.your_model import SequentialYourModelMoE # noqa: F401 +from llmcompressor.modifiers.quantization import QuantizationModifier + +model_id = "your-org/your-moe-model" +model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +oneshot( + model=model, + dataset=ds, + recipe=[QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"])], + num_calibration_samples=512, + max_seq_length=2048, +) + +model.save_pretrained("your-moe-model-FP8", save_compressed=True) +tokenizer.save_pretrained("your-moe-model-FP8") +``` + +## Tips + +- **The register name must exactly match the original class name** (case-sensitive). Inspect `module.__class__.__name__` if unsure. +- **Check whether experts are packed.** If the model stores experts as a single high-dimensional tensor rather than a `ModuleList`, you need to unpack them before calibration — see the `SequentialLlama4TextExperts` pattern. +- **Match the original `forward` signature exactly**, including return values. Llama4, for example, returns `(out, router_logits)`. +- **Apply routing scores to expert outputs, not inputs.** Applying scores before the expert forward pass can produce NaNs during calibration. +- **Use `is_permanent=True` when the unpacked form is required for inference** (e.g., vLLM needs individual expert modules). Use `is_permanent=False` when you only need calibration coverage and want the original structure restored afterwards. +- **Test with a small model or a few calibration samples first** to confirm all experts are reached and no NaNs appear. diff --git a/docs/developer-tutorials/add-observer.md b/docs/developer-tutorials/add-observer.md new file mode 100644 index 0000000000..66d5f71289 --- /dev/null +++ b/docs/developer-tutorials/add-observer.md @@ -0,0 +1,212 @@ +# Adding a New Observer + +Observers analyze weight and activation tensors during calibration to compute the statistics needed for quantization. This guide explains how observers fit into the quantization pipeline and how to implement a custom one. + +## What is an Observer? + +When a quantized layer runs a calibration forward pass, it passes the weight or activation tensor to an observer. The observer's job is to compute **min and max values** from the observed tensor. These min/max values are then passed to `compressed_tensors.quantization.utils.calculate_qparams`, which converts them into `scale` and `zero_point` tensors used for quantization. + +Observers do **not** compute scales or zero points directly — that responsibility belongs to `compressed-tensors`. The observer's only job is to characterize the tensor's range via min and max values. + +For schemes that require a global scale (e.g., NVFP4, MXFP4), the observer's `get_global_min_max` output is similarly passed to `compressed_tensors.quantization.utils.generate_gparam`, which generates the global scale used to keep per-group local scales within a target dtype range (e.g., FP8 for NVFP4 group scales). + +The base `Observer` class handles all slicing and reshaping for group-wise, channel-wise, and token-wise strategies before calling your subclass. Your subclass only needs to answer: **given this pre-shaped tensor, what are the min and max values?** + +## The Observer Contract + +All observers subclass `llmcompressor.observers.Observer` and must implement two methods: + +```python +import torch +from llmcompressor.observers import Observer +from llmcompressor.observers.base import MinMaxTuple + +@Observer.register("my_observer") +class MyObserver(Observer): + + def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + """ + Compute min and max from the observed tensor. + + The base class has already reshaped the tensor into + shape (num_observations, *qparam_shape, group_size). + These min/max values are passed to calculate_qparams + in compressed-tensors to produce scale and zero_point. + + :param observed: pre-processed tensor ready for statistics computation + :return: (min_vals, max_vals) with shape (*qparam_shape,) + """ + ... + + def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + """ + Compute global min and max for global scale calculation (e.g., NVFP4, MXFP4). + + The base class reshapes the tensor to (1, 1, total_elements) before calling + this method. The returned values are passed to generate_gparam in + compressed-tensors to produce a global scale that keeps per-group local + scales within the target dtype range. + + :param observed: per-tensor reshaped tensor + :return: (min_val, max_val) scalar tensors of shape (1,) + """ + ... +``` + +The `@Observer.register("my_observer")` decorator registers your observer under the given name so it can be referenced in recipes by string. + +## How the Base Class Uses Your Output + +The base class `_forward_with_minmax` method calls your `get_min_max` and passes the result directly to `calculate_qparams` from `compressed-tensors`: + +```python +# Inside Observer._forward_with_minmax (simplified): +min_vals, max_vals = self.get_min_max(observed) +scales, zero_points = calculate_qparams( + min_vals=min_vals, + max_vals=max_vals, + quantization_args=self.args, + global_scale=global_scale, +) +``` + +`calculate_qparams` handles the actual scale and zero point computation — symmetric vs asymmetric quantization, dtype clamping, MX scale generation, and so on. Your observer only controls the min/max values fed into it. + +For global scales (FP4 schemes), the base class calls your `get_global_min_max` and passes the result to `generate_gparam`: + +```python +# Inside Observer._get_global_scale_with_minmax (simplified): +global_min_vals, global_max_vals = self.get_global_min_max(observed) +global_scale = generate_gparam(global_min_vals, global_max_vals) +``` + +## Stateful Observers + +Some observers accumulate statistics across multiple calibration batches. To do this, initialize state in `__init__` and update it in `get_min_max`: + +```python +@Observer.register("my_observer") +class MyObserver(Observer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.running_min = None + self.running_max = None + + def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + min_vals = torch.amin(observed, dim=(0, -1)) + max_vals = torch.amax(observed, dim=(0, -1)) + + if self.running_min is not None: + min_vals = torch.min(min_vals, self.running_min) + max_vals = torch.max(max_vals, self.running_max) + + self.running_min = min_vals + self.running_max = max_vals + + return min_vals, max_vals +``` + +## Example: A Percentile-Clipping Observer + +The following observer clips outliers by returning min/max values from a configurable percentile range rather than the absolute extremes. This can improve accuracy when tensors have extreme outlier values that would otherwise inflate the quantization range. + +```python +import torch +from llmcompressor.observers import Observer +from llmcompressor.observers.base import MinMaxTuple + +@Observer.register("percentile") +class PercentileObserver(Observer): + """ + Returns per-channel min/max values clipped to a configurable percentile + range, discarding outliers beyond the given percentile. The resulting + min/max values are passed to calculate_qparams in compressed-tensors + to produce scale and zero_point. + + Configure via observer_kwargs: + percentile (float): the upper percentile to retain, e.g. 99.9 + """ + + def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + percentile = self.args.observer_kwargs.get("percentile", 99.9) + lower = 100.0 - percentile + upper = percentile + + min_vals = torch.tensor( + [ + torch.quantile(observed[..., i], lower / 100.0).item() + for i in range(observed.shape[-2]) + ] + ) + max_vals = torch.tensor( + [ + torch.quantile(observed[..., i], upper / 100.0).item() + for i in range(observed.shape[-2]) + ] + ) + + return min_vals, max_vals + + def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + percentile = self.args.observer_kwargs.get("percentile", 99.9) + lower = 100.0 - percentile + flat = observed.flatten() + min_val = torch.quantile(flat, lower / 100.0).unsqueeze(0) + max_val = torch.quantile(flat, percentile / 100.0).unsqueeze(0) + return min_val, max_val +``` + +### Using the Observer in a Recipe + +Reference the registered name (`"percentile"`) via the `observer` field in `QuantizationArgs`: + +```python +from llmcompressor.modifiers.quantization import QuantizationModifier +from compressed_tensors.quantization import QuantizationArgs + +recipe = QuantizationModifier( + targets="Linear", + scheme={ + "weights": QuantizationArgs( + num_bits=8, + type="int", + symmetric=True, + strategy="channel", + observer="percentile", + observer_kwargs={"percentile": 99.5}, + ) + }, + ignore=["lm_head"], +) +``` + +Or from a YAML recipe: + +```yaml +quantization_stage: + quantization_modifiers: + QuantizationModifier: + targets: + - Linear + ignore: + - lm_head + scheme: + weights: + num_bits: 8 + type: int + symmetric: true + strategy: channel + observer: percentile + observer_kwargs: + percentile: 99.5 +``` + +## Tips + +- **Observers return min/max, not scale/zero_point.** The conversion from min/max → scale/zero_point is handled by `calculate_qparams` in `compressed-tensors`. Focus your implementation on accurately characterizing the tensor range. +- **`get_min_max` receives a pre-shaped tensor.** The base class has already sliced the input according to `QuantizationArgs.strategy` (group, channel, token, etc.). You do not need to handle reshaping yourself. +- **`get_global_min_max` is only used for FP4 schemes** (NVFP4, MXFP4) that require a global scale. For standard int8/fp8 quantization, the base class will not call it. +- **`observer_kwargs` is the right place for hyperparameters.** Access them via `self.args.observer_kwargs.get(...)`. +- **Match the shape contract.** `get_min_max` must return tensors of shape `(*qparam_shape,)` — one scalar per quantization group/channel/token. `get_global_min_max` must return shape `(1,)`. +- **Existing observers are good references.** See `min_max.py` for a simple stateless example and `mse.py` for a more complex stateful one. diff --git a/docs/developer-tutorials/index.md b/docs/developer-tutorials/index.md new file mode 100644 index 0000000000..fdd15819aa --- /dev/null +++ b/docs/developer-tutorials/index.md @@ -0,0 +1,33 @@ +# Developer Guides + +These guides are for contributors who want to extend LLM Compressor with new functionality. Each guide walks through the relevant abstractions, the contracts you must fulfill, and a concrete working example. + +## Tutorials + +