diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index bf36824826..7a22d3a169 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -172,10 +172,6 @@ class AWQModifier(Modifier, QuantizationMixin): ) # List to store error metrics for each layer _error_metrics: list[dict] = PrivateAttr(default_factory=list) - # Cache FP16 baseline outputs for each parent module, one list of tensors per batch - _fp16_baseline_cache: dict[Module, IntermediatesCache] = PrivateAttr( - default_factory=dict - ) def on_initialize(self, state: State, **kwargs) -> bool: """ diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index f2ade1e88a..62998be410 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -12,37 +12,6 @@ from tqdm import tqdm -class OverrideEqMode(TorchDispatchMode): - """ - When using a torch.Tensor as a key in a dictionary, the equality - check must return a single value instead of a torch.Tensor - of bool values. - Use this override context for such cases, to swap out the torch.eq - equality check for a check on id - >>> a = torch.tensor([1,2,3]) - >>> b = torch.tensor([1,2,3]) - >>> a == b - tensor([True, True, True]) - >>> with OverrideEqMode(): - ... a == b - tensor(True) - """ - - def __torch_dispatch__(self, func, _types, args=(), kwargs=None): - kwargs = kwargs or {} - - # Check if the operation is equality - if func is torch.ops.aten.eq.Tensor: - # Override to use torch.equal - assert len(args) == 2, "Exactly 2 args must be provided" - - # NOTE: Errors out without cast to torch.tensor - return torch.tensor(id(args[0]) == id(args[1])) - - # For all other operations, just run them normally - return func(*args, **kwargs) - - @dataclass class IntermediateValue: """ @@ -289,7 +258,8 @@ def _offload_value( else: # move to offload if no hit offloaded = value.to(device=offload_device) - cls.offload_values[value] = offloaded + if offloaded is not value: # avoid circular ref + cls.offload_values[value] = offloaded return IntermediateValue( value=offloaded, @@ -326,3 +296,34 @@ def _offload_value( ): warnings.warn(f"Offloading not implemented for type {type(value)}.") return IntermediateValue(value=value, device=None) + + +class OverrideEqMode(TorchDispatchMode): + """ + When using a torch.Tensor as a key in a dictionary, the equality + check must return a single value instead of a torch.Tensor + of bool values. + Use this override context for such cases, to swap out the torch.eq + equality check for a check on id + >>> a = torch.tensor([1,2,3]) + >>> b = torch.tensor([1,2,3]) + >>> a == b + tensor([True, True, True]) + >>> with OverrideEqMode(): + ... a == b + tensor(True) + """ + + def __torch_dispatch__(self, func, _types, args=(), kwargs=None): + kwargs = kwargs or {} + + # Check if the operation is equality + if func is torch.ops.aten.eq.Tensor: + # Override to use torch.equal + assert len(args) == 2, "Exactly 2 args must be provided" + + # NOTE: Errors out without cast to torch.tensor + return torch.tensor(id(args[0]) == id(args[1])) + + # For all other operations, just run them normally + return func(*args, **kwargs)