-
Notifications
You must be signed in to change notification settings - Fork 453
[Docs] Add Developer Guides section #2517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dsikka
wants to merge
7
commits into
main
Choose a base branch
from
dev-tutorials
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+737
−31
Open
Changes from 3 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
c335602
[Docs] Add Developer Tutorials section
dsikka 6c7c46f
[Docs] Rename Developer Tutorials to Developer Guides
dsikka 88a12b9
[Docs] Rename MoE tutorial to "Adding MoE Calibration Support for a N…
dsikka ac3eba7
[Docs] Update MoE calibration guide to reference Llama4
dsikka a5c6428
[Docs] Correct observer tutorial: observers compute min/max, not scal…
dsikka 6e5a80d
[Docs] Update observers guide with correct descriptions and new obser…
dsikka bb745e6
Merge branch 'main' into dev-tutorials
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,177 @@ | ||||||
| # Adding a New Modifier | ||||||
|
|
||||||
| Modifiers are the core extension point in LLM Compressor. Each compression algorithm — GPTQ, AWQ, SmoothQuant, and others — is implemented as a modifier. This tutorial walks through the modifier contract, lifecycle, and how to implement a custom one. | ||||||
|
|
||||||
| ## What is a Modifier? | ||||||
|
|
||||||
| A modifier is a Pydantic model that hooks into the compression pipeline at well-defined lifecycle points. When you call `oneshot`, LLM Compressor: | ||||||
|
|
||||||
| 1. Instantiates modifiers from the recipe | ||||||
| 2. Calls `initialize` on each modifier | ||||||
| 3. Runs calibration batches, firing `Event`s that modifiers respond to | ||||||
| 4. Calls `finalize` on each modifier | ||||||
|
|
||||||
| Modifiers express what they want to do at each stage by overriding lifecycle hooks. | ||||||
|
|
||||||
| ## The Modifier Contract | ||||||
|
|
||||||
| All modifiers subclass `llmcompressor.modifiers.Modifier` and must implement `on_initialize`. All other hooks are optional. | ||||||
|
|
||||||
| ```python | ||||||
| from llmcompressor.modifiers import Modifier | ||||||
| from llmcompressor.core import State, Event | ||||||
|
|
||||||
| class MyModifier(Modifier): | ||||||
| # Pydantic fields — declare your parameters here | ||||||
| my_param: float = 1.0 | ||||||
|
|
||||||
| def on_initialize(self, state: State, **kwargs) -> bool: | ||||||
| # Called once before calibration begins. | ||||||
| # Set up hooks, attach attributes to modules, etc. | ||||||
| # Return True if initialization succeeded. | ||||||
| ... | ||||||
| return True | ||||||
|
|
||||||
| def on_start(self, state: State, event: Event, **kwargs): | ||||||
| # Called when calibration starts (first BATCH_START event). | ||||||
| ... | ||||||
|
|
||||||
| def on_update(self, state: State, event: Event, **kwargs): | ||||||
| # Called on every event while the modifier is active. | ||||||
| ... | ||||||
|
|
||||||
| def on_end(self, state: State, event: Event, **kwargs): | ||||||
| # Called when calibration ends. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to
Suggested change
|
||||||
| ... | ||||||
|
|
||||||
| def on_finalize(self, state: State, **kwargs) -> bool: | ||||||
| # Called after calibration completes. | ||||||
| # Clean up hooks, apply final transformations, etc. | ||||||
| # Return True if finalization succeeded. | ||||||
| ... | ||||||
| return True | ||||||
| ``` | ||||||
|
|
||||||
| ### Lifecycle Summary | ||||||
|
|
||||||
| | Hook | When it runs | Required | | ||||||
| |------|-------------|----------| | ||||||
| | `on_initialize` | Once, before calibration | Yes | | ||||||
| | `on_start` | First `BATCH_START` event | No | | ||||||
| | `on_update` | Every event while active | No | | ||||||
| | `on_end` | `BATCH_END` when modifier ends | No | | ||||||
| | `on_finalize` | Once, after calibration | No | | ||||||
|
|
||||||
| ### The `State` Object | ||||||
|
|
||||||
| `state.model` gives you the `torch.nn.Module` being compressed. This is the primary object you will interact with in most hooks. | ||||||
|
|
||||||
| ### Pydantic Parameters | ||||||
|
|
||||||
| Because `Modifier` is a Pydantic model with `extra="forbid"`, all parameters must be declared as class-level fields. This also means your modifier can be instantiated directly in Python or from a YAML recipe. | ||||||
|
|
||||||
| ```python | ||||||
| class MyModifier(Modifier): | ||||||
| targets: list[str] = ["Linear"] | ||||||
| scale_factor: float = 0.5 | ||||||
| ignore: list[str] = [] | ||||||
| ``` | ||||||
|
|
||||||
| ## Attaching Hooks with `HooksMixin` | ||||||
|
|
||||||
| `Modifier` inherits from `HooksMixin`, which provides a managed way to register PyTorch forward hooks. Hooks registered through `HooksMixin` are automatically removed when `finalize` is called. | ||||||
|
|
||||||
| ```python | ||||||
| from llmcompressor.modifiers import Modifier | ||||||
| from llmcompressor.core import State | ||||||
|
|
||||||
| class MyModifier(Modifier): | ||||||
| def on_initialize(self, state: State, **kwargs) -> bool: | ||||||
| for name, module in state.model.named_modules(): | ||||||
| if "Linear" in type(module).__name__: | ||||||
| self.register_hook( | ||||||
| module, | ||||||
| self._forward_hook, | ||||||
| "forward", | ||||||
| ) | ||||||
| return True | ||||||
|
|
||||||
| def _forward_hook(self, module, inputs, output): | ||||||
| # Runs after every forward pass through this module | ||||||
| ... | ||||||
| ``` | ||||||
|
|
||||||
| ## Example: A Weight-Clamping Modifier | ||||||
|
|
||||||
| The following modifier clamps the absolute magnitude of all `Linear` layer weights during `on_finalize`, after calibration is complete. | ||||||
|
|
||||||
| ```python | ||||||
| import torch | ||||||
| from compressed_tensors.utils import match_named_modules | ||||||
| from llmcompressor.modifiers import Modifier | ||||||
| from llmcompressor.core import State | ||||||
|
|
||||||
| class WeightClampModifier(Modifier): | ||||||
| """ | ||||||
| Clamps the magnitude of Linear layer weights to a maximum absolute value. | ||||||
|
|
||||||
| :param max_weight_magnitude: maximum allowed absolute weight value | ||||||
| :param targets: module types to target | ||||||
| :param ignore: module names to skip | ||||||
| """ | ||||||
|
|
||||||
| max_weight_magnitude: float = 1.0 | ||||||
| targets: list[str] = ["Linear"] | ||||||
| ignore: list[str] = [] | ||||||
|
|
||||||
| def on_initialize(self, state: State, **kwargs) -> bool: | ||||||
| return True | ||||||
|
|
||||||
| def on_finalize(self, state: State, **kwargs) -> bool: | ||||||
| for name, module in match_named_modules( | ||||||
| state.model, self.targets, self.ignore | ||||||
| ): | ||||||
| with torch.no_grad(): | ||||||
| module.weight.clamp_( | ||||||
| -self.max_weight_magnitude, | ||||||
| self.max_weight_magnitude, | ||||||
| ) | ||||||
| return True | ||||||
| ``` | ||||||
|
|
||||||
| ### Using the Modifier with `oneshot` | ||||||
|
|
||||||
| ```python | ||||||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
| from llmcompressor import oneshot | ||||||
|
|
||||||
| model = AutoModelForCausalLM.from_pretrained("your-model") | ||||||
| tokenizer = AutoTokenizer.from_pretrained("your-model") | ||||||
|
|
||||||
| oneshot( | ||||||
| model=model, | ||||||
| recipe=[WeightClampModifier(max_weight_magnitude=0.5, ignore=["lm_head"])], | ||||||
| ) | ||||||
|
|
||||||
| model.save_pretrained("your-model-clamped", save_compressed=True) | ||||||
| tokenizer.save_pretrained("your-model-clamped") | ||||||
| ``` | ||||||
|
|
||||||
| ### Using the Modifier from a YAML Recipe | ||||||
|
|
||||||
| ```yaml | ||||||
| weight_clamp_stage: | ||||||
| weight_clamp_modifiers: | ||||||
| WeightClampModifier: | ||||||
| max_weight_magnitude: 0.5 | ||||||
| targets: | ||||||
| - Linear | ||||||
| ignore: | ||||||
| - lm_head | ||||||
| ``` | ||||||
|
|
||||||
| ## Tips | ||||||
|
|
||||||
| - **Only override what you need.** The default implementations of `on_start`, `on_update`, `on_end`, and `on_finalize` are no-ops or return `True` — you do not need to call `super()` for these. | ||||||
| - **Use `match_named_modules`** (from `compressed_tensors.utils`) to filter modules by type name or path pattern, consistent with how other modifiers handle `targets` and `ignore`. | ||||||
| - **Keep `on_initialize` lightweight.** Expensive operations (e.g., full-model passes) should be deferred to `on_start` or `on_finalize`. | ||||||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment here is slightly misleading.
on_startis not necessarily called on the firstBATCH_STARTevent of calibration, but rather on the firstBATCH_STARTevent for which the modifier becomes active (i.e., whenevent.current_index >= self.start). For better clarity, I suggest changing the comment.