Skip to content

Commit 5c643b0

Browse files
authored
[NVFP4] Update to use tensor_group strategy; update observers (#1484)
SUMMARY: - Requires neuralmagic/compressed-tensors#325 - Uses the new `tensor_group` strategy for nvfp4a16 quantization - Removes global_scale as an observer class parameter and passes in as a function call, similar to g_idx
1 parent 94a3e53 commit 5c643b0

File tree

4 files changed

+57
-26
lines changed

4 files changed

+57
-26
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
QuantizationStatus,
88
)
99
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
10-
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
10+
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
1111
from compressed_tensors.utils import align_module_device, update_parameter_data
1212
from loguru import logger
1313
from torch.nn import Module
@@ -54,14 +54,9 @@ def initialize_observer(
5454
quantization_args = getattr(quantization_scheme, arg_name, None)
5555
# dont need observers for dynamic
5656
if quantization_args is not None and not quantization_args.dynamic:
57-
global_scale = getattr(module, f"{base_name}_global_scale", None)
58-
if global_scale is not None:
59-
assert base_name == "weight" and is_fp4(quantization_args=quantization_args)
60-
6157
observer = Observer.load_from_registry(
6258
quantization_args.observer,
6359
quantization_args=quantization_args,
64-
global_scale=global_scale,
6560
)
6661
module.register_module(f"{base_name}_observer", observer)
6762

@@ -80,15 +75,19 @@ def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor]
8075
if base_name == "weight":
8176
value = module.weight
8277
g_idx = getattr(module, "weight_g_idx", None)
78+
global_scale = getattr(module, f"{base_name}_global_scale", None)
8379
elif value is not None:
8480
g_idx = None
81+
global_scale = None
8582
else:
8683
raise ValueError(
8784
"Must provide a value to observe if not using weight observer"
8885
)
8986

9087
observer = getattr(module, f"{base_name}_observer")
91-
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
88+
updated_scale, updated_zero_point = observer(
89+
value, g_idx=g_idx, global_scale=global_scale
90+
)
9291

9392
# update scale and zero point
9493
update_parameter_data(module, updated_scale, f"{base_name}_scale")

src/llmcompressor/observers/base.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,32 @@ class Observer(Module, RegistryMixin):
2727
def __init__(
2828
self,
2929
quantization_args: QuantizationArgs,
30-
global_scale: Optional[torch.Tensor] = None,
3130
):
3231
self.quantization_args: QuantizationArgs = quantization_args
3332
super().__init__()
34-
self.global_scale: Optional[torch.Tensor] = global_scale
3533
self._scale = None
3634
self._zero_point = None
3735
self._num_observed_tokens = None
3836

3937
@torch.no_grad()
4038
def forward(
41-
self, observed: Tensor, g_idx: Optional[Tensor] = None
39+
self,
40+
observed: Tensor,
41+
g_idx: Optional[Tensor] = None,
42+
global_scale: Optional[Tensor] = None,
4243
) -> Tuple[FloatTensor, IntTensor]:
4344
"""
4445
maps directly to get_qparams
4546
:param observed: optional observed tensor from which to calculate
4647
quantization parameters
4748
:param g_idx: optional mapping from column index to group index
49+
:param global_scale: optional scale to further scale local quantization scales
4850
:return: tuple of scale and zero point based on last observed value
4951
"""
5052
self.record_observed_tokens(observed)
51-
return self.get_qparams(observed=observed, g_idx=g_idx)
53+
return self.get_qparams(
54+
observed=observed, g_idx=g_idx, global_scale=global_scale
55+
)
5256

5357
def calculate_qparams(
5458
self,
@@ -73,6 +77,7 @@ def get_qparams(
7377
self,
7478
observed: Optional[Tensor] = None,
7579
g_idx: Optional[Tensor] = None,
80+
global_scale: Optional[Tensor] = None,
7681
) -> Tuple[FloatTensor, IntTensor]:
7782
"""
7883
Convenience function to wrap overwritten calculate_qparams
@@ -82,6 +87,7 @@ def get_qparams(
8287
:param observed: optional observed tensor to calculate quantization parameters
8388
from
8489
:param g_idx: optional mapping from column index to group index
90+
:param global_scale: optional scale to further scale local quantization scales
8591
:return: tuple of scale and zero point based on last observed value
8692
"""
8793
if observed is not None:
@@ -91,7 +97,10 @@ def get_qparams(
9197
# re-calculate scale and zero point, update the stored value
9298
self._scale, self._zero_point = self.calculate_qparams(observed)
9399

94-
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
100+
elif self.quantization_args.strategy in (
101+
QuantizationStrategy.TENSOR_GROUP,
102+
QuantizationStrategy.GROUP,
103+
):
95104
rows = observed.shape[0]
96105
columns = observed.shape[1]
97106
num_groups = int(ceil(columns / group_size))
@@ -128,6 +137,7 @@ def get_qparams(
128137
observed[:, start:end],
129138
0,
130139
tensor_id=group_index,
140+
global_scale=global_scale
131141
)
132142

133143
self._scale[:, group_index] = scale.squeeze(1)
@@ -160,14 +170,18 @@ def get_qparams_along_dim(
160170
observed,
161171
dim: Union[int, Iterable[int]],
162172
tensor_id: Optional[Any] = None,
173+
global_scale: Optional[Tensor] = None,
163174
):
164175
if isinstance(dim, int):
165176
dim = [dim]
166177
dim = set(dim)
167178

168179
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
169180
return self.calculate_qparams(
170-
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
181+
observed,
182+
reduce_dims=reduce_dims,
183+
tensor_id=tensor_id,
184+
global_scale=global_scale,
171185
)
172186

173187
def record_observed_tokens(self, batch_tensor: Tensor):

src/llmcompressor/observers/min_max.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ def __init__(
2222
self,
2323
quantization_args: QuantizationArgs,
2424
averaging_constant: float = 0.01,
25-
global_scale: Optional[torch.Tensor] = None,
2625
):
27-
super().__init__(quantization_args=quantization_args, global_scale=global_scale)
26+
super().__init__(quantization_args=quantization_args)
2827

2928
self.min_val = {}
3029
self.max_val = {}
@@ -35,6 +34,7 @@ def calculate_qparams(
3534
observed: torch.Tensor,
3635
reduce_dims: Optional[Tuple[int]] = None,
3736
tensor_id: Optional[Any] = None,
37+
global_scale: Optional[torch.Tensor] = None,
3838
) -> Tuple[torch.FloatTensor, torch.IntTensor]:
3939
"""
4040
Updates the observed min and max using a moving average smoothed by the
@@ -46,6 +46,7 @@ def calculate_qparams(
4646
reduced dimensions
4747
:param tensor_id: Optional id if different ranges of observed tensors are
4848
passed, useful for sharding tensors by group_size
49+
:param global_scale: optional scale to further scale local quantization scales
4950
:return: tuple of scale and zero point derived from the observed tensor
5051
"""
5152
tensor_id = tensor_id or "default"
@@ -62,7 +63,7 @@ def calculate_qparams(
6263
min_vals=min_val,
6364
max_vals=max_val,
6465
quantization_args=self.quantization_args,
65-
global_scale=self.global_scale,
66+
global_scale=global_scale,
6667
)
6768

6869
running_min_val = self.min_val.get(tensor_id, None)
@@ -86,18 +87,25 @@ def calculate_qparams(
8687
min_vals=updated_min_val,
8788
max_vals=updated_max_val,
8889
quantization_args=self.quantization_args,
89-
global_scale=self.global_scale,
90+
global_scale=global_scale,
9091
)
9192

9293
def get_qparams_along_dim(
93-
self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None
94+
self,
95+
observed: torch.Tensor,
96+
dim: int,
97+
tensor_id: Optional[Any] = None,
98+
global_scale: Optional[torch.Tensor] = None,
9499
):
95100
"""
96101
Calculate quantization parameters along the specified dimension
97102
"""
98103
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
99104
return self.calculate_qparams(
100-
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
105+
observed,
106+
reduce_dims=reduce_dims,
107+
tensor_id=tensor_id,
108+
global_scale=global_scale,
101109
)
102110

103111
def reset(self):

src/llmcompressor/observers/mse.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ def __init__(
2323
averaging_constant: float = 0.01,
2424
grid: float = 100.0,
2525
norm: float = 2.4,
26-
global_scale: Optional[torch.Tensor] = None,
2726
):
28-
super().__init__(quantization_args=quantization_args, global_scale=global_scale)
27+
super().__init__(quantization_args=quantization_args)
2928

3029
kwargs = quantization_args.observer_kwargs or {}
3130
self.maxshrink = kwargs.get("maxshrink", 0.20)
@@ -41,6 +40,7 @@ def calculate_mse_min_max(
4140
self,
4241
observed: Tensor,
4342
reduce_dims: Optional[Tuple[int]] = None,
43+
global_scale: Optional[torch.Tensor] = None,
4444
):
4545
"""
4646
Computes the mse-clipped min and max values of the observed tensor by
@@ -49,6 +49,7 @@ def calculate_mse_min_max(
4949
:param observed: observed tensor to calculate quantization parameters for
5050
:param reduce_dims: optional tuple of dimensions to reduce along,
5151
returned values will be shaped (1,) along the reduced dimensions
52+
:param global_scale: optional scale to further scale local quantization scales
5253
:return: tuple of min and max values derived from the observed tensor
5354
"""
5455
from compressed_tensors.quantization.lifecycle import fake_quantize
@@ -77,14 +78,14 @@ def calculate_mse_min_max(
7778
min_vals=shrinked_min_val,
7879
max_vals=shrinked_max_val,
7980
quantization_args=self.quantization_args,
80-
global_scale=self.global_scale,
81+
global_scale=global_scale,
8182
)
8283
q = fake_quantize(
8384
observed,
8485
candidate_scales,
8586
candidate_zero_points,
8687
self.quantization_args,
87-
global_scale=self.global_scale,
88+
global_scale=global_scale,
8889
)
8990

9091
q -= observed
@@ -113,6 +114,7 @@ def calculate_qparams(
113114
observed: Tensor,
114115
reduce_dims: Optional[Tuple[int]] = None,
115116
tensor_id: Optional[Any] = None,
117+
global_scale: Optional[torch.Tensor] = None,
116118
) -> Tuple[FloatTensor, IntTensor]:
117119
"""
118120
Updates the mse-clipped min and max values of the observed tensor using
@@ -124,6 +126,7 @@ def calculate_qparams(
124126
reduced dimensions
125127
:param tensor_id: Optional id if different ranges of observed tensors are
126128
passed, useful for sharding tensors by group_size
129+
:param global_scale: optional scale to further scale local quantization scales
127130
:return: tuple of scale and zero point derived from the observed tensor
128131
"""
129132
min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)
@@ -150,15 +153,22 @@ def calculate_qparams(
150153
min_vals=updated_min_val,
151154
max_vals=updated_max_val,
152155
quantization_args=self.quantization_args,
153-
global_scale=self.global_scale,
156+
global_scale=global_scale,
154157
)
155158

156159
def get_qparams_along_dim(
157-
self, observed, dim: int, tensor_id: Optional[Any] = None
160+
self,
161+
observed,
162+
dim: int,
163+
tensor_id: Optional[Any] = None,
164+
global_scale: Optional[torch.Tensor] = None,
158165
):
159166
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
160167
return self.calculate_qparams(
161-
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
168+
observed,
169+
reduce_dims=reduce_dims,
170+
tensor_id=tensor_id,
171+
global_scale=global_scale,
162172
)
163173

164174
def reset(self):

0 commit comments

Comments
 (0)