-
Notifications
You must be signed in to change notification settings - Fork 249
[Observers] Small observers cleanup, add e2e quantization tests #1830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
7a3d43a
25e54a5
06ece86
178d0ae
7d81a79
05d780a
dd91329
a6b3842
c098447
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
from typing import Any, Optional, Tuple | ||
from typing import Any, Iterable, Optional, Tuple, Union | ||
|
||
import torch | ||
from compressed_tensors.quantization.quant_args import QuantizationArgs | ||
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam | ||
from compressed_tensors.utils import deprecated | ||
from compressed_tensors.utils import deprecated, patch_attr | ||
|
||
from llmcompressor.observers.base import Observer | ||
|
||
|
@@ -58,6 +58,8 @@ def calculate_updated_min_max( | |
|
||
# early stopping, save some computation and memory | ||
if self.averaging_constant == 1.0: | ||
self.min_val[tensor_id] = min_val | ||
self.max_val[tensor_id] = max_val | ||
return min_val, max_val | ||
|
||
running_min_val = self.min_val.get(tensor_id, None) | ||
|
@@ -86,9 +88,11 @@ def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor: | |
:return: updated global scale derived from the observed tensor | ||
""" | ||
|
||
updated_min_val, updated_max_val = self.calculate_updated_min_max( | ||
observed=observed | ||
) | ||
# patch to avoid affecting running means | ||
|
||
with patch_attr(self, "min_val", {}), patch_attr(self, "max_val", {}): | ||
updated_min_val, updated_max_val = self.calculate_updated_min_max( | ||
observed=observed | ||
) | ||
return generate_gparam( | ||
updated_min_val=updated_min_val, updated_max_val=updated_max_val | ||
) | ||
|
@@ -126,14 +130,23 @@ def calculate_qparams( | |
def get_qparams_along_dim( | ||
self, | ||
observed: torch.Tensor, | ||
dim: int, | ||
dim: Union[int, Iterable[int]], | ||
tensor_id: Optional[Any] = None, | ||
global_scale: Optional[torch.Tensor] = None, | ||
): | ||
""" | ||
Calculate quantization parameters along the specified dimension | ||
""" | ||
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) | ||
# cast to set | ||
if isinstance(dim, int): | ||
dim = [dim] | ||
dim = set(dim) | ||
|
||
# convert negative dims | ||
dim = [d if d >= 0 else observed.ndim + d for d in dim] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't the cast to set happen after this line? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically either is fine, since the argument type just needs to be an iterable. I'm purely matching the implementation on the base model for now
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean more that you might end up with duplicates in e.g. if there are 3 dims and dim= |
||
|
||
# reduce all dimensions except the the one passed as argument to this function | ||
kylesayrs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) | ||
return self.calculate_qparams( | ||
observed, | ||
reduce_dims=reduce_dims, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we only attach one observer, I’m fairly sure we’re resetting to prevent global scale metrics from impacting quant scale metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reset is replaced by patching and restoring the metrics