Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def update_weight_global_scale(module: Module):
should_calculate_gparam=True,
should_calculate_qparams=False,
)
module.weight_observer.reset()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we only attach one observer, I’m fairly sure we’re resetting to prevent global scale metrics from impacting quant scale metrics

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update calculate_gparam to restore original running values, rather than relying on resetting after calculation

Reset is replaced by patching and restoring the metrics



def update_weight_zp_scale(module: Module):
Expand Down
27 changes: 20 additions & 7 deletions src/llmcompressor/observers/min_max.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Optional, Tuple
from typing import Any, Iterable, Optional, Tuple, Union

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
from compressed_tensors.utils import deprecated
from compressed_tensors.utils import deprecated, patch_attr

from llmcompressor.observers.base import Observer

Expand Down Expand Up @@ -58,6 +58,8 @@ def calculate_updated_min_max(

# early stopping, save some computation and memory
if self.averaging_constant == 1.0:
self.min_val[tensor_id] = min_val
self.max_val[tensor_id] = max_val
return min_val, max_val

running_min_val = self.min_val.get(tensor_id, None)
Expand Down Expand Up @@ -86,9 +88,11 @@ def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor:
:return: updated global scale derived from the observed tensor
"""

updated_min_val, updated_max_val = self.calculate_updated_min_max(
observed=observed
)
# patch to avoid affecting running means
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because we are calculating global scale right? we don't want the calculate_qparams result to change based on this calculation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update calculate_gparam to restore original running values, rather than relying on resetting after calculation

Yes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this preferable? If anything, this now seems more confusing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a programming standpoint, this decouples calculate_gparam and Observer.reset (there's no way to footgun yourself by calling calculate_gparam and forgetting to call Observer.reset.

From a functionality standpoint, I think this fixes a bug where metrics would be updated twice (which has implications for running values), specifically when called from calibrate_activations. In the case of activations, we don't want to reset after each gparam calculation, since we still need those metrics to compute running values.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right about the 2nd point.

I don't know if I agree with the first point. This feels like a hack.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Case 1

Consider the case of strategy="tensor_group", dynamic=False and averaging_constant != 1.

On activation hook, calibrate_activations calls call_observer with should_calculate_gparam=True*. This causes calculate_updated_min_max to be called twice, which causes the running min/max to move faster than if no global param was calculated.

Case 2

Consider the case of strategy="tensor_group", dynamic="local" and averaging_constant != 1.

Originally, calculate_gparam would call calculate_updated_min_max would be called and the running values would update (twice*). Now, the running values will not update.

* Note that running values are updated, even if should_calculate_qparams=False

TLDR

So it seems that this change fixes a bug where running values are updated twice, but changes the behavior of dynamic="local" to calculate global parameters based on true values, not running values. I assumed that global parameters should be the true min/max of all values, not running values, but maybe @dsikka you think this shouldn't be the case?

I've reverted the change since it's not necessary for group quant, but we should definitely look into exactly the behavior we want for global scales (and scales in general. Runnings means are slightly strange anyways and seem to be a vestige of QAT).

with patch_attr(self, "min_val", {}), patch_attr(self, "max_val", {}):
updated_min_val, updated_max_val = self.calculate_updated_min_max(
observed=observed
)
return generate_gparam(
updated_min_val=updated_min_val, updated_max_val=updated_max_val
)
Expand Down Expand Up @@ -126,14 +130,23 @@ def calculate_qparams(
def get_qparams_along_dim(
self,
observed: torch.Tensor,
dim: int,
dim: Union[int, Iterable[int]],
tensor_id: Optional[Any] = None,
global_scale: Optional[torch.Tensor] = None,
):
"""
Calculate quantization parameters along the specified dimension
"""
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
# cast to set
if isinstance(dim, int):
dim = [dim]
dim = set(dim)

# convert negative dims
dim = [d if d >= 0 else observed.ndim + d for d in dim]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the cast to set happen after this line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically either is fine, since the argument type just needs to be an iterable. I'm purely matching the implementation on the base model for now

Update get_qparams_along_dim to support multiple dims and negative dims
This actually results in a silent typing bug with token quantization, and is fixed on the base class implementation
This change essentially duplicates the base class implementation. Future work could involve cleaning up the inheritance structure here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean more that you might end up with duplicates in dim if you create this list and don't cast back to a set.

e.g. if there are 3 dims and dim={1,2,-1}, then dim=[1,2,2] after this line.


# reduce all dimensions except the the one passed as argument to this function
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
return self.calculate_qparams(
observed,
reduce_dims=reduce_dims,
Expand Down
Loading
Loading