Skip to content
51 changes: 49 additions & 2 deletions src/llmcompressor/pipelines/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,42 @@
from collections import defaultdict
from dataclasses import dataclass, fields, is_dataclass
from typing import Any, Generator
from weakref import WeakKeyDictionary

import torch
from torch.utils._python_dispatch import TorchDispatchMode
from tqdm import tqdm


class OverrideEqMode(TorchDispatchMode):
"""
When using a torch.Tensor as a key in a dictionary, the equality
check must return a single value instead of a torch.Tensor
of bool values.
Use this override context for such cases, to swap out the torch.eq
check with a torch.equal check
>>> a = torch.tensor([1,2,3])
>>> b = torch.tensor([1,2,3])
>>> a == b
tensor([True, True, True])
>>> with OverrideEqMode():
... a == b
tensor(True)
"""

def __torch_dispatch__(self, func, _types, args=(), kwargs=None):
kwargs = kwargs or {}

# Check if the operation is equality
if func is torch.ops.aten.eq.Tensor:
# Override to use torch.equal
# NOTE: Errors out without cast to torch.tensor
return torch.tensor(torch.equal(*args, **kwargs))

# For all other operations, just run them normally
return func(*args, **kwargs)


@dataclass
class IntermediateValue:
"""
Expand Down Expand Up @@ -42,6 +73,10 @@ class IntermediatesCache:
batch_intermediates: list[IntermediateValues]
offload_device: torch.device | None

# map of onload value -> offload value
# used to avoid excess memory usage when shared tensors are offloaded
offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary()

def __init__(
self,
batch_intermediates: list[IntermediateValues] | None = None,
Expand Down Expand Up @@ -154,13 +189,16 @@ def size(self) -> dict[torch.device, int]:
:return: dictionary mapping torch device to number of bytes in cache
"""
sizes = defaultdict(lambda: 0)
memo = set()

def _size_helper(intermediate: IntermediateValue) -> int:
value = intermediate.value

match value:
case torch.Tensor():
sizes[value.device] += value.nbytes
if value not in memo:
sizes[value.device] += value.nbytes
memo.add(value)
case list() | tuple():
for v in value:
_size_helper(v)
Expand Down Expand Up @@ -239,8 +277,17 @@ def _offload_value(
kwargs = {"offload_device": offload_device, "onload_device": onload_device}
match value:
case torch.Tensor():
with OverrideEqMode():
# check for cache hit between shared tensors
if value in cls.offload_values:
offloaded = cls.offload_values[value]
else:
# move to offload if no hit
offloaded = value.to(device=offload_device)
cls.offload_values[value] = offloaded

return IntermediateValue(
value=value.to(device=offload_device),
value=offloaded,
device=(onload_device if onload_device else value.device),
)
case list():
Expand Down
17 changes: 16 additions & 1 deletion tests/llmcompressor/pipelines/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch.utils.data import DataLoader, StackDataset

from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.pipelines.cache import IntermediatesCache, OverrideEqMode


@dataclass
Expand Down Expand Up @@ -162,3 +162,18 @@ def deep_equal(a, b) -> bool:
return deep_equal(a_dict, b_dict)
case _:
return a == b


def test_override_eq_mode():
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
c = torch.tensor([2, 2, 2])

with pytest.raises(RuntimeError):
assert a == b
with pytest.raises(RuntimeError):
assert not (a == c)

with OverrideEqMode():
assert a == b
assert not (a == c)
Loading