diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index 50153a085a..f8f6d6ab08 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -5,11 +5,44 @@ from collections import defaultdict from dataclasses import dataclass, fields, is_dataclass from typing import Any, Generator +from weakref import WeakKeyDictionary import torch +from torch.utils._python_dispatch import TorchDispatchMode from tqdm import tqdm +class OverrideEqMode(TorchDispatchMode): + """ + When using a torch.Tensor as a key in a dictionary, the equality + check must return a single value instead of a torch.Tensor + of bool values. + Use this override context for such cases, to swap out the torch.eq + equality check for a check on id + >>> a = torch.tensor([1,2,3]) + >>> b = torch.tensor([1,2,3]) + >>> a == b + tensor([True, True, True]) + >>> with OverrideEqMode(): + ... a == b + tensor(True) + """ + + def __torch_dispatch__(self, func, _types, args=(), kwargs=None): + kwargs = kwargs or {} + + # Check if the operation is equality + if func is torch.ops.aten.eq.Tensor: + # Override to use torch.equal + assert len(args) == 2, "Exactly 2 args must be provided" + + # NOTE: Errors out without cast to torch.tensor + return torch.tensor(id(args[0]) == id(args[1])) + + # For all other operations, just run them normally + return func(*args, **kwargs) + + @dataclass class IntermediateValue: """ @@ -42,6 +75,10 @@ class IntermediatesCache: batch_intermediates: list[IntermediateValues] offload_device: torch.device | None + # map of onload value -> offload value + # used to avoid excess memory usage when shared tensors are offloaded + offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary() + def __init__( self, batch_intermediates: list[IntermediateValues] | None = None, @@ -154,13 +191,16 @@ def size(self) -> dict[torch.device, int]: :return: dictionary mapping torch device to number of bytes in cache """ sizes = defaultdict(lambda: 0) + memo = set() def _size_helper(intermediate: IntermediateValue) -> int: value = intermediate.value match value: case torch.Tensor(): - sizes[value.device] += value.nbytes + if value not in memo: + sizes[value.device] += value.nbytes + memo.add(value) case list() | tuple(): for v in value: _size_helper(v) @@ -239,8 +279,17 @@ def _offload_value( kwargs = {"offload_device": offload_device, "onload_device": onload_device} match value: case torch.Tensor(): + with OverrideEqMode(): + # check for cache hit between shared tensors + if value in cls.offload_values: + offloaded = cls.offload_values[value] + else: + # move to offload if no hit + offloaded = value.to(device=offload_device) + cls.offload_values[value] = offloaded + return IntermediateValue( - value=value.to(device=offload_device), + value=offloaded, device=(onload_device if onload_device else value.device), ) case list(): diff --git a/tests/llmcompressor/pipelines/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py index eff88d881d..9c01ed8fba 100644 --- a/tests/llmcompressor/pipelines/test_cache.py +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -4,7 +4,7 @@ import torch from torch.utils.data import DataLoader, StackDataset -from llmcompressor.pipelines.cache import IntermediatesCache +from llmcompressor.pipelines.cache import IntermediatesCache, OverrideEqMode @dataclass @@ -162,3 +162,18 @@ def deep_equal(a, b) -> bool: return deep_equal(a_dict, b_dict) case _: return a == b + + +def test_override_eq_mode(): + a = torch.tensor([1, 2, 3]) + b = a + c = torch.tensor([2, 2, 2]) + + with pytest.raises(RuntimeError): + assert a == b + with pytest.raises(RuntimeError): + assert not (a == c) + + with OverrideEqMode(): + assert a == b + assert not (a == c)