Skip to content

Commit 5c8206f

Browse files
committed
add back get_global_scale
Signed-off-by: Kyle Sayers <[email protected]>
1 parent be2d34b commit 5c8206f

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/llmcompressor/observers/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from compressed_tensors.quantization.quant_args import (
88
QuantizationArgs,
99
)
10-
from compressed_tensors.quantization.utils import calculate_qparams
10+
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
1111
from compressed_tensors.registry.registry import RegistryMixin
1212

1313
from llmcompressor.observers.helpers import flatten_for_calibration
@@ -75,6 +75,19 @@ def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
7575
global_scale=global_scale,
7676
)
7777

78+
def get_global_scale(self, observed: torch.Tensor) -> torch.nn.Parameter:
79+
"""
80+
Calculates updated global scale from observed value
81+
82+
:param observed: value being observed
83+
:return: calibrated global parameter
84+
"""
85+
observed = observed.reshape((1, 1, -1)) # per tensor reshape
86+
min_vals, max_vals = self.get_min_max(observed)
87+
global_scale = generate_gparam(min_vals, max_vals)
88+
89+
return global_scale
90+
7891
def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]:
7992
if self.module is None:
8093
return None

0 commit comments

Comments
 (0)