Skip to content

Commit b4dfb19

Browse files
committed
WIP
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 5ef8bdf commit b4dfb19

File tree

3 files changed

+290
-5
lines changed

3 files changed

+290
-5
lines changed

src/llmcompressor/observers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def get_qparams_along_dim(
256256
# convert negative dims
257257
dim = [d if d >= 0 else observed.ndim + d for d in dim]
258258

259-
# reduce all dimensions except the the one pass as argument to this function
259+
# reduce all dimensions except the the one passed as argument to this function
260260
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
261261
return self.calculate_qparams(
262262
observed,

src/llmcompressor/observers/min_max.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Tuple
1+
from typing import Any, Optional, Tuple, Union, Iterable
22

33
import torch
44
from compressed_tensors.quantization.quant_args import QuantizationArgs
@@ -128,14 +128,23 @@ def calculate_qparams(
128128
def get_qparams_along_dim(
129129
self,
130130
observed: torch.Tensor,
131-
dim: int,
131+
dim: Union[int, Iterable[int]],
132132
tensor_id: Optional[Any] = None,
133133
global_scale: Optional[torch.Tensor] = None,
134134
):
135135
"""
136136
Calculate quantization parameters along the specified dimension
137137
"""
138-
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
138+
# cast to set
139+
if isinstance(dim, int):
140+
dim = [dim]
141+
dim = set(dim)
142+
143+
# convert negative dims
144+
dim = [d if d >= 0 else observed.ndim + d for d in dim]
145+
146+
# reduce all dimensions except the the one passed as argument to this function
147+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
139148
return self.calculate_qparams(
140149
observed,
141150
reduce_dims=reduce_dims,

tests/llmcompressor/modifiers/calibration/test_observers.py

Lines changed: 277 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
QuantizationArgs,
55
QuantizationScheme,
66
initialize_module_for_quantization,
7+
QuantizationStatus,
78
)
89

9-
from llmcompressor.modifiers.quantization.calibration import initialize_observer
10+
from llmcompressor.modifiers.quantization.calibration import (
11+
initialize_observer,
12+
update_weight_zp_scale,
13+
update_weight_global_scale,
14+
calibrate_input_hook,
15+
)
1016

1117

1218
@pytest.mark.parametrize(
@@ -59,3 +65,273 @@ def test_observers_update(shape, group_size, actorder):
5965
def assert_alike(a, b):
6066
assert a.dtype == b.dtype
6167
assert a.shape == b.shape
68+
69+
70+
@pytest.mark.parametrize(
71+
"args,exp_min_val,exp_max_val,exp_tol",
72+
[
73+
(
74+
QuantizationArgs(
75+
num_bits=4,
76+
type="int",
77+
symmetric=True,
78+
strategy="tensor",
79+
observer="minmax",
80+
),
81+
{"default": torch.tensor(0.0, dtype=torch.bfloat16)},
82+
{"default": torch.tensor(23.0, dtype=torch.bfloat16)},
83+
2.5,
84+
),
85+
(
86+
QuantizationArgs(
87+
num_bits=4,
88+
type="int",
89+
symmetric=True,
90+
strategy="channel",
91+
observer="minmax",
92+
),
93+
{"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16)},
94+
{"default": torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16)},
95+
2.5,
96+
),
97+
(
98+
QuantizationArgs(
99+
num_bits=4,
100+
type="int",
101+
symmetric=True,
102+
strategy="group",
103+
group_size=3,
104+
observer="minmax",
105+
),
106+
{
107+
"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16),
108+
1: torch.tensor([[3], [9], [15], [21]], dtype=torch.bfloat16),
109+
},
110+
{
111+
"default": torch.tensor([[2], [8], [14], [20]], dtype=torch.bfloat16),
112+
1: torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16),
113+
},
114+
2.5,
115+
),
116+
(
117+
QuantizationArgs(
118+
num_bits=4,
119+
type="float",
120+
symmetric=True,
121+
strategy="tensor_group",
122+
group_size=3,
123+
observer="minmax",
124+
),
125+
{
126+
"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16),
127+
1: torch.tensor([[3], [9], [15], [21]], dtype=torch.bfloat16),
128+
},
129+
{
130+
"default": torch.tensor([[2], [8], [14], [20]], dtype=torch.bfloat16),
131+
1: torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16),
132+
},
133+
5.0,
134+
),
135+
(
136+
QuantizationArgs(
137+
num_bits=4,
138+
type="int",
139+
symmetric=True,
140+
strategy="block",
141+
block_structure=[2, 3],
142+
observer="minmax",
143+
),
144+
{
145+
"block_0_0": torch.tensor([[0]], dtype=torch.bfloat16),
146+
"block_0_1": torch.tensor([[3]], dtype=torch.bfloat16),
147+
"block_1_0": torch.tensor([[12]], dtype=torch.bfloat16),
148+
"block_1_1": torch.tensor([[15]], dtype=torch.bfloat16),
149+
},
150+
{
151+
"block_0_0": torch.tensor([[8]], dtype=torch.bfloat16),
152+
"block_0_1": torch.tensor([[11]], dtype=torch.bfloat16),
153+
"block_1_0": torch.tensor([[20]], dtype=torch.bfloat16),
154+
"block_1_1": torch.tensor([[23]], dtype=torch.bfloat16),
155+
},
156+
2.5,
157+
),
158+
],
159+
)
160+
def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
161+
# set up weight
162+
input_size, output_size = 6, 4
163+
linear = torch.nn.Linear(input_size, output_size, bias=False)
164+
linear.weight.data = torch.arange(
165+
input_size * output_size, dtype=torch.bfloat16
166+
).reshape(output_size, input_size)
167+
168+
# initialize quantization parameters
169+
scheme = QuantizationScheme(targets=[], weights=args)
170+
initialize_module_for_quantization(linear, scheme)
171+
assert getattr(linear, "quantization_scheme") is scheme
172+
173+
# calibrate quantization parameters
174+
initialize_observer(linear, "weight")
175+
update_weight_global_scale(linear)
176+
update_weight_zp_scale(linear)
177+
178+
observer = getattr(linear, "weight_observer")
179+
assert (
180+
observer.min_val.keys()
181+
== observer.max_val.keys()
182+
== exp_min_val.keys()
183+
== exp_max_val.keys()
184+
)
185+
for key in observer.min_val.keys():
186+
assert torch.equal(observer.min_val[key], exp_min_val[key])
187+
assert torch.equal(observer.max_val[key], exp_max_val[key])
188+
189+
# forward pass
190+
input = torch.rand((1, input_size), dtype=torch.bfloat16)
191+
output = linear(input)
192+
true_output = input @ linear.weight.T
193+
assert torch.allclose(output, true_output, atol=exp_tol)
194+
195+
196+
@pytest.mark.parametrize(
197+
"args,exp_min_val,exp_max_val,exp_tol",
198+
[
199+
(
200+
QuantizationArgs(
201+
num_bits=4,
202+
type="int",
203+
symmetric=True,
204+
strategy="tensor",
205+
observer="minmax",
206+
),
207+
{"default": torch.tensor(0.0, dtype=torch.bfloat16)},
208+
{"default": torch.tensor(23.0, dtype=torch.bfloat16)},
209+
2.5,
210+
),
211+
(
212+
QuantizationArgs(
213+
num_bits=4,
214+
type="int",
215+
symmetric=True,
216+
strategy="token",
217+
observer="minmax",
218+
),
219+
{"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16)},
220+
{"default": torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16)},
221+
2.5,
222+
),
223+
(
224+
QuantizationArgs(
225+
num_bits=4,
226+
type="int",
227+
symmetric=True,
228+
strategy="channel",
229+
observer="minmax",
230+
),
231+
{
232+
"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16),
233+
1: torch.tensor([[3], [9], [15], [21]], dtype=torch.bfloat16),
234+
},
235+
{
236+
"default": torch.tensor([[2], [8], [14], [20]], dtype=torch.bfloat16),
237+
1: torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16),
238+
},
239+
2.5,
240+
),
241+
(
242+
QuantizationArgs(
243+
num_bits=4,
244+
type="int",
245+
symmetric=True,
246+
strategy="group",
247+
group_size=3,
248+
observer="minmax",
249+
),
250+
{
251+
"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16),
252+
1: torch.tensor([[3], [9], [15], [21]], dtype=torch.bfloat16),
253+
},
254+
{
255+
"default": torch.tensor([[2], [8], [14], [20]], dtype=torch.bfloat16),
256+
1: torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16),
257+
},
258+
2.5,
259+
),
260+
(
261+
QuantizationArgs(
262+
num_bits=4,
263+
type="float",
264+
symmetric=True,
265+
strategy="tensor_group",
266+
group_size=3,
267+
observer="minmax",
268+
),
269+
{
270+
"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16),
271+
1: torch.tensor([[3], [9], [15], [21]], dtype=torch.bfloat16),
272+
},
273+
{
274+
"default": torch.tensor([[2], [8], [14], [20]], dtype=torch.bfloat16),
275+
1: torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16),
276+
},
277+
2.5,
278+
),
279+
# (
280+
# QuantizationArgs(
281+
# num_bits=4,
282+
# type="int",
283+
# symmetric=True,
284+
# strategy="block",
285+
# block_structure=[2, 3],
286+
# observer="minmax",
287+
# ),
288+
# {
289+
# "block_0_0": torch.tensor([[0]], dtype=torch.bfloat16),
290+
# "block_0_1": torch.tensor([[3]], dtype=torch.bfloat16),
291+
# "block_1_0": torch.tensor([[12]], dtype=torch.bfloat16),
292+
# "block_1_1": torch.tensor([[15]], dtype=torch.bfloat16),
293+
# },
294+
# {
295+
# "block_0_0": torch.tensor([[8]], dtype=torch.bfloat16),
296+
# "block_0_1": torch.tensor([[11]], dtype=torch.bfloat16),
297+
# "block_1_0": torch.tensor([[20]], dtype=torch.bfloat16),
298+
# "block_1_1": torch.tensor([[23]], dtype=torch.bfloat16),
299+
# },
300+
# 2.5,
301+
# ),
302+
],
303+
)
304+
def test_activation_quantization(args, exp_min_val, exp_max_val, exp_tol):
305+
# set up activation (and identity weight)
306+
input_size = 6
307+
input = torch.arange(input_size, dtype=torch.bfloat16).unsqueeze(0)
308+
linear = torch.nn.Linear(input_size, input_size, bias=False)
309+
linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16)
310+
311+
# initialize quantization parameters
312+
scheme = QuantizationScheme(targets=[], input_activations=args)
313+
initialize_module_for_quantization(linear, scheme)
314+
assert getattr(linear, "quantization_scheme") is scheme
315+
316+
# calibrate quantization parameters
317+
initialize_observer(linear, "input")
318+
linear.register_forward_pre_hook(calibrate_input_hook)
319+
320+
321+
observer = getattr(linear, "input_observer")
322+
breakpoint()
323+
assert (
324+
observer.min_val.keys()
325+
== observer.max_val.keys()
326+
== exp_min_val.keys()
327+
== exp_max_val.keys()
328+
)
329+
for key in observer.min_val.keys():
330+
assert torch.equal(observer.min_val[key], exp_min_val[key])
331+
assert torch.equal(observer.max_val[key], exp_max_val[key])
332+
333+
# forward pass
334+
linear.quantization_status = QuantizationStatus.FROZEN
335+
output = linear(input)
336+
true_output = input # (@ linear.weight.T == eye)
337+
assert torch.allclose(output, true_output, atol=exp_tol)

0 commit comments

Comments
 (0)