Skip to content

Commit c63986a

Browse files
committed
fix typos
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 15a551e commit c63986a

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

src/llmcompressor/observers/helpers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@ def flatten_for_calibration(
2222
args: QuantizationArgs,
2323
g_idx: Optional[torch.Tensor] = None,
2424
) -> torch.Tensor:
25+
"""
26+
Reshapes the value according to the quantization strategy for the purposes of
27+
scale/zp calibration. The value after flattening has the following shape:
28+
29+
`(num_observations, *qparam_shape, group_size)`
30+
31+
The first dim is the number of observations (usually the batch size times number of
32+
tokens), the middle dims are the dimension of the scales, and the last dim is the
33+
number of elements being quantized per group.
34+
35+
:param value: value being flattened
36+
:param base_name: weight, input, output, q/k/v. Used to characterize the value as
37+
being a weight, activation, or attention state
38+
:param args: quantization args for determining how the value is flattened
39+
:param g_idx: optional gidx for weight activation ordering
40+
:return: value which has been reshaped for calibration
41+
"""
2542
if base_name == "weight":
2643
return _flatten_weight(value, args, g_idx)
2744
elif base_name in ("input", "output"):

src/llmcompressor/observers/min_max.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
module: Optional[torch.nn.Module] = None,
2424
**observer_kwargs,
2525
):
26-
super().__init__(module, base_name, args)
26+
super().__init__(base_name, args, module, **observer_kwargs)
2727

2828
observer_kwargs = self.args.observer_kwargs
2929
self.averaging_constant = observer_kwargs.get("averaging_constant", 0.01)

src/llmcompressor/observers/mse.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
module: Optional[torch.nn.Module] = None,
2424
**observer_kwargs,
2525
):
26-
super().__init__(module, base_name, args)
26+
super().__init__(base_name, args, module, **observer_kwargs)
2727

2828
observer_kwargs = self.args.observer_kwargs
2929
self.maxshrink = observer_kwargs.get("maxshrink", 0.20)
@@ -51,8 +51,7 @@ def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso
5151
)
5252
min_val = torch.ones_like(absolute_min_val)
5353
max_val = torch.zeros_like(absolute_max_val)
54-
55-
global_scale = getattr(self.parent(), f"{self.base_name}_global_scale", None)
54+
global_scale = self._get_module_param("global_scale")
5655

5756
# Early stopping params
5857
no_improve_count = 0

0 commit comments

Comments
 (0)