Skip to content

Commit 27a122f

Browse files
committed
fix mse, add tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0b79d09 commit 27a122f

File tree

3 files changed

+68
-21
lines changed

3 files changed

+68
-21
lines changed

src/llmcompressor/observers/min_max.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso
4444
if self.min_vals is not None and self.averaging_constant != 1.0:
4545
# FUTURE: consider scaling by num observations (first dim)
4646
# rather than reducing by first dim
47-
min_vals = self._lerp(min_vals, self.min_vals, self.averaging_constant)
48-
max_vals = self._lerp(max_vals, self.max_vals, self.averaging_constant)
47+
min_vals = self._lerp(self.min_vals, min_vals, self.averaging_constant)
48+
max_vals = self._lerp(self.max_vals, max_vals, self.averaging_constant)
4949

5050
return min_vals, max_vals
5151

5252
def _lerp(
5353
self, input: torch.Tensor, end: torch.Tensor, weight: float
5454
) -> torch.Tensor:
5555
"""torch lerp_kernel is not implemeneted for all data types"""
56-
return (input * weight) + (end * (1.0 - weight))
56+
return (input * (1.0 - weight)) + (end * weight)

src/llmcompressor/observers/mse.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from typing import Optional, Tuple
22

33
import torch
4-
from compressed_tensors.quantization.quant_args import QuantizationArgs
4+
from compressed_tensors.quantization.quant_args import (
5+
QuantizationArgs,
6+
QuantizationStrategy,
7+
)
58
from compressed_tensors.quantization.utils import calculate_qparams
9+
from compressed_tensors.utils import patch_attr
610

711
from llmcompressor.observers.base import Observer
812

@@ -42,6 +46,24 @@ def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso
4246
(num_observations, *qparam_shape, group_size)
4347
:return: minimum value and maximum value whose shapes are (*qparam_shape, )
4448
"""
49+
min_vals, max_vals = self._mse_min_max(observed)
50+
51+
if self.min_vals is not None and self.averaging_constant != 1.0:
52+
# FUTURE: consider scaling by num observations (first dim)
53+
# rather than reducing by first dim
54+
min_vals = self._lerp(self.min_vals, min_vals, self.averaging_constant)
55+
max_vals = self._lerp(self.max_vals, max_vals, self.averaging_constant)
56+
57+
return min_vals, max_vals
58+
59+
def _mse_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
60+
"""
61+
Grid search for MSE-optimal min and max values
62+
63+
:param observed: value being observed whose shape is
64+
(num_observations, *qparam_shape, group_size)
65+
:return: minimum and maximum values which minimize reconstruction error
66+
"""
4567
from compressed_tensors.quantization.lifecycle import fake_quantize
4668

4769
absolute_min_val = torch.amin(observed, dim=(0, -1))
@@ -67,13 +89,17 @@ def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso
6789
quantization_args=self.args,
6890
global_scale=global_scale,
6991
)
70-
q = fake_quantize(
71-
observed,
72-
candidate_scales,
73-
candidate_zero_points,
74-
self.args,
75-
global_scale=global_scale,
76-
)
92+
93+
# Note that observed.shape = (num_observations, *qparams_shape, group_size).
94+
# For the purposes of fake quantization, this is equivalent to token quant
95+
with patch_attr(self.args, "strategy", QuantizationStrategy.TOKEN):
96+
q = fake_quantize(
97+
observed,
98+
candidate_scales.unsqueeze(-1),
99+
candidate_zero_points.unsqueeze(-1),
100+
self.args,
101+
global_scale=global_scale,
102+
)
77103

78104
q -= observed
79105
q.abs_()
@@ -92,3 +118,9 @@ def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso
92118
break
93119

94120
return min_val, max_val
121+
122+
def _lerp(
123+
self, input: torch.Tensor, end: torch.Tensor, weight: float
124+
) -> torch.Tensor:
125+
"""torch lerp_kernel is not implemeneted for all data types"""
126+
return (input * (1.0 - weight)) + (end * weight)

tests/llmcompressor/observers/test_mse.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,45 @@
1515

1616
import pytest
1717
import torch
18+
from compressed_tensors.quantization import fake_quantize
1819
from compressed_tensors.quantization.quant_args import QuantizationArgs
1920

2021
from llmcompressor.observers import MovingAverageMSEObserver, Observer
2122

2223

2324
@pytest.mark.parametrize(
24-
"symmetric,expected_scale,expected_zero_point",
25+
"strategy,symmetric,exp_loss",
2526
[
26-
(True, 0.0078, 0),
27-
(False, 0.0039, -128),
27+
("tensor", True, 4.8103e-06),
28+
("tensor", False, 1.1258e-06),
29+
("channel", True, 2.5675e-06),
30+
("channel", False, 2.3696e-07),
31+
("group", True, 3.1282e-06),
32+
("group", False, 1.3794e-07),
33+
("block", True, 2.8968e-06),
34+
("block", False, 5.6068e-07),
2835
],
2936
)
30-
def test_mse_observer(symmetric, expected_scale, expected_zero_point):
31-
tensor = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])
37+
def test_mse_observer(strategy, symmetric, exp_loss):
38+
tensor = torch.arange(24).reshape((6, 4)) / 24
3239
num_bits = 8
33-
weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse")
40+
weights = QuantizationArgs(
41+
num_bits=num_bits,
42+
strategy=strategy,
43+
symmetric=symmetric,
44+
group_size=(2 if strategy == "group" else None),
45+
block_structure=([3, 2] if strategy == "block" else None),
46+
observer="mse",
47+
)
3448

3549
observer = weights.observer
3650
observer = Observer.load_from_registry(observer, base_name="weight", args=weights)
37-
scale, zero_point = observer(tensor)
38-
3951
assert isinstance(observer, MovingAverageMSEObserver)
40-
assert round(scale.item(), 4) == expected_scale
41-
assert round(zero_point.item(), 4) == expected_zero_point
52+
53+
scale, zero_point = observer(tensor)
54+
q_tensor = fake_quantize(tensor, scale, zero_point, weights)
55+
mse_loss = torch.sum((tensor - q_tensor).abs_().pow_(2)) / tensor.numel()
56+
assert mse_loss == pytest.approx(exp_loss, abs=1e-10)
4257

4358

4459
def test_mse_observer_symmetric_scale_range():

0 commit comments

Comments
 (0)