Skip to content

Commit a6567d7

Browse files
shanjiazbrian-dellabettadsikka
authored
[Observer] Optimize mse observer (#1450)
SUMMARY: The calculate_mse_min_max function previously performed a full grid search across a 0.8 × 100 = 80-point space. After discussing with Alex and Eldar last week, we reduced max_shrink to 0.2 to improve performance without sacrificing accuracy. Additionally, implemented an early stopping mechanism. The function now tracks the best quantization error seen so far and stops if no improvement is observed over 5 consecutive steps (patience = 5). maxshrink variable is now configurable in recipe file, and patience(for early stop) can be passed in as well. TEST PLAN: All lm_eval tests were run. No regressions in accuracy were observed. Performance improved significantly after maxshrink is updated. **There's a 3-7 mins slow down per test switching from MinMax to MSE observer.** USAGE: Tested the recipe by adding: ```yaml observer: "mse" observer_kwargs: maxshrink: 0.3 ``` More details can be found in this [notion page](https://www.notion.so/Accuracy-test-1d930c7e73f3803bb057fd17d6d45302?pvs=4) Raw timing data are stored [here](https://drive.google.com/drive/folders/1I69QNGKxLJJZ06k9jSw0f0BRhPchV_nt?usp=drive_link) --------- Signed-off-by: shanjiaz <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 1fb1377 commit a6567d7

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/llmcompressor/observers/mse.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@ def __init__(
2222
quantization_args: QuantizationArgs,
2323
averaging_constant: float = 0.01,
2424
grid: float = 100.0,
25-
maxshrink: float = 0.80,
2625
norm: float = 2.4,
2726
global_scale: Optional[torch.Tensor] = None,
2827
):
2928
super().__init__(quantization_args=quantization_args, global_scale=global_scale)
3029

30+
kwargs = quantization_args.observer_kwargs or {}
31+
self.maxshrink = kwargs.get("maxshrink", 0.20)
32+
self.patience = kwargs.get("patience", 5)
33+
3134
self.min_val = {}
3235
self.max_val = {}
3336
self.averaging_constant = averaging_constant
3437
self.grid = grid
35-
self.maxshrink = maxshrink
3638
self.norm = norm
3739

3840
def calculate_mse_min_max(
@@ -62,6 +64,10 @@ def calculate_mse_min_max(
6264
)
6365
min_val = torch.ones_like(absolute_min_val)
6466
max_val = torch.zeros_like(absolute_max_val)
67+
68+
# Early stopping params
69+
no_improve_count = 0
70+
6571
for i in range(int(self.maxshrink * self.grid)):
6672
p = 1 - i / self.grid
6773
shrinked_min_val = p * absolute_min_val
@@ -94,6 +100,12 @@ def calculate_mse_min_max(
94100
best[tmp] = err[tmp]
95101
min_val[tmp] = shrinked_min_val[tmp]
96102
max_val[tmp] = shrinked_max_val[tmp]
103+
no_improve_count = 0
104+
else:
105+
no_improve_count += 1
106+
if no_improve_count >= self.patience:
107+
break
108+
97109
return min_val, max_val
98110

99111
def calculate_qparams(

tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_group.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ quant_stage:
1111
strategy: "group"
1212
group_size: 128
1313
actorder: "group"
14-
targets: ["Linear"]
14+
targets: ["Linear"]

0 commit comments

Comments
 (0)