5
5
from compressed_tensors .quantization import (
6
6
DynamicType ,
7
7
KVCacheScaleType ,
8
+ QuantizationArgs ,
8
9
QuantizationScheme ,
9
10
QuantizationStatus ,
10
11
QuantizationStrategy ,
19
20
from llmcompressor .observers import Observer
20
21
from llmcompressor .utils .helpers import getattr_chain
21
22
22
- DEFAULT_MAXSHRINK = 0.20
23
- DEFAULT_PATIENCE = 5
24
- DEFAULT_AVERAGING_CONSTANT = 0.01
25
- DEFAULT_GRID = 100.0
26
- DEFAULT_NORM = 2.4
27
-
28
23
__all__ = [
29
24
"initialize_observer" ,
30
25
"update_weight_zp_scale" ,
@@ -54,31 +49,19 @@ def initialize_observer(
54
49
:param base_name: str used to name the observer attribute
55
50
56
51
"""
57
-
58
- arg_name = "weights" if base_name == "weight" else f"{ base_name } _activations"
59
- quantization_scheme = getattr (module , "quantization_scheme" , None )
60
- if not quantization_scheme :
61
- # no quantization scheme nothing to do
62
- return
63
-
64
- quantization_args = getattr (quantization_scheme , arg_name , None )
65
- # dont need observers for dynamic
66
- if quantization_args is not None and quantization_args .dynamic in (
67
- False ,
68
- DynamicType .LOCAL ,
69
- ):
70
- observer_kwargs = quantization_args .observer_kwargs or {}
52
+ if base_name == "weight" :
53
+ arg_name = "weights"
54
+ elif base_name == "output" :
55
+ arg_name = "output_activations"
56
+ else : # input, q, k, v
57
+ arg_name = "input_activations"
58
+
59
+ args : QuantizationArgs = getattr_chain (
60
+ module , f"quantization_scheme.{ arg_name } " , None
61
+ )
62
+ if args is not None and args .dynamic is not True :
71
63
observer = Observer .load_from_registry (
72
- quantization_args .observer ,
73
- quantization_args = quantization_args ,
74
- averaging_constant = observer_kwargs .get (
75
- "averaging_constant" , DEFAULT_AVERAGING_CONSTANT
76
- ),
77
- # used by mse observer only, will be ignored by minmax observer
78
- maxshrink = observer_kwargs .get ("maxshrink" , DEFAULT_MAXSHRINK ),
79
- patience = observer_kwargs .get ("patience" , DEFAULT_PATIENCE ),
80
- grid = observer_kwargs .get ("grid" , DEFAULT_GRID ),
81
- norm = observer_kwargs .get ("norm" , DEFAULT_NORM ),
64
+ args .observer , base_name = base_name , args = args , module = module
82
65
)
83
66
module .register_module (f"{ base_name } _observer" , observer )
84
67
@@ -100,36 +83,17 @@ def call_observer(
100
83
base_name is "weight", then the module's weight tensor will be used
101
84
"""
102
85
with align_module_device (module ):
103
- if base_name == "weight" :
104
- value = module .weight
105
- g_idx = getattr (module , "weight_g_idx" , None )
106
- elif value is not None :
107
- g_idx = None
108
- else :
109
- raise ValueError (
110
- "Must provide a value to observe if not using weight observer"
111
- )
112
-
113
- observer = getattr (module , f"{ base_name } _observer" )
86
+ value = module .weight if base_name == "weight" else value
87
+ observer : Observer = getattr (module , f"{ base_name } _observer" )
114
88
115
89
if should_calculate_gparam :
116
- global_scale = observer (
117
- value ,
118
- should_calculate_gparam = True ,
119
- )
90
+ global_scale = observer .get_global_scale (value )
120
91
update_offload_parameter (module , f"{ base_name } _global_scale" , global_scale )
121
- else :
122
- global_scale = getattr (module , f"{ base_name } _global_scale" , None )
123
92
124
93
if should_calculate_qparams :
125
- updated_scale , updated_zero_point = observer (
126
- value , g_idx = g_idx , global_scale = global_scale
127
- )
128
- # register or update scale & zero_point parameters (supports block shapes)
129
- scale_name = f"{ base_name } _scale"
130
- zp_name = f"{ base_name } _zero_point"
131
- update_offload_parameter (module , scale_name , updated_scale )
132
- update_offload_parameter (module , zp_name , updated_zero_point )
94
+ scale , zero_point = observer (value )
95
+ update_offload_parameter (module , f"{ base_name } _scale" , scale )
96
+ update_offload_parameter (module , f"{ base_name } _zero_point" , zero_point )
133
97
134
98
135
99
def update_weight_global_scale (module : Module ):
@@ -148,7 +112,6 @@ def update_weight_global_scale(module: Module):
148
112
should_calculate_gparam = True ,
149
113
should_calculate_qparams = False ,
150
114
)
151
- module .weight_observer .reset ()
152
115
153
116
154
117
def update_weight_zp_scale (module : Module ):
0 commit comments