Skip to content

Commit 0b79d09

Browse files
committed
fix typos
Signed-off-by: Kyle Sayers <[email protected]>
1 parent c63986a commit 0b79d09

File tree

4 files changed

+76
-103
lines changed

4 files changed

+76
-103
lines changed

src/llmcompressor/observers/helpers.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,19 @@ def _flatten_weight(
6565

6666
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
6767
if g_idx is not None:
68-
value = value.index_select(dim=1, index=g_idx)
68+
value = value.index_select(dim=1, index=torch.argsort(g_idx))
6969

7070
# (1, num_rows, num_groups, group_size)
7171
return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0)
7272

7373
if args.strategy == QuantizationStrategy.BLOCK:
7474
# (1, num_block_rows, num_block_cols, block_width * block_height)
7575
block_height, block_width = args.block_structure
76-
num_rows, num_cols = value.shape
77-
num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy)
78-
num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy)
76+
rows, cols = value.shape
77+
block_rows = strategy_cdiv(rows, block_height, args.strategy, strict=True)
78+
block_cols = strategy_cdiv(cols, block_width, args.strategy, strict=True)
7979
return (
80-
value.reshape(
81-
num_block_rows,
82-
block_height,
83-
num_block_cols,
84-
block_width,
85-
)
80+
value.reshape(block_rows, block_height, block_cols, block_width)
8681
.transpose(1, 2)
8782
.flatten(-2, -1)
8883
.unsqueeze(0)
@@ -99,7 +94,7 @@ def _flatten_activation(value: torch.Tensor, args: QuantizationArgs):
9994
if args.strategy == QuantizationStrategy.TOKEN:
10095
# (batch_size, seq_len, hidden_dim)
10196
# warning: token quantization uses `compute_dynamic_scales_and_zp`
102-
return value.flatten(2, -1)
97+
return value
10398

10499
if args.strategy == QuantizationStrategy.CHANNEL:
105100
raise ValueError("Channel quantization cannot be applied to activations")

src/llmcompressor/observers/min_max.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso
4444
if self.min_vals is not None and self.averaging_constant != 1.0:
4545
# FUTURE: consider scaling by num observations (first dim)
4646
# rather than reducing by first dim
47-
min_vals = torch.lerp(self.min_vals, min_vals, self.averaging_constant)
48-
max_vals = torch.lerp(self.max_vals, max_vals, self.averaging_constant)
47+
min_vals = self._lerp(min_vals, self.min_vals, self.averaging_constant)
48+
max_vals = self._lerp(max_vals, self.max_vals, self.averaging_constant)
4949

5050
return min_vals, max_vals
51+
52+
def _lerp(
53+
self, input: torch.Tensor, end: torch.Tensor, weight: float
54+
) -> torch.Tensor:
55+
"""torch lerp_kernel is not implemeneted for all data types"""
56+
return (input * weight) + (end * (1.0 - weight))

tests/llmcompressor/observers/test_helpers.py

Lines changed: 46 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -12,98 +12,61 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pytest
1516
import torch
1617
from compressed_tensors.quantization import (
17-
QuantizationConfig,
18-
QuantizationStatus,
19-
apply_quantization_config,
18+
QuantizationArgs,
19+
QuantizationScheme,
20+
initialize_module_for_quantization,
2021
)
21-
from transformers import AutoModelForCausalLM, AutoTokenizer
22-
23-
from llmcompressor.modifiers.quantization.calibration import (
24-
calibrate_input_hook,
25-
initialize_observer,
26-
)
27-
from llmcompressor.observers.helpers import get_observer_token_count
28-
29-
30-
def _prep_for_input_quant_calibration(module: torch.nn.Module):
31-
quantization_scheme = getattr(module, "quantization_scheme", None)
32-
if not quantization_scheme:
33-
return
34-
35-
module.register_forward_pre_hook(calibrate_input_hook)
36-
module.quantization_status = QuantizationStatus.CALIBRATION
3722

23+
from llmcompressor.observers.helpers import flatten_for_calibration
3824

39-
def test_get_observer_token_count():
40-
model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE")
41-
tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE")
42-
model.eval()
43-
config = QuantizationConfig(
44-
format="fakequant",
45-
quantization_status="calibration",
46-
config_groups={
47-
"group_1": {
48-
"input_activations": {
49-
"num_bits": 8,
50-
"type": "int",
51-
"symmetric": False,
52-
"strategy": "tensor",
53-
},
54-
"targets": ["Linear"],
55-
},
56-
},
57-
)
58-
apply_quantization_config(model, config)
59-
model.apply(lambda module: initialize_observer(module, base_name="input"))
60-
model.apply(_prep_for_input_quant_calibration)
61-
62-
# start calibration
63-
calib_list = [
64-
"I am a string that",
65-
"is used for calibration so",
66-
"that your model is",
67-
"quantized properly.",
68-
]
6925

70-
total_num_tokens_observed = 0
71-
for calib_sample in calib_list:
72-
calib_tensor = tokenizer(calib_sample, return_tensors="pt")
73-
_ = model(**calib_tensor)
74-
total_num_tokens_observed += len(calib_tensor.input_ids.flatten())
26+
def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor:
27+
perm = torch.randperm(columns)
28+
return torch.tensor([index // group_size for index in range(columns)])[perm]
7529

76-
counter = get_observer_token_count(model)
7730

78-
# filter out the None values
79-
# (tokens, in the appropriate format, that were not observed by the model)
80-
counter = {k: v for k, v in counter.items() if v is not None}
31+
@pytest.mark.parametrize(
32+
"args",
33+
[
34+
QuantizationArgs(strategy="tensor"),
35+
QuantizationArgs(strategy="tensor_group", group_size=4),
36+
],
37+
)
38+
def test_flatten_for_calibration_input(args):
39+
module = torch.nn.Linear(8, 10)
40+
scheme = QuantizationScheme(targets=[], input_activations=args)
41+
initialize_module_for_quantization(module, scheme)
8142

82-
# iterate over all the layers in the model where the token count in the proper
83-
# format is has been observed
84-
for i in range(model.config.num_hidden_layers):
85-
# fetch the tokens observed by the router
86-
tokens_observed_by_router = counter.pop(
87-
f"model.layers.{i}.block_sparse_moe.gate"
88-
)
89-
assert tokens_observed_by_router == total_num_tokens_observed
43+
input = torch.empty((3, 5, 8))
44+
input_flattened = flatten_for_calibration(input, "input", scheme.input_activations)
45+
assert input_flattened.shape[1:-1] == module.input_scale.shape
46+
assert input_flattened.shape[1:-1] == module.input_zero_point.shape
9047

91-
# fetch the sum of tokens observed by all the experts
92-
sum_tokens_observed_by_experts = 0
93-
keys_for_this_layer = [
94-
k
95-
for k in counter.keys()
96-
if f"model.layers.{i}.block_sparse_moe.experts" in k
97-
]
98-
for key in keys_for_this_layer:
99-
sum_tokens_observed_by_experts += counter.pop(key)
10048

101-
# each Mixtral expert is comprised of 3 linear layers,
102-
# so we need to multiply by 3
103-
assert (
104-
sum_tokens_observed_by_experts
105-
== total_num_tokens_observed * model.config.num_experts_per_tok * 3
106-
)
49+
@pytest.mark.parametrize(
50+
"args,g_idx",
51+
[
52+
(QuantizationArgs(strategy="tensor"), None),
53+
(QuantizationArgs(strategy="channel"), None),
54+
(QuantizationArgs(strategy="group", group_size=4), None),
55+
(QuantizationArgs(strategy="group", group_size=4), make_dummy_g_idx(8, 4)),
56+
(QuantizationArgs(strategy="tensor_group", group_size=4), None),
57+
(QuantizationArgs(strategy="block", block_structure=[5, 4]), None),
58+
],
59+
)
60+
def test_flatten_for_calibration_weights(args, g_idx):
61+
module = torch.nn.Linear(8, 10)
62+
scheme = QuantizationScheme(targets=[], weights=args)
63+
initialize_module_for_quantization(module, scheme)
10764

108-
# there are no more information in the counter
109-
assert len(counter) == 0
65+
weight_flattened = flatten_for_calibration(
66+
module.weight,
67+
"weight",
68+
scheme.weights,
69+
g_idx=g_idx,
70+
)
71+
assert weight_flattened.shape[1:-1] == module.weight_scale.shape
72+
assert weight_flattened.shape[1:-1] == module.weight_zero_point.shape

tests/llmcompressor/observers/test_min_max.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,17 @@ def test_min_max_observer_value_update():
8282

8383
tensor = inp
8484
num_bits = 8
85-
weights = QuantizationArgs(num_bits=num_bits, symmetric=True, observer="minmax")
85+
weights = QuantizationArgs(
86+
num_bits=num_bits, strategy="tensor", symmetric=True, observer="minmax"
87+
)
8688
observer = weights.observer
8789
observer = Observer.load_from_registry(observer, base_name="weight", args=weights)
8890
curr_max = 1
8991
curr_min = 1
9092
for i, tensor in enumerate(tensors):
9193
observer(tensor)
92-
curr_max = max(observer.max_val.get("default"), curr_max)
93-
curr_min = min(observer.min_val.get("default"), curr_max)
94+
curr_max = max(observer.max_vals[0], curr_max)
95+
curr_min = min(observer.min_vals[0], curr_min)
9496

9597
if i < 2:
9698
assert curr_max == 1
@@ -108,13 +110,20 @@ def test_g_idx():
108110
input_shape = (128, 512)
109111
tensor = torch.rand(input_shape)
110112
weights = QuantizationArgs(num_bits=8, group_size=group_size, observer="minmax")
113+
114+
module = torch.nn.Linear(512, 1)
111115
g_idx = make_dummy_g_idx(tensor.shape[1], group_size)
116+
module.weight_g_idx = g_idx
112117

113-
observer = weights.observer
114-
observer = Observer.load_from_registry(observer, base_name="weight", args=weights)
115-
scale_g_idx, zero_point_g_idx = observer(tensor, g_idx=g_idx)
118+
observer = Observer.load_from_registry(
119+
weights.observer, base_name="weight", args=weights, module=module
120+
)
121+
scale_g_idx, zero_point_g_idx = observer(tensor)
116122

117-
observer.reset()
123+
observer = Observer.load_from_registry(
124+
weights.observer, base_name="weight", args=weights, module=module
125+
)
126+
del module.weight_g_idx
118127
scale, zero_point = observer(tensor[:, torch.argsort(g_idx)])
119128

120129
assert scale_g_idx == pytest.approx(scale)

0 commit comments

Comments
 (0)