Skip to content

Commit 7a3d43a

Browse files
committed
add tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 37b846a commit 7a3d43a

File tree

2 files changed

+343
-8
lines changed

2 files changed

+343
-8
lines changed

src/llmcompressor/observers/min_max.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Any, Optional, Tuple
1+
from typing import Any, Iterable, Optional, Tuple, Union
22

33
import torch
44
from compressed_tensors.quantization.quant_args import QuantizationArgs
55
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
6-
from compressed_tensors.utils import deprecated
6+
from compressed_tensors.utils import deprecated, patch_attr
77

88
from llmcompressor.observers.base import Observer
99

@@ -58,6 +58,8 @@ def calculate_updated_min_max(
5858

5959
# early stopping, save some computation and memory
6060
if self.averaging_constant == 1.0:
61+
self.min_val[tensor_id] = min_val
62+
self.max_val[tensor_id] = max_val
6163
return min_val, max_val
6264

6365
running_min_val = self.min_val.get(tensor_id, None)
@@ -86,9 +88,11 @@ def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor:
8688
:return: updated global scale derived from the observed tensor
8789
"""
8890

89-
updated_min_val, updated_max_val = self.calculate_updated_min_max(
90-
observed=observed
91-
)
91+
# patch to avoid affecting running means
92+
with patch_attr(self, "min_val", {}), patch_attr(self, "max_val", {}):
93+
updated_min_val, updated_max_val = self.calculate_updated_min_max(
94+
observed=observed
95+
)
9296
return generate_gparam(
9397
updated_min_val=updated_min_val, updated_max_val=updated_max_val
9498
)
@@ -126,14 +130,23 @@ def calculate_qparams(
126130
def get_qparams_along_dim(
127131
self,
128132
observed: torch.Tensor,
129-
dim: int,
133+
dim: Union[int, Iterable[int]],
130134
tensor_id: Optional[Any] = None,
131135
global_scale: Optional[torch.Tensor] = None,
132136
):
133137
"""
134138
Calculate quantization parameters along the specified dimension
135139
"""
136-
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
140+
# cast to set
141+
if isinstance(dim, int):
142+
dim = [dim]
143+
dim = set(dim)
144+
145+
# convert negative dims
146+
dim = [d if d >= 0 else observed.ndim + d for d in dim]
147+
148+
# reduce all dimensions except the the one passed as argument to this function
149+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
137150
return self.calculate_qparams(
138151
observed,
139152
reduce_dims=reduce_dims,

tests/llmcompressor/modifiers/calibration/test_observers.py

Lines changed: 323 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
initialize_module_for_quantization,
77
)
88

9-
from llmcompressor.modifiers.quantization.calibration import initialize_observer
9+
from llmcompressor.modifiers.quantization.calibration import (
10+
calibrate_input_hook,
11+
initialize_observer,
12+
update_weight_global_scale,
13+
update_weight_zp_scale,
14+
)
1015

1116

1217
@pytest.mark.parametrize(
@@ -59,3 +64,320 @@ def test_observers_update(shape, group_size, actorder):
5964
def assert_alike(a, b):
6065
assert a.dtype == b.dtype
6166
assert a.shape == b.shape
67+
68+
69+
@pytest.mark.parametrize(
70+
"args,exp_min_val,exp_max_val,exp_quant,exp_loss",
71+
[
72+
(
73+
QuantizationArgs(
74+
num_bits=4,
75+
type="int",
76+
symmetric=True,
77+
strategy="tensor", # equivalent to token
78+
observer="minmax",
79+
),
80+
{"default": torch.tensor(0.0)},
81+
{"default": torch.tensor(23.0)},
82+
torch.tensor(
83+
[
84+
[0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250],
85+
[6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500],
86+
[12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750],
87+
[18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000],
88+
],
89+
dtype=torch.bfloat16,
90+
),
91+
0.85,
92+
),
93+
(
94+
QuantizationArgs(
95+
num_bits=4,
96+
type="int",
97+
symmetric=True,
98+
strategy="channel",
99+
observer="minmax",
100+
),
101+
{"default": torch.tensor([[0], [6], [12], [18]])},
102+
{"default": torch.tensor([[5], [11], [17], [23]])},
103+
torch.tensor(
104+
[
105+
[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875],
106+
[5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500],
107+
[11.3125, 13.6250, 13.6250, 15.8750, 15.8750, 15.8750],
108+
[18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000],
109+
],
110+
dtype=torch.bfloat16,
111+
),
112+
0.45,
113+
),
114+
(
115+
QuantizationArgs(
116+
num_bits=4,
117+
type="int",
118+
symmetric=True,
119+
strategy="group",
120+
group_size=3,
121+
observer="minmax",
122+
),
123+
{
124+
"default": torch.tensor([[0], [6], [12], [18]]),
125+
1: torch.tensor([[3], [9], [15], [21]]),
126+
},
127+
{
128+
"default": torch.tensor([[2], [8], [14], [20]]),
129+
1: torch.tensor([[5], [11], [17], [23]]),
130+
},
131+
torch.tensor(
132+
[
133+
[0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875],
134+
[6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500],
135+
[11.1875, 13.0625, 13.0625, 15.8750, 15.8750, 15.8750],
136+
[18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000],
137+
],
138+
),
139+
0.45,
140+
),
141+
(
142+
QuantizationArgs(
143+
num_bits=4,
144+
type="float", # tensor group requires FP4
145+
symmetric=True,
146+
strategy="tensor_group", # requires float4
147+
group_size=3,
148+
observer="minmax",
149+
),
150+
{
151+
"default": torch.tensor([[0], [6], [12], [18]]),
152+
1: torch.tensor([[3], [9], [15], [21]]),
153+
},
154+
{
155+
"default": torch.tensor([[2], [8], [14], [20]]),
156+
1: torch.tensor([[5], [11], [17], [23]]),
157+
},
158+
torch.tensor(
159+
[
160+
[0.0000, 1.0234, 2.0469, 3.2812, 3.2812, 4.9375],
161+
[5.4688, 8.1875, 8.1875, 10.6875, 10.6875, 10.6875],
162+
[9.8750, 14.7500, 14.7500, 16.3750, 16.3750, 16.3750],
163+
[19.7500, 19.7500, 19.7500, 23.0000, 23.0000, 23.0000],
164+
],
165+
),
166+
1.1,
167+
),
168+
(
169+
QuantizationArgs(
170+
num_bits=4,
171+
type="int",
172+
symmetric=True,
173+
strategy="block",
174+
block_structure=[2, 3],
175+
observer="minmax",
176+
),
177+
{
178+
"block_0_0": torch.tensor([[0]]),
179+
"block_0_1": torch.tensor([[3]]),
180+
"block_1_0": torch.tensor([[12]]),
181+
"block_1_1": torch.tensor([[15]]),
182+
},
183+
{
184+
"block_0_0": torch.tensor([[8]]),
185+
"block_0_1": torch.tensor([[11]]),
186+
"block_1_0": torch.tensor([[20]]),
187+
"block_1_1": torch.tensor([[23]]),
188+
},
189+
torch.tensor(
190+
[
191+
[0.0000, 1.0703, 2.1406, 2.9375, 4.4062, 4.4062],
192+
[6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500],
193+
[10.6875, 13.3750, 13.3750, 15.3125, 15.3125, 18.3750],
194+
[18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000],
195+
],
196+
),
197+
0.5,
198+
),
199+
(
200+
QuantizationArgs(
201+
num_bits=4,
202+
type="int",
203+
symmetric=True,
204+
strategy="token", # equivalent to tensor
205+
observer="minmax",
206+
),
207+
{"default": torch.tensor(0.0)},
208+
{"default": torch.tensor(23.0)},
209+
torch.tensor(
210+
[
211+
[0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250],
212+
[6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500],
213+
[12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750],
214+
[18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000],
215+
],
216+
dtype=torch.bfloat16,
217+
),
218+
0.85,
219+
),
220+
],
221+
)
222+
def test_static_weight_quantization(
223+
args, exp_min_val, exp_max_val, exp_quant, exp_loss
224+
):
225+
"""
226+
weight = tensor([[ 0, 1, 2, 3, 4, 5],
227+
[ 6, 7, 8, 9, 10, 11],
228+
[12, 13, 14, 15, 16, 17],
229+
[18, 19, 20, 21, 22, 23]])
230+
"""
231+
# set up weight
232+
input_size, output_size = 6, 4
233+
linear = torch.nn.Linear(input_size, output_size, bias=False)
234+
linear.weight.data = torch.arange(
235+
input_size * output_size, dtype=torch.bfloat16
236+
).reshape(output_size, input_size)
237+
238+
# initialize quantization parameters
239+
scheme = QuantizationScheme(targets=[], weights=args)
240+
initialize_module_for_quantization(linear, scheme)
241+
assert getattr(linear, "quantization_scheme") is scheme
242+
243+
# calibrate quantization parameters
244+
initialize_observer(linear, "weight")
245+
update_weight_global_scale(linear)
246+
update_weight_zp_scale(linear)
247+
248+
observer = getattr(linear, "weight_observer")
249+
assert (
250+
observer.min_val.keys()
251+
== observer.max_val.keys()
252+
== exp_min_val.keys()
253+
== exp_max_val.keys()
254+
)
255+
for key in observer.min_val.keys():
256+
assert torch.equal(observer.min_val[key], exp_min_val[key])
257+
assert torch.equal(observer.max_val[key], exp_max_val[key])
258+
259+
# forward pass
260+
input = torch.eye(input_size, dtype=torch.bfloat16)
261+
output = linear(input)
262+
263+
print(output.T)
264+
print(torch.nn.functional.mse_loss(output.T, linear.weight))
265+
assert torch.allclose(output.T, exp_quant.to(output.dtype))
266+
assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss
267+
268+
269+
@pytest.mark.parametrize(
270+
"args,exp_min_val,exp_max_val,exp_quant,exp_loss",
271+
[
272+
(
273+
QuantizationArgs(
274+
num_bits=4,
275+
type="int",
276+
symmetric=True,
277+
strategy="tensor", # equivalent to token
278+
observer="minmax",
279+
),
280+
{"default": torch.tensor(0.0)},
281+
{"default": torch.tensor(5.0)},
282+
torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]),
283+
0.06,
284+
),
285+
(
286+
QuantizationArgs(
287+
num_bits=4,
288+
type="int",
289+
symmetric=True,
290+
strategy="token", # equivalent to tensor
291+
observer="minmax",
292+
),
293+
{"default": torch.tensor(0.0)},
294+
{"default": torch.tensor(5.0)},
295+
torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]),
296+
0.06,
297+
),
298+
# channel is not supported, but is in principle equivalent to token/tensor
299+
# (
300+
# QuantizationArgs(
301+
# num_bits=4,
302+
# type="int",
303+
# symmetric=True,
304+
# strategy="group",
305+
# group_size=3,
306+
# observer="minmax",
307+
# ),
308+
# {
309+
# "default": torch.tensor([[0]]),
310+
# 1: torch.tensor([[3]]),
311+
# },
312+
# {
313+
# "default": torch.tensor([[2]]),
314+
# 1: torch.tensor([[5]]),
315+
# },
316+
# torch.tensor([[0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875]]),
317+
# 0.04,
318+
# ),
319+
# (
320+
# QuantizationArgs(
321+
# num_bits=4,
322+
# type="float", # tensor group requires FP4
323+
# symmetric=True,
324+
# strategy="tensor_group",
325+
# group_size=3,
326+
# observer="minmax",
327+
# ),
328+
# {
329+
# "default": torch.tensor([[0]]),
330+
# 1: torch.tensor([[3]]),
331+
# },
332+
# {
333+
# "default": torch.tensor([[2]]),
334+
# 1: torch.tensor([[5]]),
335+
# },
336+
# torch.tensor([[0.0000, 0.9766, 1.9531, 3.3125, 3.3125, 4.9688]]),
337+
# 0.1,
338+
# ),
339+
# block is not supported, but is in principle similar to group
340+
],
341+
)
342+
def test_static_activation_quantization(
343+
args, exp_min_val, exp_max_val, exp_quant, exp_loss
344+
):
345+
"""
346+
input = tensor([[ 0, 1, 2, 3, 4, 5]])
347+
"""
348+
# set up activation (and identity weight)
349+
input_size = 6
350+
input = torch.arange(input_size, dtype=torch.bfloat16).unsqueeze(0)
351+
linear = torch.nn.Linear(input_size, input_size, bias=False)
352+
linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16)
353+
354+
# initialize quantization parameters
355+
scheme = QuantizationScheme(targets=[], input_activations=args)
356+
initialize_module_for_quantization(linear, scheme)
357+
assert getattr(linear, "quantization_scheme") is scheme
358+
359+
# calibrate quantization parameters
360+
initialize_observer(linear, "input")
361+
linear.register_forward_pre_hook(calibrate_input_hook)
362+
363+
# calibration forward pass
364+
output = linear(input)
365+
366+
# check calibration
367+
observer = getattr(linear, "input_observer")
368+
assert (
369+
observer.min_val.keys()
370+
== observer.max_val.keys()
371+
== exp_min_val.keys()
372+
== exp_max_val.keys()
373+
)
374+
for key in observer.min_val.keys():
375+
assert torch.equal(observer.min_val[key], exp_min_val[key])
376+
assert torch.equal(observer.max_val[key], exp_max_val[key])
377+
378+
# check forward pass
379+
print(args.strategy)
380+
print(output)
381+
print(torch.nn.functional.mse_loss(output, input))
382+
assert torch.allclose(output, exp_quant.to(output.dtype))
383+
assert torch.nn.functional.mse_loss(output, input) <= exp_loss

0 commit comments

Comments
 (0)