diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index 90ecb1d318..b63df895a6 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -7,7 +7,7 @@ from typing import Any, Generator #from .helpers import TensorKeyWeakValueDictionary -from weakref import WeakKeyDictionary +from weakref import WeakKeyDictionary, ReferenceType, ref import torch from tqdm import tqdm @@ -48,7 +48,7 @@ class IntermediatesCache: # onload value -> offload value # used to avoid excess memory usage when shared tensors are offloaded #offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary() - offload_values: dict[int, torch.Tensor] = dict() + offload_values: dict[int, ReferenceType[torch.Tensor]] = dict() def __init__( self, @@ -254,14 +254,13 @@ def _offload_value( # Note: due to a (bug) in WeakKeyDictionary, we must check tensors using # id. this is UNSAFE, since once the onloaded tensor is deleted, other # python objects can reuse that id, leading to collisions. - key = id(value) + sum(value.shape) + len(value.shape) #torch.hash_tensor(torch.view_as_real(value) if value.is_complex() else value) - if key in cls.offload_values: - offloaded = cls.offload_values[key] + value_hash = torch.hash_tensor(torch.view_as_real(value) if value.is_complex() else value) + if value_hash in cls.offload_values: + offloaded = cls.offload_values[value_hash]() else: # move to offload if no hit offloaded = value.to(device=offload_device) - cls.offload_values[key] = offloaded - offloaded = value.to(device=offload_device) + cls.offload_values[value_hash] = ref(offloaded) return IntermediateValue( value=offloaded,