Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,6 @@ class AWQModifier(Modifier, QuantizationMixin):
)
# List to store error metrics for each layer
_error_metrics: list[dict] = PrivateAttr(default_factory=list)
# Cache FP16 baseline outputs for each parent module, one list of tensors per batch
_fp16_baseline_cache: dict[Module, IntermediatesCache] = PrivateAttr(
default_factory=dict
)

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand Down
65 changes: 33 additions & 32 deletions src/llmcompressor/pipelines/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,6 @@
from tqdm import tqdm


class OverrideEqMode(TorchDispatchMode):
"""
When using a torch.Tensor as a key in a dictionary, the equality
check must return a single value instead of a torch.Tensor
of bool values.
Use this override context for such cases, to swap out the torch.eq
equality check for a check on id
>>> a = torch.tensor([1,2,3])
>>> b = torch.tensor([1,2,3])
>>> a == b
tensor([True, True, True])
>>> with OverrideEqMode():
... a == b
tensor(True)
"""

def __torch_dispatch__(self, func, _types, args=(), kwargs=None):
kwargs = kwargs or {}

# Check if the operation is equality
if func is torch.ops.aten.eq.Tensor:
# Override to use torch.equal
assert len(args) == 2, "Exactly 2 args must be provided"

# NOTE: Errors out without cast to torch.tensor
return torch.tensor(id(args[0]) == id(args[1]))

# For all other operations, just run them normally
return func(*args, **kwargs)


@dataclass
class IntermediateValue:
"""
Expand Down Expand Up @@ -286,7 +255,8 @@ def _offload_value(
else:
# move to offload if no hit
offloaded = value.to(device=offload_device)
cls.offload_values[value] = offloaded
if offloaded is not value: # avoid circular ref
cls.offload_values[value] = offloaded

return IntermediateValue(
value=offloaded,
Expand Down Expand Up @@ -323,3 +293,34 @@ def _offload_value(
):
warnings.warn(f"Offloading not implemented for type {type(value)}.")
return IntermediateValue(value=value, device=None)


class OverrideEqMode(TorchDispatchMode):
"""
When using a torch.Tensor as a key in a dictionary, the equality
check must return a single value instead of a torch.Tensor
of bool values.
Use this override context for such cases, to swap out the torch.eq
equality check for a check on id
>>> a = torch.tensor([1,2,3])
>>> b = torch.tensor([1,2,3])
>>> a == b
tensor([True, True, True])
>>> with OverrideEqMode():
... a == b
tensor(True)
"""

def __torch_dispatch__(self, func, _types, args=(), kwargs=None):
kwargs = kwargs or {}

# Check if the operation is equality
if func is torch.ops.aten.eq.Tensor:
# Override to use torch.equal
assert len(args) == 2, "Exactly 2 args must be provided"

# NOTE: Errors out without cast to torch.tensor
return torch.tensor(id(args[0]) == id(args[1]))

# For all other operations, just run them normally
return func(*args, **kwargs)
Loading