|
7 | 7 | from typing import Any, Generator |
8 | 8 |
|
9 | 9 | #from .helpers import TensorKeyWeakValueDictionary |
10 | | -from weakref import WeakKeyDictionary |
| 10 | +from weakref import WeakKeyDictionary, ReferenceType, ref |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | from tqdm import tqdm |
@@ -48,7 +48,7 @@ class IntermediatesCache: |
48 | 48 | # onload value -> offload value |
49 | 49 | # used to avoid excess memory usage when shared tensors are offloaded |
50 | 50 | #offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary() |
51 | | - offload_values: dict[int, torch.Tensor] = dict() |
| 51 | + offload_values: dict[int, ReferenceType[torch.Tensor]] = dict() |
52 | 52 |
|
53 | 53 | def __init__( |
54 | 54 | self, |
@@ -254,14 +254,13 @@ def _offload_value( |
254 | 254 | # Note: due to a (bug) in WeakKeyDictionary, we must check tensors using |
255 | 255 | # id. this is UNSAFE, since once the onloaded tensor is deleted, other |
256 | 256 | # python objects can reuse that id, leading to collisions. |
257 | | - key = id(value) + sum(value.shape) + len(value.shape) #torch.hash_tensor(torch.view_as_real(value) if value.is_complex() else value) |
258 | | - if key in cls.offload_values: |
259 | | - offloaded = cls.offload_values[key] |
| 257 | + value_hash = torch.hash_tensor(torch.view_as_real(value) if value.is_complex() else value) |
| 258 | + if value_hash in cls.offload_values: |
| 259 | + offloaded = cls.offload_values[value_hash]() |
260 | 260 | else: |
261 | 261 | # move to offload if no hit |
262 | 262 | offloaded = value.to(device=offload_device) |
263 | | - cls.offload_values[key] = offloaded |
264 | | - offloaded = value.to(device=offload_device) |
| 263 | + cls.offload_values[value_hash] = ref(offloaded) |
265 | 264 |
|
266 | 265 | return IntermediateValue( |
267 | 266 | value=offloaded, |
|
0 commit comments