Skip to content

Commit c8ae1ed

Browse files
committed
squash
Signed-off-by: Kyle Sayers <[email protected]>
1 parent c098447 commit c8ae1ed

File tree

4 files changed

+77
-39
lines changed

4 files changed

+77
-39
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def update_weight_global_scale(module: Module):
147147
should_calculate_gparam=True,
148148
should_calculate_qparams=False,
149149
)
150-
module.weight_observer.reset()
151150

152151

153152
def update_weight_zp_scale(module: Module):
@@ -199,6 +198,10 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
199198
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
200199
calculate_gparam = True
201200

201+
# (..., 1, hidden_dim)
202+
# the second to last dim indicates that activations have one output channel
203+
value = value.flatten(0, -1).unsqueeze(-2)
204+
202205
call_observer(
203206
module=module,
204207
base_name=base_name,

src/llmcompressor/observers/base.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
QuantizationArgs,
99
QuantizationStrategy,
1010
)
11-
from compressed_tensors.quantization.utils import is_fp4
11+
from compressed_tensors.quantization.utils import is_fp4, strict_divide
1212
from compressed_tensors.registry.registry import RegistryMixin
1313
from loguru import logger
1414
from torch import FloatTensor, IntTensor, Tensor
@@ -128,8 +128,6 @@ def get_qparams(
128128
:return: tuple of scale and zero point based on last observed value
129129
"""
130130
if observed is not None:
131-
group_size = self.quantization_args.group_size
132-
133131
if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
134132
# re-calculate scale and zero point, update the stored value
135133
self._scale, self._zero_point = self.calculate_qparams(observed)
@@ -138,49 +136,43 @@ def get_qparams(
138136
QuantizationStrategy.TENSOR_GROUP,
139137
QuantizationStrategy.GROUP,
140138
):
141-
rows = observed.shape[0]
142-
columns = observed.shape[1]
143-
num_groups = int(ceil(columns / group_size))
144-
if num_groups * group_size != columns:
145-
logger.bind(log_once=True).warning(
146-
"Attempting to quantize a module weight whose columns "
147-
f"({columns}) are not divisible by group_size ({group_size}). "
148-
"This scheme is not supported by vLLM, please consider "
149-
"adjusting the group_size for modules with this number of "
150-
"columns",
151-
)
139+
# should be identical implementation to first half of
140+
# `_process_quantization`
152141

153-
self._scale = torch.empty(
154-
(rows, num_groups), dtype=observed.dtype, device=observed.device
155-
)
142+
# get shapes
143+
assert observed.ndim >= 2
144+
rows, columns = observed.shape[-2:]
145+
group_size = self.quantization_args.group_size
146+
num_groups = strict_divide(columns, group_size)
147+
148+
# FP4: cast zp type
156149
if is_fp4(quantization_args=self.quantization_args):
157150
zp_dtype = FP8_E4M3_DATA.dtype
158151
else:
159152
zp_dtype = self.quantization_args.pytorch_dtype()
160153

154+
# allocate qparams
155+
self._scale = torch.empty(
156+
(rows, num_groups), dtype=observed.dtype, device=observed.device
157+
)
161158
self._zero_point = torch.empty(
162159
(rows, num_groups), dtype=zp_dtype, device=observed.device
163160
)
164161

165-
# support column-order (default) quantization as well as other orderings
166-
# such as activation ordering. Below checks if g_idx has initialized
167-
is_column_order = g_idx is None or -1 in g_idx
168-
if is_column_order:
169-
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
170-
else:
171-
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
172-
group_sizes = group_sizes[torch.argsort(group_indices)]
173-
174-
observed = observed.index_select(-1, g_idx)
162+
# permute groups
163+
if g_idx is not None:
164+
perm = torch.argsort(g_idx)
165+
observed = observed.index_select(-1, perm)
175166

176167
# TODO: experiment with vectorizing for loop for performance
168+
# all reduce all dims except the second to last one
177169
end = 0
178-
for group_index, group_count in enumerate(group_sizes):
170+
for group_index in range(num_groups):
179171
start = end
180-
end = start + group_count
172+
end = start + group_size
181173
scale, zero_point = self.get_qparams_along_dim(
182-
observed[:, start:end],
183-
0,
174+
observed[..., start:end],
175+
dim=-2,
184176
tensor_id=group_index,
185177
global_scale=global_scale,
186178
)
@@ -189,8 +181,8 @@ def get_qparams(
189181
self._zero_point[:, group_index] = zero_point.squeeze(1)
190182

191183
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
192-
# assume observed is transposed, because its the output, hence use dim 0
193-
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
184+
# all reduce all dims except the second to last one
185+
self._scale, self._zero_point = self.get_qparams_along_dim(observed, -2)
194186

195187
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
196188
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
@@ -203,7 +195,7 @@ def get_qparams(
203195
elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
204196
# Block-wise quantization: one scale/zero_point per block of shape
205197
# [block_rows, block_cols]
206-
rows, cols = observed.shape[:2]
198+
rows, cols = observed.shape[-2:]
207199
bs = self.quantization_args.block_structure
208200
if not (
209201
isinstance(bs, (list, tuple))
@@ -255,15 +247,20 @@ def get_qparams(
255247

256248
def get_qparams_along_dim(
257249
self,
258-
observed,
250+
observed: torch.Tensor,
259251
dim: Union[int, Iterable[int]],
260252
tensor_id: Optional[Any] = None,
261253
global_scale: Optional[Tensor] = None,
262254
):
255+
# cast to set
263256
if isinstance(dim, int):
264257
dim = [dim]
265258
dim = set(dim)
266259

260+
# convert negative dims
261+
dim = [d if d >= 0 else observed.ndim + d for d in dim]
262+
263+
# reduce all dimensions except the the one passed as argument to this function
267264
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
268265
return self.calculate_qparams(
269266
observed,

src/llmcompressor/observers/min_max.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from compressed_tensors.quantization.quant_args import QuantizationArgs
55
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
6-
from compressed_tensors.utils import deprecated
6+
from compressed_tensors.utils import deprecated, patch_attr
77

88
from llmcompressor.observers.base import Observer
99

tests/llmcompressor/modifiers/calibration/test_observers.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,46 @@ def test_static_weight_quantization(
294294
0.06,
295295
),
296296
# channel is not supported, but is in principle equivalent to token/tensor
297-
# group is not yet supported
298-
# tensor_group is not yet supported
297+
(
298+
QuantizationArgs(
299+
num_bits=4,
300+
type="int",
301+
symmetric=True,
302+
strategy="group",
303+
group_size=3,
304+
observer="minmax",
305+
),
306+
{
307+
"default": torch.tensor([[0]]),
308+
1: torch.tensor([[3]]),
309+
},
310+
{
311+
"default": torch.tensor([[2]]),
312+
1: torch.tensor([[5]]),
313+
},
314+
torch.tensor([[0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875]]),
315+
0.04,
316+
),
317+
(
318+
QuantizationArgs(
319+
num_bits=4,
320+
type="float", # tensor group requires FP4
321+
symmetric=True,
322+
strategy="tensor_group",
323+
group_size=3,
324+
observer="minmax",
325+
),
326+
{
327+
"default": torch.tensor([[0]]),
328+
1: torch.tensor([[3]]),
329+
},
330+
{
331+
"default": torch.tensor([[2]]),
332+
1: torch.tensor([[5]]),
333+
},
334+
torch.tensor([[0.0000, 0.9766, 1.9531, 3.3125, 3.3125, 4.9688]]),
335+
0.1,
336+
),
299337
# block is not supported, but is in principle similar to group
300338
],
301339
)

0 commit comments

Comments
 (0)