Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/.nav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ nav:
- Observers: guides/observers.md
- Memory Requirements: guides/memory.md
- Runtime Performance: guides/runtime.md
- Developer Guides:
- developer-tutorials/index.md
- Adding a New Modifier: developer-tutorials/add-modifier.md
- Adding a New Observer: developer-tutorials/add-observer.md
- Adding MoE Calibration Support for a New Model: developer-tutorials/add-moe-support.md
- Examples:
- examples/README.md
- examples/*
Expand Down
177 changes: 177 additions & 0 deletions docs/developer-tutorials/add-modifier.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Adding a New Modifier

Modifiers are the core extension point in LLM Compressor. Each compression algorithm — GPTQ, AWQ, SmoothQuant, and others — is implemented as a modifier. This tutorial walks through the modifier contract, lifecycle, and how to implement a custom one.

## What is a Modifier?

A modifier is a Pydantic model that hooks into the compression pipeline at well-defined lifecycle points. When you call `oneshot`, LLM Compressor:

1. Instantiates modifiers from the recipe
2. Calls `initialize` on each modifier
3. Runs calibration batches, firing `Event`s that modifiers respond to
4. Calls `finalize` on each modifier

Modifiers express what they want to do at each stage by overriding lifecycle hooks.

## The Modifier Contract

All modifiers subclass `llmcompressor.modifiers.Modifier` and must implement `on_initialize`. All other hooks are optional.

```python
from llmcompressor.modifiers import Modifier
from llmcompressor.core import State, Event

class MyModifier(Modifier):
# Pydantic fields — declare your parameters here
my_param: float = 1.0

def on_initialize(self, state: State, **kwargs) -> bool:
# Called once before calibration begins.
# Set up hooks, attach attributes to modules, etc.
# Return True if initialization succeeded.
...
return True

def on_start(self, state: State, event: Event, **kwargs):
# Called when calibration starts (first BATCH_START event).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment here is slightly misleading. on_start is not necessarily called on the first BATCH_START event of calibration, but rather on the first BATCH_START event for which the modifier becomes active (i.e., when event.current_index >= self.start). For better clarity, I suggest changing the comment.

Suggested change
# Called when calibration starts (first BATCH_START event).
# Called on the BATCH_START event when the modifier's `start` step is reached.

...

def on_update(self, state: State, event: Event, **kwargs):
# Called on every event while the modifier is active.
...

def on_end(self, state: State, event: Event, **kwargs):
# Called when calibration ends.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to on_start, the comment for on_end could be more precise. It's triggered when the modifier's end step is reached, on a BATCH_END event. I suggest updating the comment for accuracy.

Suggested change
# Called when calibration ends.
# Called on the BATCH_END event when the modifier's `end` step is reached.

...

def on_finalize(self, state: State, **kwargs) -> bool:
# Called after calibration completes.
# Clean up hooks, apply final transformations, etc.
# Return True if finalization succeeded.
...
return True
```

### Lifecycle Summary

| Hook | When it runs | Required |
|------|-------------|----------|
| `on_initialize` | Once, before calibration | Yes |
| `on_start` | First `BATCH_START` event | No |
| `on_update` | Every event while active | No |
| `on_end` | `BATCH_END` when modifier ends | No |
| `on_finalize` | Once, after calibration | No |

### The `State` Object

`state.model` gives you the `torch.nn.Module` being compressed. This is the primary object you will interact with in most hooks.

### Pydantic Parameters

Because `Modifier` is a Pydantic model with `extra="forbid"`, all parameters must be declared as class-level fields. This also means your modifier can be instantiated directly in Python or from a YAML recipe.

```python
class MyModifier(Modifier):
targets: list[str] = ["Linear"]
scale_factor: float = 0.5
ignore: list[str] = []
```

## Attaching Hooks with `HooksMixin`

`Modifier` inherits from `HooksMixin`, which provides a managed way to register PyTorch forward hooks. Hooks registered through `HooksMixin` are automatically removed when `finalize` is called.

```python
from llmcompressor.modifiers import Modifier
from llmcompressor.core import State

class MyModifier(Modifier):
def on_initialize(self, state: State, **kwargs) -> bool:
for name, module in state.model.named_modules():
if "Linear" in type(module).__name__:
self.register_hook(
module,
self._forward_hook,
"forward",
)
return True

def _forward_hook(self, module, inputs, output):
# Runs after every forward pass through this module
...
```

## Example: A Weight-Clamping Modifier

The following modifier clamps the absolute magnitude of all `Linear` layer weights during `on_finalize`, after calibration is complete.

```python
import torch
from compressed_tensors.utils import match_named_modules
from llmcompressor.modifiers import Modifier
from llmcompressor.core import State

class WeightClampModifier(Modifier):
"""
Clamps the magnitude of Linear layer weights to a maximum absolute value.

:param max_weight_magnitude: maximum allowed absolute weight value
:param targets: module types to target
:param ignore: module names to skip
"""

max_weight_magnitude: float = 1.0
targets: list[str] = ["Linear"]
ignore: list[str] = []

def on_initialize(self, state: State, **kwargs) -> bool:
return True

def on_finalize(self, state: State, **kwargs) -> bool:
for name, module in match_named_modules(
state.model, self.targets, self.ignore
):
with torch.no_grad():
module.weight.clamp_(
-self.max_weight_magnitude,
self.max_weight_magnitude,
)
return True
```

### Using the Modifier with `oneshot`

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot

model = AutoModelForCausalLM.from_pretrained("your-model")
tokenizer = AutoTokenizer.from_pretrained("your-model")

oneshot(
model=model,
recipe=[WeightClampModifier(max_weight_magnitude=0.5, ignore=["lm_head"])],
)

model.save_pretrained("your-model-clamped", save_compressed=True)
tokenizer.save_pretrained("your-model-clamped")
```

### Using the Modifier from a YAML Recipe

```yaml
weight_clamp_stage:
weight_clamp_modifiers:
WeightClampModifier:
max_weight_magnitude: 0.5
targets:
- Linear
ignore:
- lm_head
```

## Tips

- **Only override what you need.** The default implementations of `on_start`, `on_update`, `on_end`, and `on_finalize` are no-ops or return `True` — you do not need to call `super()` for these.
- **Use `match_named_modules`** (from `compressed_tensors.utils`) to filter modules by type name or path pattern, consistent with how other modifiers handle `targets` and `ignore`.
- **Keep `on_initialize` lightweight.** Expensive operations (e.g., full-model passes) should be deferred to `on_start` or `on_finalize`.
Loading
Loading