Skip to content

Commit 9d33973

Browse files
committed
perf: make MSE observer compatible with torch.compile
compile inner _compute_candidate_error via torch.compile(dynamic=True). early stopping preserved in outer loop. compile flag added as oneshot arg. requires: vllm-project/compressed-tensors#627 related: pytorch/pytorch#177131
1 parent 36c30ee commit 9d33973

File tree

4 files changed

+206
-70
lines changed

4 files changed

+206
-70
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from llmcompressor.core.session_functions import active_session
2323
from llmcompressor.datasets import get_calibration_dataloader
2424
from llmcompressor.entrypoints.utils import post_process, pre_process
25+
from llmcompressor.observers.compile_config import set_observer_compile
2526
from llmcompressor.modeling.moe_context import moe_calibration_context
2627
from llmcompressor.pipelines import CalibrationPipeline
2728

@@ -299,6 +300,8 @@ def oneshot(
299300
sequential_targets: list[str] | None = None,
300301
sequential_offload_device: str = "cpu",
301302
quantization_aware_calibration: bool = True,
303+
sequential_prefetch: bool = False,
304+
enable_observer_compile: bool = False,
302305
# Miscellaneous arguments
303306
output_dir: str | None = None,
304307
log_dir: str | None = None,
@@ -364,9 +367,10 @@ def oneshot(
364367
:param streaming: True to stream data from a cloud dataset.
365368
:param overwrite_cache: Whether to overwrite the cached preprocessed datasets.
366369
:param preprocessing_num_workers: Number of processes for dataset preprocessing.
367-
:param dataloader_num_workers: Number of worker processes for data loading. Set to 0
368-
to disable multiprocessing. Note: Custom data collators may not work with
369-
multiprocessing. Default is 0.
370+
:param dataloader_num_workers: Number of worker processes for data loading. Default
371+
is 0 (safe for low CPU/GPU memory). Set to 2 or more for faster calibration if
372+
you have sufficient RAM. Custom data collators may not work with
373+
multiprocessing.
370374
:param min_tokens_per_module: Minimum percentage of tokens per
371375
module, relevant for MoE models.
372376
:param moe_calibrate_all_experts: Whether to calibrate all experts during MoE
@@ -388,6 +392,9 @@ def oneshot(
388392
calibration in the sequential pipeline. When True, quantization is applied
389393
during forward pass in calibration. When False, quantization is disabled
390394
during forward pass in calibration. Default is set to True.
395+
:param sequential_prefetch: When using the sequential pipeline, prefetch the
396+
next batch in a background thread to overlap onload with forward. Default
397+
False; set True for faster calibration when GPU memory allows.
391398
392399
# Miscellaneous arguments
393400
:param output_dir: Path to save the output model after calibration.
@@ -400,9 +407,10 @@ def oneshot(
400407

401408
# pass all args directly into Oneshot
402409
local_args = {
403-
k: v for k, v in locals().items() if k not in ("local_args", "kwargs")
410+
k: v for k, v in locals().items() if k not in ("local_args", "kwargs", "enable_observer_compile")
404411
}
405412
one_shot = Oneshot(**local_args, **kwargs)
413+
set_observer_compile(enable_observer_compile)
406414
one_shot()
407415

408416
return one_shot.model
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Global configuration for observer torch.compile support.
3+
4+
The compile flag is set by the oneshot entrypoint and read by observer
5+
instances at call time. This avoids threading the flag through recipe
6+
and modifier layers.
7+
"""
8+
9+
_enable_observer_compile: bool = False
10+
11+
12+
def set_observer_compile(enabled: bool) -> None:
13+
global _enable_observer_compile
14+
_enable_observer_compile = enabled
15+
16+
17+
def get_observer_compile() -> bool:
18+
return _enable_observer_compile

src/llmcompressor/observers/mse.py

Lines changed: 143 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
1-
from typing import Optional
1+
from typing import Optional, Tuple
22

33
import torch
4+
import torch._dynamo.config
45
from compressed_tensors.quantization import (
56
QuantizationArgs,
67
QuantizationStrategy,
78
)
89
from compressed_tensors.quantization.lifecycle import fake_quantize
910
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
10-
from compressed_tensors.utils import patch_attr
1111

1212
from llmcompressor.observers.base import MinMaxTuple, Observer
13+
from llmcompressor.observers.compile_config import get_observer_compile
1314
from llmcompressor.observers.moving_base import MovingAverageObserverBase
1415

1516
__all__ = ["MovingAverageMSEObserver"]
1617

18+
# Allow torch.compile to handle scalar conversions inside
19+
# compressed_tensors' calculate_qparams (float(bit_range)).
20+
# Same approach as GPTQ compile path (commit a4f9ba2e).
21+
torch._dynamo.config.capture_scalar_outputs = True
22+
1723

1824
@Observer.register("memoryless_mse")
1925
class MemorylessMSEObserver(Observer):
@@ -32,7 +38,7 @@ class MemorylessMSEObserver(Observer):
3238
:param module: optional module with attached quantization parameters. This argument
3339
is required to utilize existing qparams such as global_scale or g_idx
3440
:param **observer_kwargs: keyword arguments for observer initialization\n
35-
maxshrink: maximum shrink amount (in grid steps). The number of
41+
maxshrink: maximum shrink amount (in "grid steps"). The number of
3642
search steps is int(maxshrink * grid)\n
3743
patience: number of consecutive search steps without improvement before
3844
early stopping\n
@@ -53,32 +59,39 @@ def __init__(self, *args, **kwargs):
5359
self.grid = observer_kwargs.get("grid", 100.0)
5460
self.norm = observer_kwargs.get("norm", 2.4)
5561

56-
def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
57-
# min[min_vals, max_vals](mse_quant_error)
58-
global_scale = self._get_module_param("global_scale")
62+
# Pre-create token_args to avoid patch_attr context manager
63+
# which causes torch.compile graph breaks
64+
self._token_args = self.args.model_copy(
65+
update={"strategy": QuantizationStrategy.TOKEN}
66+
)
67+
68+
def _call_grid_search(
69+
self,
70+
observed: torch.Tensor,
71+
global_scale: Optional[torch.Tensor],
72+
optimize_global_scale: bool,
73+
) -> MinMaxTuple:
5974
return _grid_search_mse(
6075
observed,
6176
self.args,
77+
self._token_args,
6278
self.maxshrink,
6379
self.patience,
6480
self.grid,
6581
self.norm,
6682
global_scale=global_scale,
67-
optimize_global_scale=False,
83+
optimize_global_scale=optimize_global_scale,
84+
enable_compile=get_observer_compile(),
6885
)
6986

87+
def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
88+
# min[min_vals, max_vals](mse_quant_error)
89+
global_scale = self._get_module_param("global_scale")
90+
return self._call_grid_search(observed, global_scale, False)
91+
7092
def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
7193
# min[min_vals, max_vals, global_scale](mse_quant_error)
72-
return _grid_search_mse(
73-
observed,
74-
self.args,
75-
self.maxshrink,
76-
self.patience,
77-
self.grid,
78-
self.norm,
79-
global_scale=None,
80-
optimize_global_scale=True,
81-
)
94+
return self._call_grid_search(observed, None, True)
8295

8396

8497
@Observer.register("mse")
@@ -98,7 +111,7 @@ class MovingAverageMSEObserver(MovingAverageObserverBase):
98111
:param module: optional module with attached quantization parameters. This argument
99112
is required to utilize existing qparams such as global_scale or g_idx
100113
:param **observer_kwargs: keyword arguments for observer initialization\n
101-
maxshrink: maximum shrink amount (in grid steps). The number of
114+
maxshrink: maximum shrink amount (in "grid steps"). The number of
102115
search steps is int(maxshrink * grid)\n
103116
patience: number of consecutive search steps without improvement before
104117
early stopping\n
@@ -119,55 +132,134 @@ def __init__(self, *args, **kwargs):
119132
self.grid = observer_kwargs.get("grid", 100.0)
120133
self.norm = observer_kwargs.get("norm", 2.4)
121134

122-
def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
123-
# min[min_vals, max_vals](mse_quant_error)
124-
global_scale = self._get_module_param("global_scale")
135+
# Pre-create token_args to avoid patch_attr context manager
136+
# which causes torch.compile graph breaks
137+
self._token_args = self.args.model_copy(
138+
update={"strategy": QuantizationStrategy.TOKEN}
139+
)
140+
141+
def _call_grid_search(
142+
self,
143+
observed: torch.Tensor,
144+
global_scale: Optional[torch.Tensor],
145+
optimize_global_scale: bool,
146+
) -> MinMaxTuple:
125147
return _grid_search_mse(
126148
observed,
127149
self.args,
150+
self._token_args,
128151
self.maxshrink,
129152
self.patience,
130153
self.grid,
131154
self.norm,
132155
global_scale=global_scale,
133-
optimize_global_scale=False,
156+
optimize_global_scale=optimize_global_scale,
157+
enable_compile=get_observer_compile(),
134158
)
135159

160+
def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
161+
# min[min_vals, max_vals](mse_quant_error)
162+
global_scale = self._get_module_param("global_scale")
163+
return self._call_grid_search(observed, global_scale, False)
164+
136165
def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
137166
# min[min_vals, max_vals, global_scale](mse_quant_error)
138-
return _grid_search_mse(
139-
observed,
140-
self.args,
141-
self.maxshrink,
142-
self.patience,
143-
self.grid,
144-
self.norm,
145-
global_scale=None,
146-
optimize_global_scale=True,
147-
)
167+
return self._call_grid_search(observed, None, True)
168+
169+
170+
def _compute_candidate_error(
171+
observed: torch.Tensor,
172+
args: QuantizationArgs,
173+
token_args: QuantizationArgs,
174+
min_val: torch.Tensor,
175+
max_val: torch.Tensor,
176+
p: float,
177+
norm: float,
178+
global_scale: Optional[torch.Tensor],
179+
optimize_global_scale: bool,
180+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181+
"""
182+
Compute the quantization error for a single shrink factor.
183+
184+
Shared helper used by the grid search. When enable_compile is set
185+
via oneshot, this function is called through its compiled wrapper
186+
for accelerated execution.
187+
188+
:param observed: value of shape (num_observations, *qparams_shape, group_size)
189+
:param args: quantization args used for computing qparams
190+
:param token_args: quantization args with strategy set to TOKEN, pre-created
191+
to avoid patch_attr context manager which causes torch.compile graph breaks
192+
:param min_val: per-channel minimum values
193+
:param max_val: per-channel maximum values
194+
:param p: shrink factor (1 - i/grid)
195+
:param norm: exponent used when computing the error
196+
:param global_scale: precomputed global scale to use for quantization
197+
:param optimize_global_scale: If True, recompute global_scale from candidates
198+
:return: (error, shrinked_min_val, shrinked_max_val)
199+
"""
200+
shrinked_min_val = p * min_val
201+
shrinked_max_val = p * max_val
202+
203+
if optimize_global_scale:
204+
global_scale = generate_gparam(shrinked_min_val, shrinked_max_val)
205+
206+
candidate_scales, candidate_zero_points = calculate_qparams(
207+
min_vals=shrinked_min_val,
208+
max_vals=shrinked_max_val,
209+
quantization_args=args,
210+
global_scale=global_scale,
211+
)
212+
213+
# Use pre-created token_args instead of patch_attr context manager
214+
# to maintain torch.compile compatibility
215+
q = fake_quantize(
216+
observed,
217+
candidate_scales.unsqueeze(-1),
218+
candidate_zero_points.unsqueeze(-1),
219+
token_args,
220+
global_scale=global_scale,
221+
).to(observed.dtype)
222+
223+
err = torch.sum((q - observed).abs().pow(norm), dim=(0, -1))
224+
return err, shrinked_min_val, shrinked_max_val
225+
226+
227+
# Compiled variant of the inner computation.
228+
# The outer grid search loop stays in eager mode to preserve
229+
# early stopping (data-dependent control flow).
230+
_compute_candidate_error_compiled = torch.compile(
231+
_compute_candidate_error, dynamic=True
232+
)
148233

149234

150235
def _grid_search_mse(
151236
observed: torch.Tensor,
152237
args: QuantizationArgs,
238+
token_args: QuantizationArgs,
153239
maxshrink: float,
154240
patience: float,
155241
grid: float,
156242
norm: float,
157243
global_scale: Optional[torch.Tensor] = None,
158244
optimize_global_scale: bool = False,
245+
enable_compile: bool = False,
159246
) -> MinMaxTuple:
160247
"""
161248
Perform a 1-D grid search to find per-channel min/max ranges that minimize
162249
mean-squared quantization error.
163250
164-
This routine progressively “shrinks” the absolute min/max ranges of the
165-
observed tensor and evaluates the quantization error at each candidate
166-
range. For each shrink factor ``p = 1 - i/grid`` up to ``maxshrink``.
251+
Progressively shrinks the absolute min/max ranges of the observed tensor
252+
and evaluates the quantization error at each candidate. Early stopping
253+
exits when no improvement is found for ``patience`` consecutive steps.
254+
255+
When enable_compile is True, the inner error computation is executed
256+
through a torch.compiled wrapper for accelerated execution while
257+
preserving early stopping in the outer loop.
167258
168259
:param observed: value of shape (num_observations, *qparams_shape, group_size)
169260
:param args: quantization args used for computing qparams and fake quant
170-
:param maxshrink: maximum shrink amount (in “grid steps”). The number of
261+
:param token_args: quantization args with strategy set to TOKEN
262+
:param maxshrink: maximum shrink amount (in "grid steps"). The number of
171263
search steps is int(maxshrink * grid)
172264
:param patience: number of consecutive search steps without improvement before
173265
early stopping
@@ -178,50 +270,35 @@ def _grid_search_mse(
178270
`optimize_global_scale` is True
179271
:param optimize_global_scale: If True, recompute ``global_scale`` from the
180272
candidate min/max during each step of the search
273+
:param enable_compile: If True, use torch.compiled inner computation
181274
"""
182275
min_val = torch.amin(observed, dim=(0, -1))
183276
max_val = torch.amax(observed, dim=(0, -1))
184277
best_error = torch.full_like(min_val, torch.finfo(min_val.dtype).max)
185278
best_min_val = min_val.clone()
186279
best_max_val = max_val.clone()
187280

188-
# Early stopping params
281+
compute_fn = (
282+
_compute_candidate_error_compiled if enable_compile
283+
else _compute_candidate_error
284+
)
189285
no_improve_count = 0
190286

191287
# @ksayers @HGCharles: investigate searching over separate shrinking factors
192288
for i in range(int(maxshrink * grid)):
193289
p = 1 - i / grid
194-
shrinked_min_val = p * min_val
195-
shrinked_max_val = p * max_val
196-
197-
if optimize_global_scale:
198-
global_scale = generate_gparam(shrinked_min_val, shrinked_max_val)
199-
200-
candidate_scales, candidate_zero_points = calculate_qparams(
201-
min_vals=shrinked_min_val,
202-
max_vals=shrinked_max_val,
203-
quantization_args=args,
204-
global_scale=global_scale,
290+
err, shrinked_min_val, shrinked_max_val = compute_fn(
291+
observed,
292+
args,
293+
token_args,
294+
min_val,
295+
max_val,
296+
p,
297+
norm,
298+
global_scale,
299+
optimize_global_scale,
205300
)
206301

207-
# Note that observed.shape = (num_observations, *qparams_shape, group_size).
208-
# For the purposes of fake quantization, this is equivalent to token quant
209-
with patch_attr(args, "strategy", QuantizationStrategy.TOKEN):
210-
q = fake_quantize(
211-
observed,
212-
candidate_scales.unsqueeze(-1),
213-
candidate_zero_points.unsqueeze(-1),
214-
args,
215-
global_scale=global_scale,
216-
).to(observed.dtype)
217-
# Note that due to forward quantization implementation, token quant,
218-
# unlike tensor_group, requires extra dtype cast
219-
220-
q -= observed
221-
q.abs_()
222-
q.pow_(norm)
223-
err = torch.sum(q, dim=(0, -1))
224-
225302
tmp = err < best_error
226303
if torch.any(tmp):
227304
best_error[tmp] = err[tmp]

0 commit comments

Comments
 (0)