Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from compressed_tensors.quantization.utils import is_fp4
from compressed_tensors.registry.registry import RegistryMixin
from compressed_tensors.utils import safe_permute
from loguru import logger
from torch import FloatTensor, IntTensor, Tensor

Expand Down Expand Up @@ -51,8 +50,12 @@ def forward(
:return: tuple of scale and zero point based on last observed value
"""
self.record_observed_tokens(observed)

if should_calculate_gparam:
# NOTE: this function updates running min/max values, which leads to
# running values updating twice
return self.get_gparam(observed=observed)

return self.get_qparams(
observed=observed,
g_idx=g_idx,
Expand Down Expand Up @@ -168,8 +171,7 @@ def get_qparams(
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
group_sizes = group_sizes[torch.argsort(group_indices)]

perm = torch.argsort(g_idx)
observed = safe_permute(observed, perm, dim=1)
observed = observed.index_select(-1, g_idx)

# TODO: experiment with vectorizing for loop for performance
end = 0
Expand Down
20 changes: 16 additions & 4 deletions src/llmcompressor/observers/min_max.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple
from typing import Any, Iterable, Optional, Tuple, Union

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
Expand Down Expand Up @@ -58,6 +58,8 @@ def calculate_updated_min_max(

# early stopping, save some computation and memory
if self.averaging_constant == 1.0:
self.min_val[tensor_id] = min_val
self.max_val[tensor_id] = max_val
return min_val, max_val

running_min_val = self.min_val.get(tensor_id, None)
Expand Down Expand Up @@ -85,7 +87,8 @@ def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor:
:param observed: observed tensor to calculate quantization parameters for
:return: updated global scale derived from the observed tensor
"""

# NOTE: this function updates running min/max values, which leads to
# running values updating twice
updated_min_val, updated_max_val = self.calculate_updated_min_max(
observed=observed
)
Expand Down Expand Up @@ -126,14 +129,23 @@ def calculate_qparams(
def get_qparams_along_dim(
self,
observed: torch.Tensor,
dim: int,
dim: Union[int, Iterable[int]],
tensor_id: Optional[Any] = None,
global_scale: Optional[torch.Tensor] = None,
):
"""
Calculate quantization parameters along the specified dimension
"""
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
# cast to set
if isinstance(dim, int):
dim = [dim]
dim = set(dim)

# convert negative dims
dim = [d if d >= 0 else observed.ndim + d for d in dim]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the cast to set happen after this line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically either is fine, since the argument type just needs to be an iterable. I'm purely matching the implementation on the base model for now

Update get_qparams_along_dim to support multiple dims and negative dims
This actually results in a silent typing bug with token quantization, and is fixed on the base class implementation
This change essentially duplicates the base class implementation. Future work could involve cleaning up the inheritance structure here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean more that you might end up with duplicates in dim if you create this list and don't cast back to a set.

e.g. if there are 3 dims and dim={1,2,-1}, then dim=[1,2,2] after this line.


# reduce all dimensions except the ones passed as argument to this function
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
return self.calculate_qparams(
observed,
reduce_dims=reduce_dims,
Expand Down
281 changes: 280 additions & 1 deletion tests/llmcompressor/modifiers/calibration/test_observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
initialize_module_for_quantization,
)

from llmcompressor.modifiers.quantization.calibration import initialize_observer
from llmcompressor.modifiers.quantization.calibration import (
calibrate_input_hook,
initialize_observer,
update_weight_global_scale,
update_weight_zp_scale,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -59,3 +64,277 @@ def test_observers_update(shape, group_size, actorder):
def assert_alike(a, b):
assert a.dtype == b.dtype
assert a.shape == b.shape


@pytest.mark.parametrize(
"args,exp_min_val,exp_max_val,exp_quant,exp_loss",
[
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="tensor", # equivalent to token
observer="minmax",
),
{"default": torch.tensor(0.0)},
{"default": torch.tensor(23.0)},
torch.tensor(
[
[0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250],
[6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500],
[12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750],
[18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000],
],
dtype=torch.bfloat16,
),
0.85,
),
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="channel",
observer="minmax",
),
{"default": torch.tensor([[0], [6], [12], [18]])},
{"default": torch.tensor([[5], [11], [17], [23]])},
torch.tensor(
[
[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875],
[5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500],
[11.3125, 13.6250, 13.6250, 15.8750, 15.8750, 15.8750],
[18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000],
],
dtype=torch.bfloat16,
),
0.45,
),
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="group",
group_size=3,
observer="minmax",
),
{
"default": torch.tensor([[0], [6], [12], [18]]),
1: torch.tensor([[3], [9], [15], [21]]),
},
{
"default": torch.tensor([[2], [8], [14], [20]]),
1: torch.tensor([[5], [11], [17], [23]]),
},
torch.tensor(
[
[0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875],
[6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500],
[11.1875, 13.0625, 13.0625, 15.8750, 15.8750, 15.8750],
[18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000],
],
),
0.45,
),
(
QuantizationArgs(
num_bits=4,
type="float", # tensor group requires FP4
symmetric=True,
strategy="tensor_group", # requires float4
group_size=3,
observer="minmax",
),
{
"default": torch.tensor([[0], [6], [12], [18]]),
1: torch.tensor([[3], [9], [15], [21]]),
},
{
"default": torch.tensor([[2], [8], [14], [20]]),
1: torch.tensor([[5], [11], [17], [23]]),
},
torch.tensor(
[
[0.0000, 1.0234, 2.0469, 3.2812, 3.2812, 4.9375],
[5.4688, 8.1875, 8.1875, 10.6875, 10.6875, 10.6875],
[9.8750, 14.7500, 14.7500, 16.3750, 16.3750, 16.3750],
[19.7500, 19.7500, 19.7500, 23.0000, 23.0000, 23.0000],
],
),
1.1,
),
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="block",
block_structure=[2, 3],
observer="minmax",
),
{
"block_0_0": torch.tensor([[0]]),
"block_0_1": torch.tensor([[3]]),
"block_1_0": torch.tensor([[12]]),
"block_1_1": torch.tensor([[15]]),
},
{
"block_0_0": torch.tensor([[8]]),
"block_0_1": torch.tensor([[11]]),
"block_1_0": torch.tensor([[20]]),
"block_1_1": torch.tensor([[23]]),
},
torch.tensor(
[
[0.0000, 1.0703, 2.1406, 2.9375, 4.4062, 4.4062],
[6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500],
[10.6875, 13.3750, 13.3750, 15.3125, 15.3125, 18.3750],
[18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000],
],
),
0.5,
),
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="token", # equivalent to tensor
observer="minmax",
),
{"default": torch.tensor(0.0)},
{"default": torch.tensor(23.0)},
torch.tensor(
[
[0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250],
[6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500],
[12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750],
[18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000],
],
dtype=torch.bfloat16,
),
0.85,
),
],
)
def test_static_weight_quantization(
args, exp_min_val, exp_max_val, exp_quant, exp_loss
):
"""
weight = tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
"""
# set up weight
input_size, output_size = 6, 4
linear = torch.nn.Linear(input_size, output_size, bias=False)
linear.weight.data = torch.arange(
input_size * output_size, dtype=torch.bfloat16
).reshape(output_size, input_size)

# initialize quantization parameters
scheme = QuantizationScheme(targets=[], weights=args)
initialize_module_for_quantization(linear, scheme)
assert getattr(linear, "quantization_scheme") is scheme

# calibrate quantization parameters
initialize_observer(linear, "weight")
update_weight_global_scale(linear)
update_weight_zp_scale(linear)

observer = getattr(linear, "weight_observer")
assert (
observer.min_val.keys()
== observer.max_val.keys()
== exp_min_val.keys()
== exp_max_val.keys()
)
for key in observer.min_val.keys():
assert torch.equal(observer.min_val[key], exp_min_val[key])
assert torch.equal(observer.max_val[key], exp_max_val[key])

# forward pass
input = torch.eye(input_size, dtype=torch.bfloat16)
output = linear(input)

assert torch.allclose(output.T, exp_quant.to(output.dtype))
assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss


@pytest.mark.parametrize(
"args,exp_min_val,exp_max_val,exp_quant,exp_loss",
[
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="tensor", # equivalent to token
observer="minmax",
),
{"default": torch.tensor(0.0)},
{"default": torch.tensor(5.0)},
torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]),
0.06,
),
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="token", # equivalent to tensor
observer="minmax",
),
{"default": torch.tensor(0.0)},
{"default": torch.tensor(5.0)},
torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]),
0.06,
),
# channel is not supported, but is in principle equivalent to token/tensor
# group is not yet supported
# tensor_group is not yet supported
# block is not supported, but is in principle similar to group
],
)
def test_static_activation_quantization(
args, exp_min_val, exp_max_val, exp_quant, exp_loss
):
"""
input = tensor([[ 0, 1, 2, 3, 4, 5]])
"""
# set up activation (and identity weight)
input_size = 6
input = torch.arange(input_size, dtype=torch.bfloat16).unsqueeze(0)
linear = torch.nn.Linear(input_size, input_size, bias=False)
linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16)

# initialize quantization parameters
scheme = QuantizationScheme(targets=[], input_activations=args)
initialize_module_for_quantization(linear, scheme)
assert getattr(linear, "quantization_scheme") is scheme

# calibrate quantization parameters
initialize_observer(linear, "input")
linear.register_forward_pre_hook(calibrate_input_hook)

# calibration forward pass
output = linear(input)

# check calibration
observer = getattr(linear, "input_observer")
assert (
observer.min_val.keys()
== observer.max_val.keys()
== exp_min_val.keys()
== exp_max_val.keys()
)
for key in observer.min_val.keys():
assert torch.equal(observer.min_val[key], exp_min_val[key])
assert torch.equal(observer.max_val[key], exp_max_val[key])

# check forward pass
assert torch.allclose(output, exp_quant.to(output.dtype))
assert torch.nn.functional.mse_loss(output, input) <= exp_loss
Loading