From 31b7cbe8c9c79a8a090ef1bfb375e06b49a7ff35 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Sat, 14 Feb 2026 01:01:53 +0000 Subject: [PATCH] using torch hash_tensor and weakref Signed-off-by: Brian Dellabetta --- src/llmcompressor/pipelines/cache.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index a90d8c32eb..b63df895a6 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -7,7 +7,7 @@ from typing import Any, Generator #from .helpers import TensorKeyWeakValueDictionary -from weakref import WeakKeyDictionary +from weakref import WeakKeyDictionary, ReferenceType, ref import torch from tqdm import tqdm @@ -48,7 +48,7 @@ class IntermediatesCache: # onload value -> offload value # used to avoid excess memory usage when shared tensors are offloaded #offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary() - offload_values: dict[int, torch.Tensor] = dict() + offload_values: dict[int, ReferenceType[torch.Tensor]] = dict() def __init__( self, @@ -254,12 +254,13 @@ def _offload_value( # Note: due to a (bug) in WeakKeyDictionary, we must check tensors using # id. this is UNSAFE, since once the onloaded tensor is deleted, other # python objects can reuse that id, leading to collisions. - if id(value) in cls.offload_values: - offloaded = cls.offload_values[id(value)] + value_hash = torch.hash_tensor(torch.view_as_real(value) if value.is_complex() else value) + if value_hash in cls.offload_values: + offloaded = cls.offload_values[value_hash]() else: # move to offload if no hit offloaded = value.to(device=offload_device) - cls.offload_values[id(value)] = offloaded + cls.offload_values[value_hash] = ref(offloaded) return IntermediateValue( value=offloaded,