|
19 | 19 | from llmcompressor.observers import Observer
|
20 | 20 | from llmcompressor.utils.helpers import getattr_chain
|
21 | 21 |
|
22 |
| -DEFAULT_MAXSHRINK = 0.20 |
23 |
| -DEFAULT_PATIENCE = 5 |
24 |
| -DEFAULT_AVERAGING_CONSTANT = 0.01 |
25 |
| -DEFAULT_GRID = 100.0 |
26 |
| -DEFAULT_NORM = 2.4 |
27 |
| - |
28 | 22 | __all__ = [
|
29 | 23 | "initialize_observer",
|
30 | 24 | "update_weight_zp_scale",
|
@@ -61,24 +55,11 @@ def initialize_observer(
|
61 | 55 | # no quantization scheme nothing to do
|
62 | 56 | return
|
63 | 57 |
|
64 |
| - quantization_args = getattr(quantization_scheme, arg_name, None) |
| 58 | + args = getattr(quantization_scheme, arg_name, None) |
65 | 59 | # dont need observers for dynamic
|
66 |
| - if quantization_args is not None and quantization_args.dynamic in ( |
67 |
| - False, |
68 |
| - DynamicType.LOCAL, |
69 |
| - ): |
70 |
| - observer_kwargs = quantization_args.observer_kwargs or {} |
| 60 | + if args is not None and args.dynamic in (False, DynamicType.LOCAL): |
71 | 61 | observer = Observer.load_from_registry(
|
72 |
| - quantization_args.observer, |
73 |
| - quantization_args=quantization_args, |
74 |
| - averaging_constant=observer_kwargs.get( |
75 |
| - "averaging_constant", DEFAULT_AVERAGING_CONSTANT |
76 |
| - ), |
77 |
| - # used by mse observer only, will be ignored by minmax observer |
78 |
| - maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK), |
79 |
| - patience=observer_kwargs.get("patience", DEFAULT_PATIENCE), |
80 |
| - grid=observer_kwargs.get("grid", DEFAULT_GRID), |
81 |
| - norm=observer_kwargs.get("norm", DEFAULT_NORM), |
| 62 | + args.observer, base_name=base_name, args=args, module=module |
82 | 63 | )
|
83 | 64 | module.register_module(f"{base_name}_observer", observer)
|
84 | 65 |
|
@@ -110,13 +91,10 @@ def call_observer(
|
110 | 91 | "Must provide a value to observe if not using weight observer"
|
111 | 92 | )
|
112 | 93 |
|
113 |
| - observer = getattr(module, f"{base_name}_observer") |
| 94 | + observer: Observer = getattr(module, f"{base_name}_observer") |
114 | 95 |
|
115 | 96 | if should_calculate_gparam:
|
116 |
| - global_scale = observer( |
117 |
| - value, |
118 |
| - should_calculate_gparam=True, |
119 |
| - ) |
| 97 | + global_scale = observer.get_global_scale(value) |
120 | 98 | update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
|
121 | 99 | else:
|
122 | 100 | global_scale = getattr(module, f"{base_name}_global_scale", None)
|
@@ -148,7 +126,6 @@ def update_weight_global_scale(module: Module):
|
148 | 126 | should_calculate_gparam=True,
|
149 | 127 | should_calculate_qparams=False,
|
150 | 128 | )
|
151 |
| - module.weight_observer.reset() |
152 | 129 |
|
153 | 130 |
|
154 | 131 | def update_weight_zp_scale(module: Module):
|
|
0 commit comments