Skip to content

Commit f2e6efc

Browse files
Bdellabe/shared tensor caching (#2367)
SUMMARY: "please provide a brief summary" TEST PLAN: "please outline how the changes were tested" Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Signed-off-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent d8f813d commit f2e6efc

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/llmcompressor/pipelines/cache.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, Generator
88

99
#from .helpers import TensorKeyWeakValueDictionary
10-
from weakref import WeakKeyDictionary
10+
from weakref import WeakKeyDictionary, ReferenceType, ref
1111

1212
import torch
1313
from tqdm import tqdm
@@ -48,7 +48,7 @@ class IntermediatesCache:
4848
# onload value -> offload value
4949
# used to avoid excess memory usage when shared tensors are offloaded
5050
#offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary()
51-
offload_values: dict[int, torch.Tensor] = dict()
51+
offload_values: dict[int, ReferenceType[torch.Tensor]] = dict()
5252

5353
def __init__(
5454
self,
@@ -254,14 +254,13 @@ def _offload_value(
254254
# Note: due to a (bug) in WeakKeyDictionary, we must check tensors using
255255
# id. this is UNSAFE, since once the onloaded tensor is deleted, other
256256
# python objects can reuse that id, leading to collisions.
257-
key = id(value) + sum(value.shape) + len(value.shape) #torch.hash_tensor(torch.view_as_real(value) if value.is_complex() else value)
258-
if key in cls.offload_values:
259-
offloaded = cls.offload_values[key]
257+
value_hash = torch.hash_tensor(torch.view_as_real(value) if value.is_complex() else value)
258+
if value_hash in cls.offload_values:
259+
offloaded = cls.offload_values[value_hash]()
260260
else:
261261
# move to offload if no hit
262262
offloaded = value.to(device=offload_device)
263-
cls.offload_values[key] = offloaded
264-
offloaded = value.to(device=offload_device)
263+
cls.offload_values[value_hash] = ref(offloaded)
265264

266265
return IntermediateValue(
267266
value=offloaded,

0 commit comments

Comments
 (0)