Skip to content

Commit 79c7e86

Browse files
committed
refactor observers
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 33ef5f4 commit 79c7e86

File tree

14 files changed

+733
-749
lines changed

14 files changed

+733
-749
lines changed

docs/observers.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ from llmcompressor.observers import Observer
6565
from compressed_tensors.quantization.quant_args import QuantizationArgs
6666

6767
args = QuantizationArgs(num_bits=4, strategy="group", group_size=128)
68-
observer = Observer.load_from_registry("minmax", quantization_args=args)
68+
observer = Observer.load_from_registry(
69+
"minmax",
70+
base_name="weight",
71+
quantization_args=args,
72+
)
6973

7074
x = torch.randn(64, 512)
7175
scale, zero_point = observer(x)

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,15 @@ def update(
8686
"""
8787

8888
if len(self.k_observers) <= layer_idx:
89-
k_observer_name = self.quantization_args.observer
9089
k_observer = Observer.load_from_registry(
91-
k_observer_name, quantization_args=self.quantization_args
90+
self.quantization_args.observer,
91+
base_name="k",
92+
args=self.quantization_args,
9293
)
93-
v_observer_name = self.quantization_args.observer
9494
v_observer = Observer.load_from_registry(
95-
v_observer_name, quantization_args=self.quantization_args
95+
self.quantization_args.observer,
96+
base_name="v",
97+
args=self.quantization_args,
9698
)
9799

98100
# NOTE: User may ignore some layers in configuration,

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 19 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from compressed_tensors.quantization import (
66
DynamicType,
77
KVCacheScaleType,
8+
QuantizationArgs,
89
QuantizationScheme,
910
QuantizationStatus,
1011
QuantizationStrategy,
@@ -19,12 +20,6 @@
1920
from llmcompressor.observers import Observer
2021
from llmcompressor.utils.helpers import getattr_chain
2122

22-
DEFAULT_MAXSHRINK = 0.20
23-
DEFAULT_PATIENCE = 5
24-
DEFAULT_AVERAGING_CONSTANT = 0.01
25-
DEFAULT_GRID = 100.0
26-
DEFAULT_NORM = 2.4
27-
2823
__all__ = [
2924
"initialize_observer",
3025
"update_weight_zp_scale",
@@ -54,31 +49,19 @@ def initialize_observer(
5449
:param base_name: str used to name the observer attribute
5550
5651
"""
57-
58-
arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
59-
quantization_scheme = getattr(module, "quantization_scheme", None)
60-
if not quantization_scheme:
61-
# no quantization scheme nothing to do
62-
return
63-
64-
quantization_args = getattr(quantization_scheme, arg_name, None)
65-
# dont need observers for dynamic
66-
if quantization_args is not None and quantization_args.dynamic in (
67-
False,
68-
DynamicType.LOCAL,
69-
):
70-
observer_kwargs = quantization_args.observer_kwargs or {}
52+
if base_name == "weight":
53+
arg_name = "weights"
54+
elif base_name == "output":
55+
arg_name = "output_activations"
56+
else: # input, q, k, v
57+
arg_name = "input_activations"
58+
59+
args: QuantizationArgs = getattr_chain(
60+
module, f"quantization_scheme.{arg_name}", None
61+
)
62+
if args is not None and args.dynamic is not True:
7163
observer = Observer.load_from_registry(
72-
quantization_args.observer,
73-
quantization_args=quantization_args,
74-
averaging_constant=observer_kwargs.get(
75-
"averaging_constant", DEFAULT_AVERAGING_CONSTANT
76-
),
77-
# used by mse observer only, will be ignored by minmax observer
78-
maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK),
79-
patience=observer_kwargs.get("patience", DEFAULT_PATIENCE),
80-
grid=observer_kwargs.get("grid", DEFAULT_GRID),
81-
norm=observer_kwargs.get("norm", DEFAULT_NORM),
64+
args.observer, base_name=base_name, args=args, module=module
8265
)
8366
module.register_module(f"{base_name}_observer", observer)
8467

@@ -100,36 +83,17 @@ def call_observer(
10083
base_name is "weight", then the module's weight tensor will be used
10184
"""
10285
with align_module_device(module):
103-
if base_name == "weight":
104-
value = module.weight
105-
g_idx = getattr(module, "weight_g_idx", None)
106-
elif value is not None:
107-
g_idx = None
108-
else:
109-
raise ValueError(
110-
"Must provide a value to observe if not using weight observer"
111-
)
112-
113-
observer = getattr(module, f"{base_name}_observer")
86+
value = module.weight if base_name == "weight" else value
87+
observer: Observer = getattr(module, f"{base_name}_observer")
11488

11589
if should_calculate_gparam:
116-
global_scale = observer(
117-
value,
118-
should_calculate_gparam=True,
119-
)
90+
global_scale = observer.get_global_scale(value)
12091
update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
121-
else:
122-
global_scale = getattr(module, f"{base_name}_global_scale", None)
12392

12493
if should_calculate_qparams:
125-
updated_scale, updated_zero_point = observer(
126-
value, g_idx=g_idx, global_scale=global_scale
127-
)
128-
# register or update scale & zero_point parameters (supports block shapes)
129-
scale_name = f"{base_name}_scale"
130-
zp_name = f"{base_name}_zero_point"
131-
update_offload_parameter(module, scale_name, updated_scale)
132-
update_offload_parameter(module, zp_name, updated_zero_point)
94+
scale, zero_point = observer(value)
95+
update_offload_parameter(module, f"{base_name}_scale", scale)
96+
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)
13397

13498

13599
def update_weight_global_scale(module: Module):
@@ -148,7 +112,6 @@ def update_weight_global_scale(module: Module):
148112
should_calculate_gparam=True,
149113
should_calculate_qparams=False,
150114
)
151-
module.weight_observer.reset()
152115

153116

154117
def update_weight_zp_scale(module: Module):

src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ def quantize_weight(
9595

9696
# create observer for calculating quantization parameters
9797
observer = Observer.load_from_registry(
98-
quant_args.observer,
99-
quantization_args=quant_args,
98+
"minmax",
99+
base_name="weight",
100+
args=quant_args,
101+
module=module,
100102
averaging_constant=1.0, # ignore moving average
101103
)
102104

0 commit comments

Comments
 (0)