|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import pytest |
15 | 16 | import torch
|
16 | 17 | from compressed_tensors.quantization import (
|
17 |
| - QuantizationConfig, |
18 |
| - QuantizationStatus, |
19 |
| - apply_quantization_config, |
| 18 | + QuantizationArgs, |
| 19 | + QuantizationScheme, |
| 20 | + initialize_module_for_quantization, |
20 | 21 | )
|
21 |
| -from transformers import AutoModelForCausalLM, AutoTokenizer |
22 |
| - |
23 |
| -from llmcompressor.modifiers.quantization.calibration import ( |
24 |
| - calibrate_input_hook, |
25 |
| - initialize_observer, |
26 |
| -) |
27 |
| -from llmcompressor.observers.helpers import get_observer_token_count |
28 |
| - |
29 |
| - |
30 |
| -def _prep_for_input_quant_calibration(module: torch.nn.Module): |
31 |
| - quantization_scheme = getattr(module, "quantization_scheme", None) |
32 |
| - if not quantization_scheme: |
33 |
| - return |
34 |
| - |
35 |
| - module.register_forward_pre_hook(calibrate_input_hook) |
36 |
| - module.quantization_status = QuantizationStatus.CALIBRATION |
37 | 22 |
|
| 23 | +from llmcompressor.observers.helpers import flatten_for_calibration |
38 | 24 |
|
39 |
| -def test_get_observer_token_count(): |
40 |
| - model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") |
41 |
| - tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") |
42 |
| - model.eval() |
43 |
| - config = QuantizationConfig( |
44 |
| - format="fakequant", |
45 |
| - quantization_status="calibration", |
46 |
| - config_groups={ |
47 |
| - "group_1": { |
48 |
| - "input_activations": { |
49 |
| - "num_bits": 8, |
50 |
| - "type": "int", |
51 |
| - "symmetric": False, |
52 |
| - "strategy": "tensor", |
53 |
| - }, |
54 |
| - "targets": ["Linear"], |
55 |
| - }, |
56 |
| - }, |
57 |
| - ) |
58 |
| - apply_quantization_config(model, config) |
59 |
| - model.apply(lambda module: initialize_observer(module, base_name="input")) |
60 |
| - model.apply(_prep_for_input_quant_calibration) |
61 |
| - |
62 |
| - # start calibration |
63 |
| - calib_list = [ |
64 |
| - "I am a string that", |
65 |
| - "is used for calibration so", |
66 |
| - "that your model is", |
67 |
| - "quantized properly.", |
68 |
| - ] |
69 | 25 |
|
70 |
| - total_num_tokens_observed = 0 |
71 |
| - for calib_sample in calib_list: |
72 |
| - calib_tensor = tokenizer(calib_sample, return_tensors="pt") |
73 |
| - _ = model(**calib_tensor) |
74 |
| - total_num_tokens_observed += len(calib_tensor.input_ids.flatten()) |
| 26 | +def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: |
| 27 | + perm = torch.randperm(columns) |
| 28 | + return torch.tensor([index // group_size for index in range(columns)])[perm] |
75 | 29 |
|
76 |
| - counter = get_observer_token_count(model) |
77 | 30 |
|
78 |
| - # filter out the None values |
79 |
| - # (tokens, in the appropriate format, that were not observed by the model) |
80 |
| - counter = {k: v for k, v in counter.items() if v is not None} |
| 31 | +@pytest.mark.parametrize( |
| 32 | + "args", |
| 33 | + [ |
| 34 | + QuantizationArgs(strategy="tensor"), |
| 35 | + QuantizationArgs(strategy="tensor_group", group_size=4), |
| 36 | + ], |
| 37 | +) |
| 38 | +def test_flatten_for_calibration_input(args): |
| 39 | + module = torch.nn.Linear(8, 10) |
| 40 | + scheme = QuantizationScheme(targets=[], input_activations=args) |
| 41 | + initialize_module_for_quantization(module, scheme) |
81 | 42 |
|
82 |
| - # iterate over all the layers in the model where the token count in the proper |
83 |
| - # format is has been observed |
84 |
| - for i in range(model.config.num_hidden_layers): |
85 |
| - # fetch the tokens observed by the router |
86 |
| - tokens_observed_by_router = counter.pop( |
87 |
| - f"model.layers.{i}.block_sparse_moe.gate" |
88 |
| - ) |
89 |
| - assert tokens_observed_by_router == total_num_tokens_observed |
| 43 | + input = torch.empty((3, 5, 8)) |
| 44 | + input_flattened = flatten_for_calibration(input, "input", scheme.input_activations) |
| 45 | + assert input_flattened.shape[1:-1] == module.input_scale.shape |
| 46 | + assert input_flattened.shape[1:-1] == module.input_zero_point.shape |
90 | 47 |
|
91 |
| - # fetch the sum of tokens observed by all the experts |
92 |
| - sum_tokens_observed_by_experts = 0 |
93 |
| - keys_for_this_layer = [ |
94 |
| - k |
95 |
| - for k in counter.keys() |
96 |
| - if f"model.layers.{i}.block_sparse_moe.experts" in k |
97 |
| - ] |
98 |
| - for key in keys_for_this_layer: |
99 |
| - sum_tokens_observed_by_experts += counter.pop(key) |
100 | 48 |
|
101 |
| - # each Mixtral expert is comprised of 3 linear layers, |
102 |
| - # so we need to multiply by 3 |
103 |
| - assert ( |
104 |
| - sum_tokens_observed_by_experts |
105 |
| - == total_num_tokens_observed * model.config.num_experts_per_tok * 3 |
106 |
| - ) |
| 49 | +@pytest.mark.parametrize( |
| 50 | + "args,g_idx", |
| 51 | + [ |
| 52 | + (QuantizationArgs(strategy="tensor"), None), |
| 53 | + (QuantizationArgs(strategy="channel"), None), |
| 54 | + (QuantizationArgs(strategy="group", group_size=4), None), |
| 55 | + (QuantizationArgs(strategy="group", group_size=4), make_dummy_g_idx(8, 4)), |
| 56 | + (QuantizationArgs(strategy="tensor_group", group_size=4), None), |
| 57 | + (QuantizationArgs(strategy="block", block_structure=[5, 4]), None), |
| 58 | + ], |
| 59 | +) |
| 60 | +def test_flatten_for_calibration_weights(args, g_idx): |
| 61 | + module = torch.nn.Linear(8, 10) |
| 62 | + scheme = QuantizationScheme(targets=[], weights=args) |
| 63 | + initialize_module_for_quantization(module, scheme) |
107 | 64 |
|
108 |
| - # there are no more information in the counter |
109 |
| - assert len(counter) == 0 |
| 65 | + weight_flattened = flatten_for_calibration( |
| 66 | + module.weight, |
| 67 | + "weight", |
| 68 | + scheme.weights, |
| 69 | + g_idx=g_idx, |
| 70 | + ) |
| 71 | + assert weight_flattened.shape[1:-1] == module.weight_scale.shape |
| 72 | + assert weight_flattened.shape[1:-1] == module.weight_zero_point.shape |
0 commit comments