Skip to content

Commit 0d556a7

Browse files
[Sequential Pipeline] only cache unique offloaded values (#2366)
Updated by @brian-dellabetta SUMMARY: The SequentialPipeline offloads subgraph outputs as part of normal usage. Occasionally these outputs share duplicates in kwargs that point to the same memory location on the onloaded device. When offloading is enabled, there was previously no check to see if any tensors to be offloaded had already previously been offloaded, which can cause a huge increase in memory requirements in some models, as reported in #2363. This PR - [x] adds an offload map to IntermediatesCache to ensure tensors are not redundantly offloaded - [x] wraps the map in an override to ensure `torch.equal` is used rather than `torch.eq` (which is the one used with `==` checks). `torch.eq` can return multiple boolean values depending on the tensors being compared, resulting in an error. This override, which should only be used when the tensors are immutable (the case here), allows us to retain the original hashing function and have an `O(1)` lookup. Our other attempts to circumvent the issue added to runtime or required `O(N)` lookup. Resolves #2363 TEST PLAN: - [x] Unit test added for `OverrideEqMode` - [x] Script from #2363 runs with ~81GB CPU RAM after first layer propagation, increased to ~88GB CPU RAM used by layer 11/49, and then stays consistently <89GB CPU RAM used by layer 25/49. On current main, this script would hit ~750GB CPU RAM usage during first layer propagastion --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Signed-off-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Co-authored-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 556b503 commit 0d556a7

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

src/llmcompressor/pipelines/cache.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,44 @@
55
from collections import defaultdict
66
from dataclasses import dataclass, fields, is_dataclass
77
from typing import Any, Generator
8+
from weakref import WeakKeyDictionary
89

910
import torch
11+
from torch.utils._python_dispatch import TorchDispatchMode
1012
from tqdm import tqdm
1113

1214

15+
class OverrideEqMode(TorchDispatchMode):
16+
"""
17+
When using a torch.Tensor as a key in a dictionary, the equality
18+
check must return a single value instead of a torch.Tensor
19+
of bool values.
20+
Use this override context for such cases, to swap out the torch.eq
21+
equality check for a check on id
22+
>>> a = torch.tensor([1,2,3])
23+
>>> b = torch.tensor([1,2,3])
24+
>>> a == b
25+
tensor([True, True, True])
26+
>>> with OverrideEqMode():
27+
... a == b
28+
tensor(True)
29+
"""
30+
31+
def __torch_dispatch__(self, func, _types, args=(), kwargs=None):
32+
kwargs = kwargs or {}
33+
34+
# Check if the operation is equality
35+
if func is torch.ops.aten.eq.Tensor:
36+
# Override to use torch.equal
37+
assert len(args) == 2, "Exactly 2 args must be provided"
38+
39+
# NOTE: Errors out without cast to torch.tensor
40+
return torch.tensor(id(args[0]) == id(args[1]))
41+
42+
# For all other operations, just run them normally
43+
return func(*args, **kwargs)
44+
45+
1346
@dataclass
1447
class IntermediateValue:
1548
"""
@@ -42,6 +75,10 @@ class IntermediatesCache:
4275
batch_intermediates: list[IntermediateValues]
4376
offload_device: torch.device | None
4477

78+
# map of onload value -> offload value
79+
# used to avoid excess memory usage when shared tensors are offloaded
80+
offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary()
81+
4582
def __init__(
4683
self,
4784
batch_intermediates: list[IntermediateValues] | None = None,
@@ -154,13 +191,16 @@ def size(self) -> dict[torch.device, int]:
154191
:return: dictionary mapping torch device to number of bytes in cache
155192
"""
156193
sizes = defaultdict(lambda: 0)
194+
memo = set()
157195

158196
def _size_helper(intermediate: IntermediateValue) -> int:
159197
value = intermediate.value
160198

161199
match value:
162200
case torch.Tensor():
163-
sizes[value.device] += value.nbytes
201+
if value not in memo:
202+
sizes[value.device] += value.nbytes
203+
memo.add(value)
164204
case list() | tuple():
165205
for v in value:
166206
_size_helper(v)
@@ -239,8 +279,17 @@ def _offload_value(
239279
kwargs = {"offload_device": offload_device, "onload_device": onload_device}
240280
match value:
241281
case torch.Tensor():
282+
with OverrideEqMode():
283+
# check for cache hit between shared tensors
284+
if value in cls.offload_values:
285+
offloaded = cls.offload_values[value]
286+
else:
287+
# move to offload if no hit
288+
offloaded = value.to(device=offload_device)
289+
cls.offload_values[value] = offloaded
290+
242291
return IntermediateValue(
243-
value=value.to(device=offload_device),
292+
value=offloaded,
244293
device=(onload_device if onload_device else value.device),
245294
)
246295
case list():

tests/llmcompressor/pipelines/test_cache.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch.utils.data import DataLoader, StackDataset
66

7-
from llmcompressor.pipelines.cache import IntermediatesCache
7+
from llmcompressor.pipelines.cache import IntermediatesCache, OverrideEqMode
88

99

1010
@dataclass
@@ -162,3 +162,18 @@ def deep_equal(a, b) -> bool:
162162
return deep_equal(a_dict, b_dict)
163163
case _:
164164
return a == b
165+
166+
167+
def test_override_eq_mode():
168+
a = torch.tensor([1, 2, 3])
169+
b = a
170+
c = torch.tensor([2, 2, 2])
171+
172+
with pytest.raises(RuntimeError):
173+
assert a == b
174+
with pytest.raises(RuntimeError):
175+
assert not (a == c)
176+
177+
with OverrideEqMode():
178+
assert a == b
179+
assert not (a == c)

0 commit comments

Comments
 (0)