-
Notifications
You must be signed in to change notification settings - Fork 464
Expand file tree
/
Copy pathtest_quantization.py
More file actions
166 lines (143 loc) · 5.69 KB
/
test_quantization.py
File metadata and controls
166 lines (143 loc) · 5.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import shutil
import tempfile
import unittest
import pytest
import torch
from compressed_tensors.quantization.utils import is_module_quantized
from parameterized import parameterized_class
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from llmcompressor import oneshot
from llmcompressor.args import DatasetArguments
from llmcompressor.pytorch.utils import tensors_to_device
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from tests.testing_utils import parse_params, requires_gpu
CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/compression/configs"
@requires_gpu
@pytest.mark.integration
@parameterized_class(parse_params(CONFIGS_DIRECTORY))
class TestQuantizationMatches(unittest.TestCase):
new_recipe = None
ppl_threshold = None
model_stub = None
dataset = "ultrachat-200k"
output = "tiny_llama_out"
max_seq_length = 512
weight_dtype = torch.float16
num_eval = 64
@classmethod
def setUpClass(cls):
cls.test_dir = tempfile.mkdtemp()
cls.model = AutoModelForCausalLM.from_pretrained(
cls.model_stub, torch_dtype=cls.weight_dtype, device_map="cuda:0"
)
model = cls._run_oneshot(
cls.model,
cls.new_recipe,
cls.dataset,
os.path.join(cls.test_dir, cls.output),
)
cls.session_model = model
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.test_dir)
del cls.model
torch.cuda.empty_cache()
@staticmethod
def _run_oneshot(model, recipe, dataset, output_dir):
num_calibration_samples = 64
max_seq_length = 512
pad_to_max_length = False
model = oneshot(
model=model,
dataset=dataset,
output_dir=output_dir,
max_seq_length=max_seq_length,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
pad_to_max_length=pad_to_max_length,
clear_sparse_session=False,
splits={"calibration": "train_gen[:1%]"},
save_compressed=False,
)
return model
def _get_quant_info(self, model):
quant_info_weights = {}
quant_info_inputs = {}
for name, module in model.named_modules():
if is_module_quantized(module):
if module.quantization_scheme.weights is not None:
quant_info_weights[name] = (
module.weight_scale,
module.weight_zero_point,
module.weight,
)
if module.quantization_scheme.input_activations is not None:
is_dynamic = module.quantization_scheme.input_activations.dynamic
if not is_dynamic:
quant_info_inputs[name] = (
module.input_scale,
module.input_zero_point,
)
return quant_info_weights, quant_info_inputs
def test_quantization_reload(self):
model_reloaded = AutoModelForCausalLM.from_pretrained(
os.path.join(self.test_dir, self.output),
torch_dtype="auto",
device_map="cuda:0",
)
og_weights, og_inputs = self._get_quant_info(self.model)
reloaded_weights, reloaded_inputs = self._get_quant_info(model_reloaded)
for name, (o_scale, o_zp, o_weight) in og_weights.items():
n_scale, n_zp, n_weight = reloaded_weights[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype
assert torch.equal(o_zp, n_zp)
# we don't expect an exact match here because o_weight still has the
# original weight and n_weight has been fake_quantized
assert n_weight.dtype == o_weight.dtype == self.weight_dtype
for name, (o_scale, o_zp) in og_inputs.items():
n_scale, n_zp = reloaded_inputs[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype
assert torch.equal(o_zp, n_zp)
def _get_dataloader(self, data_args, tokenizer):
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split="train_gen[:5%]",
processor=tokenizer,
)
calib_dataset = dataset_manager()
data_loader = DataLoader(
calib_dataset,
batch_size=1,
collate_fn=DefaultDataCollator(),
sampler=torch.utils.data.RandomSampler(calib_dataset),
)
return data_loader
@torch.no_grad()
def test_perplexity(self):
if self.ppl_threshold is None:
pytest.skip("Skipping perplexity calculation.")
tokenizer = AutoTokenizer.from_pretrained(self.model_stub)
data_args = DatasetArguments(
dataset="ultrachat-200k",
max_seq_length=self.max_seq_length,
)
dataloader = self._get_dataloader(data_args, tokenizer)
total_ppl = 0.0
total_non_nan = 0
for idx, sample in enumerate(dataloader):
if idx >= self.num_eval:
break
output = self.model(**tensors_to_device(sample, "cuda:0"))
if torch.isnan(output.loss):
continue
total_ppl += torch.exp(output.loss).item()
total_non_nan += 1
avg_ppl = total_ppl / total_non_nan
assert avg_ppl <= self.ppl_threshold