Skip to content

Commit ad5f719

Browse files
committed
WIP
Signed-off-by: Kyle Sayers <[email protected]>
1 parent b4dfb19 commit ad5f719

File tree

2 files changed

+33
-40
lines changed

2 files changed

+33
-40
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
200200
calculate_gparam = True
201201

202202
# (..., 1, hidden_dim)
203-
# this reshaping is mostly for the benefit of group quantization
204-
value = value.unsqueeze(-2)
203+
# the second to last dim indicates that activations have one output channel
204+
value = value.flatten(0, -1).unsqueeze(-2)
205205

206206
call_observer(
207207
module=module,

tests/llmcompressor/modifiers/calibration/test_observers.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def assert_alike(a, b):
157157
),
158158
],
159159
)
160-
def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
160+
def test_static_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
161161
# set up weight
162162
input_size, output_size = 6, 4
163163
linear = torch.nn.Linear(input_size, output_size, bias=False)
@@ -205,7 +205,7 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
205205
observer="minmax",
206206
),
207207
{"default": torch.tensor(0.0, dtype=torch.bfloat16)},
208-
{"default": torch.tensor(23.0, dtype=torch.bfloat16)},
208+
{"default": torch.tensor(5.0, dtype=torch.bfloat16)},
209209
2.5,
210210
),
211211
(
@@ -216,28 +216,23 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
216216
strategy="token",
217217
observer="minmax",
218218
),
219-
{"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16)},
220-
{"default": torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16)},
221-
2.5,
222-
),
223-
(
224-
QuantizationArgs(
225-
num_bits=4,
226-
type="int",
227-
symmetric=True,
228-
strategy="channel",
229-
observer="minmax",
230-
),
231-
{
232-
"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16),
233-
1: torch.tensor([[3], [9], [15], [21]], dtype=torch.bfloat16),
234-
},
235-
{
236-
"default": torch.tensor([[2], [8], [14], [20]], dtype=torch.bfloat16),
237-
1: torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16),
238-
},
219+
{"default": torch.tensor(0.0, dtype=torch.bfloat16)},
220+
{"default": torch.tensor(5.0, dtype=torch.bfloat16)},
239221
2.5,
240222
),
223+
# channel is not supported, but (tensor == token == channel)
224+
# (
225+
# QuantizationArgs(
226+
# num_bits=4,
227+
# type="int",
228+
# symmetric=True,
229+
# strategy="channel",
230+
# observer="minmax",
231+
# ),
232+
# {"default": torch.tensor(0.0, dtype=torch.bfloat16)},
233+
# {"default": torch.tensor(5.0, dtype=torch.bfloat16)},
234+
# 2.5,
235+
# ),
241236
(
242237
QuantizationArgs(
243238
num_bits=4,
@@ -248,12 +243,12 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
248243
observer="minmax",
249244
),
250245
{
251-
"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16),
252-
1: torch.tensor([[3], [9], [15], [21]], dtype=torch.bfloat16),
246+
"default": torch.tensor([[0]], dtype=torch.bfloat16),
247+
1: torch.tensor([[3]], dtype=torch.bfloat16),
253248
},
254249
{
255-
"default": torch.tensor([[2], [8], [14], [20]], dtype=torch.bfloat16),
256-
1: torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16),
250+
"default": torch.tensor([[2]], dtype=torch.bfloat16),
251+
1: torch.tensor([[5]], dtype=torch.bfloat16),
257252
},
258253
2.5,
259254
),
@@ -267,12 +262,12 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
267262
observer="minmax",
268263
),
269264
{
270-
"default": torch.tensor([[0], [6], [12], [18]], dtype=torch.bfloat16),
271-
1: torch.tensor([[3], [9], [15], [21]], dtype=torch.bfloat16),
265+
"default": torch.tensor([[0]], dtype=torch.bfloat16),
266+
1: torch.tensor([[3]], dtype=torch.bfloat16),
272267
},
273268
{
274-
"default": torch.tensor([[2], [8], [14], [20]], dtype=torch.bfloat16),
275-
1: torch.tensor([[5], [11], [17], [23]], dtype=torch.bfloat16),
269+
"default": torch.tensor([[2]], dtype=torch.bfloat16),
270+
1: torch.tensor([[5]], dtype=torch.bfloat16),
276271
},
277272
2.5,
278273
),
@@ -301,7 +296,7 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
301296
# ),
302297
],
303298
)
304-
def test_activation_quantization(args, exp_min_val, exp_max_val, exp_tol):
299+
def test_static_activation_quantization(args, exp_min_val, exp_max_val, exp_tol):
305300
# set up activation (and identity weight)
306301
input_size = 6
307302
input = torch.arange(input_size, dtype=torch.bfloat16).unsqueeze(0)
@@ -317,7 +312,11 @@ def test_activation_quantization(args, exp_min_val, exp_max_val, exp_tol):
317312
initialize_observer(linear, "input")
318313
linear.register_forward_pre_hook(calibrate_input_hook)
319314

315+
# calibration forward pass
316+
output = linear(input)
317+
assert torch.allclose(output, input, atol=exp_tol)
320318

319+
# check calibration
321320
observer = getattr(linear, "input_observer")
322321
breakpoint()
323322
assert (
@@ -328,10 +327,4 @@ def test_activation_quantization(args, exp_min_val, exp_max_val, exp_tol):
328327
)
329328
for key in observer.min_val.keys():
330329
assert torch.equal(observer.min_val[key], exp_min_val[key])
331-
assert torch.equal(observer.max_val[key], exp_max_val[key])
332-
333-
# forward pass
334-
linear.quantization_status = QuantizationStatus.FROZEN
335-
output = linear(input)
336-
true_output = input # (@ linear.weight.T == eye)
337-
assert torch.allclose(output, true_output, atol=exp_tol)
330+
assert torch.equal(observer.max_val[key], exp_max_val[key])

0 commit comments

Comments
 (0)