Skip to content

Commit fd71824

Browse files
authored
Revert "expand observers to calculate gparams, add example for activa… (#1486)
…tions" This reverts commit 830c904. SUMMARY: "please provide a brief summary" TEST PLAN: "please outline how the changes were tested"
1 parent 830c904 commit fd71824

File tree

8 files changed

+17
-201
lines changed

8 files changed

+17
-201
lines changed

examples/quantization_w4a4_fp4/llama3_example.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
KVCacheScaleType,
66
QuantizationScheme,
77
QuantizationStatus,
8-
QuantizationStrategy,
98
)
109
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
1110
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
@@ -85,48 +84,14 @@ def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor]
8584
"Must provide a value to observe if not using weight observer"
8685
)
8786

88-
quantization_scheme = getattr(module, "quantization_scheme", None)
89-
should_calculate_gparam = False
90-
should_calculate_qparams = True
91-
92-
# TODO: will update to be the case for both weight and input in a follow-up
93-
# weight global calculate is currently done in ct right now; s
94-
# should be moved here to unify global scale calculations
95-
if (
96-
quantization_scheme.strategy == QuantizationStrategy.TENSOR_GROUP
97-
and base_name == "input"
98-
):
99-
should_calculate_gparam = True
100-
should_calculate_qparams = False
101-
10287
observer = getattr(module, f"{base_name}_observer")
103-
observer_outputs = observer(
104-
value,
105-
g_idx=g_idx,
106-
global_scale=global_scale,
107-
should_calculate_gparam=should_calculate_gparam,
108-
should_calculate_qparams=should_calculate_qparams,
88+
updated_scale, updated_zero_point = observer(
89+
value, g_idx=g_idx, global_scale=global_scale
10990
)
11091

111-
if should_calculate_qparams:
112-
if should_calculate_gparam:
113-
updated_scale, updated_zero_point, updated_global_scale = (
114-
observer_outputs
115-
)
116-
else:
117-
updated_scale, updated_zero_point = observer_outputs
118-
else:
119-
updated_global_scale = observer_outputs
120-
121-
if should_calculate_gparam:
122-
update_parameter_data(
123-
module, updated_global_scale, f"{base_name}_global_scale"
124-
)
125-
126-
if should_calculate_qparams:
127-
# update scale and zero point
128-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
129-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
92+
# update scale and zero point
93+
update_parameter_data(module, updated_scale, f"{base_name}_scale")
94+
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
13095

13196

13297
def update_weight_zp_scale(module: Module):

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44
from compressed_tensors.quantization import (
5-
DynamicType,
65
QuantizationArgs,
76
QuantizationConfig,
87
QuantizationScheme,
@@ -213,10 +212,7 @@ def _initialize_observers(self, module: torch.nn.Module):
213212
return
214213

215214
scheme: QuantizationScheme = module.quantization_scheme
216-
input = scheme.input_activations and scheme.input_activations.dynamic in (
217-
False,
218-
DynamicType.LOCAL,
219-
)
215+
input = scheme.input_activations and not scheme.input_activations.dynamic
220216
weight = scheme.weights is not None
221217
output = scheme.output_activations and not scheme.output_activations.dynamic
222218
is_attention = is_attention_module(module)
@@ -245,10 +241,7 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
245241
continue
246242

247243
scheme: QuantizationScheme = module.quantization_scheme
248-
input = scheme.input_activations and scheme.input_activations.dynamic in (
249-
False,
250-
DynamicType.LOCAL,
251-
)
244+
input = scheme.input_activations and not scheme.input_activations.dynamic
252245
output = scheme.output_activations and not scheme.output_activations.dynamic
253246
is_attention = is_attention_module(module)
254247

src/llmcompressor/observers/base.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,11 @@ def post_calculate_qparams(self) -> None:
7373
Run any logic specific to its observers after running calculate_qparams
7474
"""
7575

76-
# TODO: use a different name?
7776
def get_qparams(
7877
self,
7978
observed: Optional[Tensor] = None,
8079
g_idx: Optional[Tensor] = None,
8180
global_scale: Optional[Tensor] = None,
82-
should_calculate_gparam: bool = False,
83-
should_calculate_qparams: bool = True,
8481
) -> Tuple[FloatTensor, IntTensor]:
8582
"""
8683
Convenience function to wrap overwritten calculate_qparams
@@ -104,14 +101,6 @@ def get_qparams(
104101
QuantizationStrategy.TENSOR_GROUP,
105102
QuantizationStrategy.GROUP,
106103
):
107-
# Global params are for the entire tensor
108-
if should_calculate_gparam:
109-
return self.calculate_qparams(
110-
observed,
111-
should_calculate_gparam=True,
112-
should_calculate_qparams=False,
113-
)
114-
115104
rows = observed.shape[0]
116105
columns = observed.shape[1]
117106
num_groups = int(ceil(columns / group_size))
@@ -148,7 +137,7 @@ def get_qparams(
148137
observed[:, start:end],
149138
0,
150139
tensor_id=group_index,
151-
global_scale=global_scale,
140+
global_scale=global_scale
152141
)
153142

154143
self._scale[:, group_index] = scale.squeeze(1)
Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
from collections import Counter
2-
from typing import Optional
32

43
import torch
5-
from compressed_tensors.quantization.quant_args import (
6-
FP4_E2M1_DATA,
7-
FP8_E4M3_DATA,
8-
FloatArgs,
9-
)
104

11-
__all__ = ["get_observer_token_count", "calculate_gparam"]
5+
__all__ = ["get_observer_token_count"]
126

137

148
def get_observer_token_count(module: torch.nn.Module) -> Counter:
@@ -26,26 +20,3 @@ def get_observer_token_count(module: torch.nn.Module) -> Counter:
2620
module._num_observed_tokens
2721
)
2822
return token_counts
29-
30-
31-
def calculate_gparam(
32-
updated_min_val: torch.Tensor,
33-
updated_max_val: torch.Tensor,
34-
scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
35-
quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
36-
dtype: Optional[torch.dtype] = torch.float32,
37-
):
38-
"""
39-
Generate a global scale for an entire tensor (input_tensor).
40-
Goal of the scale is to ensure that the quantization (local) scale
41-
falls into the approproiate dtype range.
42-
43-
E.g. for NVFP4, group (local) scales are in dtype FP8. The global_scale
44-
attempts to use the entire FP8 dtype range while mapping a per-group max
45-
to the FP4 max.
46-
"""
47-
min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
48-
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
49-
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
50-
global_scale = scale_data.max * quant_data.max / max_val_pos
51-
return global_scale.to(dtype)

src/llmcompressor/observers/min_max.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from compressed_tensors.utils import deprecated
77

88
from llmcompressor.observers.base import Observer
9-
from llmcompressor.observers.helpers import calculate_gparam
109

1110
__all__ = ["MinMaxObserver", "MovingAverageMinMaxObserver"]
1211

@@ -36,8 +35,6 @@ def calculate_qparams(
3635
reduce_dims: Optional[Tuple[int]] = None,
3736
tensor_id: Optional[Any] = None,
3837
global_scale: Optional[torch.Tensor] = None,
39-
should_calculate_gparam: bool = False,
40-
should_calculate_qparams: bool = True,
4138
) -> Tuple[torch.FloatTensor, torch.IntTensor]:
4239
"""
4340
Updates the observed min and max using a moving average smoothed by the
@@ -86,24 +83,13 @@ def calculate_qparams(
8683
self.min_val[tensor_id] = updated_min_val
8784
self.max_val[tensor_id] = updated_max_val
8885

89-
if should_calculate_gparam:
90-
global_scale = calculate_gparam(
91-
updated_min_val=updated_max_val, updated_max_val=updated_max_val
92-
)
93-
if not should_calculate_qparams:
94-
return global_scale
95-
96-
scale, zero_point = calculate_qparams(
86+
return calculate_qparams(
9787
min_vals=updated_min_val,
9888
max_vals=updated_max_val,
9989
quantization_args=self.quantization_args,
10090
global_scale=global_scale,
10191
)
10292

103-
if should_calculate_gparam:
104-
return scale, zero_point, global_scale
105-
return scale, zero_point
106-
10793
def get_qparams_along_dim(
10894
self,
10995
observed: torch.Tensor,

src/llmcompressor/observers/mse.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch import FloatTensor, IntTensor, Tensor
77

88
from llmcompressor.observers.base import Observer
9-
from llmcompressor.observers.helpers import calculate_gparam
109

1110
__all__ = ["MovingAverageMSEObserver"]
1211

@@ -116,8 +115,6 @@ def calculate_qparams(
116115
reduce_dims: Optional[Tuple[int]] = None,
117116
tensor_id: Optional[Any] = None,
118117
global_scale: Optional[torch.Tensor] = None,
119-
should_calculate_gparam: bool = False,
120-
should_calculate_qparams: bool = True,
121118
) -> Tuple[FloatTensor, IntTensor]:
122119
"""
123120
Updates the mse-clipped min and max values of the observed tensor using
@@ -152,24 +149,13 @@ def calculate_qparams(
152149
self.min_val[tensor_id] = updated_min_val
153150
self.max_val[tensor_id] = updated_max_val
154151

155-
if should_calculate_gparam:
156-
global_scale = calculate_gparam(
157-
updated_min_val=updated_max_val, updated_max_val=updated_max_val
158-
)
159-
if not should_calculate_qparams:
160-
return global_scale
161-
162-
scale, zero_point = calculate_qparams(
152+
return calculate_qparams(
163153
min_vals=updated_min_val,
164154
max_vals=updated_max_val,
165155
quantization_args=self.quantization_args,
166156
global_scale=global_scale,
167157
)
168158

169-
if should_calculate_gparam:
170-
return scale, zero_point, global_scale
171-
return scale, zero_point
172-
173159
def get_qparams_along_dim(
174160
self,
175161
observed,

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ def infer_quantization_format(
6161
)
6262
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
6363

64-
if (
65-
weight_args[0].num_bits == 4
66-
and weight_args[0].type == QuantizationType.FLOAT.value
67-
):
68-
return CompressionFormat.nvfp4_pack_quantized
69-
7064
if is_weight_only: # w4a16 and w8a16
65+
if (
66+
weight_args[0].num_bits == 4
67+
and weight_args[0].type == QuantizationType.FLOAT.value
68+
):
69+
return CompressionFormat.nvfp4_pack_quantized
70+
7171
is_valid_pack = all(
7272
weight_arg.num_bits in [4, 8]
7373
and weight_arg.type == QuantizationType.INT.value

0 commit comments

Comments
 (0)