Skip to content

Commit 2e469aa

Browse files
kylesayrsHDCharles
andauthored
[Bugfix] Fix circular references when activation offload device is cuda (#2387)
## Background ## #2366 introduced a `WeakKeyDictionary` layer which caches shared tensors. This is a good approach, but has an edge case where, if the value of entry is identical to the key of the entry, then the key will never be garbage collected. This can occur if the user specifies `sequential_offload_device="cuda"`, or if the AWQ offload device is "cuda" (default true in most cases). ## Purpose ## * Fix memory leak in AWQ which led to very high CUDA memory usage ## Changes ## * Guard against entries into the `WeakKeyDictionary` where the key and value are identical * Misc * Move `OverrideEqMode` to the bottom of the `pipelines/cache.py` * Remove `_fp16_baseline_cache`, which was not being used ## Testing ## | Before Changes | After Changes | | - | - | | <img width="640" height="480" alt="awq_before" src="https://github.com/user-attachments/assets/07714321-4b2f-49b7-aa2b-5c745a60d2f4" /> | <img width="640" height="480" alt="awq_after" src="https://github.com/user-attachments/assets/336b0e98-c24c-4e0c-a873-3166effc32b7" /> | --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
1 parent 9979e98 commit 2e469aa

File tree

2 files changed

+33
-36
lines changed

2 files changed

+33
-36
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,6 @@ class AWQModifier(Modifier, QuantizationMixin):
172172
)
173173
# List to store error metrics for each layer
174174
_error_metrics: list[dict] = PrivateAttr(default_factory=list)
175-
# Cache FP16 baseline outputs for each parent module, one list of tensors per batch
176-
_fp16_baseline_cache: dict[Module, IntermediatesCache] = PrivateAttr(
177-
default_factory=dict
178-
)
179175

180176
def on_initialize(self, state: State, **kwargs) -> bool:
181177
"""

src/llmcompressor/pipelines/cache.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,6 @@
1212
from tqdm import tqdm
1313

1414

15-
class OverrideEqMode(TorchDispatchMode):
16-
"""
17-
When using a torch.Tensor as a key in a dictionary, the equality
18-
check must return a single value instead of a torch.Tensor
19-
of bool values.
20-
Use this override context for such cases, to swap out the torch.eq
21-
equality check for a check on id
22-
>>> a = torch.tensor([1,2,3])
23-
>>> b = torch.tensor([1,2,3])
24-
>>> a == b
25-
tensor([True, True, True])
26-
>>> with OverrideEqMode():
27-
... a == b
28-
tensor(True)
29-
"""
30-
31-
def __torch_dispatch__(self, func, _types, args=(), kwargs=None):
32-
kwargs = kwargs or {}
33-
34-
# Check if the operation is equality
35-
if func is torch.ops.aten.eq.Tensor:
36-
# Override to use torch.equal
37-
assert len(args) == 2, "Exactly 2 args must be provided"
38-
39-
# NOTE: Errors out without cast to torch.tensor
40-
return torch.tensor(id(args[0]) == id(args[1]))
41-
42-
# For all other operations, just run them normally
43-
return func(*args, **kwargs)
44-
45-
4615
@dataclass
4716
class IntermediateValue:
4817
"""
@@ -289,7 +258,8 @@ def _offload_value(
289258
else:
290259
# move to offload if no hit
291260
offloaded = value.to(device=offload_device)
292-
cls.offload_values[value] = offloaded
261+
if offloaded is not value: # avoid circular ref
262+
cls.offload_values[value] = offloaded
293263

294264
return IntermediateValue(
295265
value=offloaded,
@@ -326,3 +296,34 @@ def _offload_value(
326296
):
327297
warnings.warn(f"Offloading not implemented for type {type(value)}.")
328298
return IntermediateValue(value=value, device=None)
299+
300+
301+
class OverrideEqMode(TorchDispatchMode):
302+
"""
303+
When using a torch.Tensor as a key in a dictionary, the equality
304+
check must return a single value instead of a torch.Tensor
305+
of bool values.
306+
Use this override context for such cases, to swap out the torch.eq
307+
equality check for a check on id
308+
>>> a = torch.tensor([1,2,3])
309+
>>> b = torch.tensor([1,2,3])
310+
>>> a == b
311+
tensor([True, True, True])
312+
>>> with OverrideEqMode():
313+
... a == b
314+
tensor(True)
315+
"""
316+
317+
def __torch_dispatch__(self, func, _types, args=(), kwargs=None):
318+
kwargs = kwargs or {}
319+
320+
# Check if the operation is equality
321+
if func is torch.ops.aten.eq.Tensor:
322+
# Override to use torch.equal
323+
assert len(args) == 2, "Exactly 2 args must be provided"
324+
325+
# NOTE: Errors out without cast to torch.tensor
326+
return torch.tensor(id(args[0]) == id(args[1]))
327+
328+
# For all other operations, just run them normally
329+
return func(*args, **kwargs)

0 commit comments

Comments
 (0)