Skip to content

Commit 930e35c

Browse files
committed
WIP
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 9e916b6 commit 930e35c

File tree

2 files changed

+34
-37
lines changed

2 files changed

+34
-37
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
198198
calculate_qparams = False
199199
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
200200
calculate_gparam = True
201+
202+
# (..., 1, hidden_dim)
203+
# this reshaping is mostly for the benefit of group quantization
204+
value = value.unsqueeze(-2)
201205

202206
call_observer(
203207
module=module,

src/llmcompressor/observers/base.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from compressed_tensors.quantization.utils import is_fp4
1212
from compressed_tensors.registry.registry import RegistryMixin
13-
from compressed_tensors.utils import safe_permute
13+
from compressed_tensors.quantization.utils import strict_divide
1414
from loguru import logger
1515
from torch import FloatTensor, IntTensor, Tensor
1616

@@ -125,8 +125,6 @@ def get_qparams(
125125
:return: tuple of scale and zero point based on last observed value
126126
"""
127127
if observed is not None:
128-
group_size = self.quantization_args.group_size
129-
130128
if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
131129
# re-calculate scale and zero point, update the stored value
132130
self._scale, self._zero_point = self.calculate_qparams(observed)
@@ -135,50 +133,43 @@ def get_qparams(
135133
QuantizationStrategy.TENSOR_GROUP,
136134
QuantizationStrategy.GROUP,
137135
):
138-
rows = observed.shape[0]
139-
columns = observed.shape[1]
140-
num_groups = int(ceil(columns / group_size))
141-
if num_groups * group_size != columns:
142-
logger.bind(log_once=True).warning(
143-
"Attempting to quantize a module weight whose columns "
144-
f"({columns}) are not divisible by group_size ({group_size}). "
145-
"This scheme is not supported by vLLM, please consider "
146-
"adjusting the group_size for modules with this number of "
147-
"columns",
148-
)
136+
# should be identical implementation to first half of
137+
# `_process_quantization`
149138

150-
self._scale = torch.empty(
151-
(rows, num_groups), dtype=observed.dtype, device=observed.device
152-
)
139+
# get shapes
140+
assert observed.ndim >= 2
141+
rows, columns = observed.shape[-2:]
142+
group_size = self.quantization_args.group_size
143+
num_groups = strict_divide(columns, group_size)
144+
145+
# FP4: cast zp type
153146
if is_fp4(quantization_args=self.quantization_args):
154147
zp_dtype = FP8_E4M3_DATA.dtype
155148
else:
156149
zp_dtype = self.quantization_args.pytorch_dtype()
157150

151+
# allocate qparams
152+
self._scale = torch.empty(
153+
(rows, num_groups), dtype=observed.dtype, device=observed.device
154+
)
158155
self._zero_point = torch.empty(
159156
(rows, num_groups), dtype=zp_dtype, device=observed.device
160157
)
161158

162-
# support column-order (default) quantization as well as other orderings
163-
# such as activation ordering. Below checks if g_idx has initialized
164-
is_column_order = g_idx is None or -1 in g_idx
165-
if is_column_order:
166-
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
167-
else:
168-
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
169-
group_sizes = group_sizes[torch.argsort(group_indices)]
170-
159+
# permute groups
160+
if g_idx is not None:
171161
perm = torch.argsort(g_idx)
172-
observed = safe_permute(observed, perm, dim=1)
162+
observed = observed.index_select(-1, perm)
173163

174164
# TODO: experiment with vectorizing for loop for performance
165+
# all reduce all dims except the last one
175166
end = 0
176-
for group_index, group_count in enumerate(group_sizes):
167+
for group_index in range(num_groups):
177168
start = end
178-
end = start + group_count
169+
end = start + group_size
179170
scale, zero_point = self.get_qparams_along_dim(
180-
observed[:, start:end],
181-
0,
171+
observed[..., start:end],
172+
dim=tuple(range(observed.ndim - 1)),
182173
tensor_id=group_index,
183174
global_scale=global_scale,
184175
)
@@ -187,21 +178,23 @@ def get_qparams(
187178
self._zero_point[:, group_index] = zero_point.squeeze(1)
188179

189180
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
190-
# assume observed is transposed, because its the output, hence use dim 0
191-
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
181+
# all reduce all dims except the last one
182+
self._scale, self._zero_point = self.get_qparams_along_dim(
183+
observed,
184+
dim=tuple(range(observed.ndim - 1)),
185+
)
192186

193187
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
194-
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
195-
# should be batch, token
188+
# all reduce all dims except the last one
196189
self._scale, self._zero_point = self.get_qparams_along_dim(
197190
observed,
198-
dim={0, 1},
191+
dim=tuple(range(observed.ndim - 1)),
199192
)
200193

201194
elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
202195
# Block-wise quantization: one scale/zero_point per block of shape
203196
# [block_rows, block_cols]
204-
rows, cols = observed.shape[:2]
197+
rows, cols = observed.shape[-2:]
205198
bs = self.quantization_args.block_structure
206199
if not (
207200
isinstance(bs, (list, tuple))

0 commit comments

Comments
 (0)