Skip to content
16 changes: 15 additions & 1 deletion src/llmcompressor/pipelines/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from dataclasses import dataclass, fields, is_dataclass
from typing import Any, Generator

from weakref import WeakKeyDictionary

import torch
from tqdm import tqdm

Expand Down Expand Up @@ -42,6 +44,10 @@ class IntermediatesCache:
batch_intermediates: list[IntermediateValues]
offload_device: torch.device | None

# onload value -> offload value
# used to avoid excess memory usage when shared tensors are offloaded
offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary()

def __init__(
self,
batch_intermediates: list[IntermediateValues] | None = None,
Expand Down Expand Up @@ -239,8 +245,16 @@ def _offload_value(
kwargs = {"offload_device": offload_device, "onload_device": onload_device}
match value:
case torch.Tensor():
# check for cache hit
if value in cls.offload_values:
offloaded = cls.offload_values[value]
else:
# move to offload if no hit
offloaded = value.to(device=offload_device)
cls.offload_values[value] = offloaded

return IntermediateValue(
value=value.to(device=offload_device),
value=offloaded,
device=(onload_device if onload_device else value.device),
)
case list():
Expand Down
Loading