|
1 | 1 | import inspect
|
2 | 2 | from typing import Dict, List, Optional, Tuple, Union
|
| 3 | +import warnings |
3 | 4 |
|
4 | 5 | import torch
|
5 | 6 | from compressed_tensors.quantization import (
|
@@ -183,25 +184,25 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier":
|
183 | 184 |
|
184 | 185 | model._group_size = next(iter(group_size_set))
|
185 | 186 |
|
186 |
| - in_num_bits_set = set( |
| 187 | + num_bits_set = set( |
187 | 188 | group.input_activations.num_bits
|
188 | 189 | for group in config.config_groups.values()
|
189 | 190 | if group.input_activations is not None
|
| 191 | + ).union( |
| 192 | + set( |
| 193 | + group.output_activations.num_bits |
| 194 | + for group in config.config_groups.values() |
| 195 | + if group.output_activations is not None |
| 196 | + ) |
190 | 197 | )
|
191 |
| - assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, ( |
192 |
| - "AWQ activations must be 16-bit precision, " |
193 |
| - f"input activations {in_num_bits_set} not allowed" |
194 |
| - ) |
195 |
| - |
196 |
| - out_num_bits_set = set( |
197 |
| - group.output_activations.num_bits |
198 |
| - for group in config.config_groups.values() |
199 |
| - if group.output_activations is not None |
200 |
| - ) |
201 |
| - assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, ( |
202 |
| - "AWQ activations must be 16-bit precision, " |
203 |
| - f"output activations {out_num_bits_set} not allowed" |
204 |
| - ) |
| 198 | + if not (len(num_bits_set) == 0 or num_bits_set == {16}): |
| 199 | + warnings.warn( |
| 200 | + "A strategy including activation quantization was detected. " |
| 201 | + "AWQ was originally intended for weight-only quantization. " |
| 202 | + "Lower-precision activations are an experimental feautre, and " |
| 203 | + "overall performance may be poor. If it is, consider using " |
| 204 | + "`W4A16` or `W4A16_ASYM` quantization schemes instead." |
| 205 | + ) |
205 | 206 |
|
206 | 207 | return model
|
207 | 208 |
|
|
0 commit comments