Skip to content

Commit b26703f

Browse files
kylesayrsdsikka
andauthored
[Utils] Offloaded cache size (#1714)
## Purpose ## * Make debugging offloaded cache memory easier --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent e5591f4 commit b26703f

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

src/llmcompressor/pipelines/cache.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import sys
12
import warnings
3+
from collections import defaultdict
24
from dataclasses import dataclass, fields, is_dataclass
35
from typing import Any, Dict, Generator, List, Optional, Union
46

@@ -132,10 +134,51 @@ def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None):
132134
del intermediates[name]
133135

134136
def append(self, values: Dict[str, Any]):
137+
"""
138+
Append new values to the cache. The new values will be assigned the next
139+
available batch index
140+
141+
:param values: dictionary mapping keys to values used for update
142+
"""
135143
batch_index = len(self.batch_intermediates)
136144
self.batch_intermediates.append({})
137145
self.update(batch_index, values)
138146

147+
def size(self) -> Dict[torch.device, int]:
148+
"""
149+
Returns the memory used by cached values, keyed by device, in bytes
150+
151+
:return: dictionary mapping torch device to number of bytes in cache
152+
"""
153+
sizes = defaultdict(lambda: 0)
154+
155+
def _size_helper(intermediate: IntermediateValue) -> int:
156+
value = intermediate.value
157+
158+
if isinstance(value, torch.Tensor):
159+
sizes[value.device] += value.nbytes
160+
161+
elif is_dataclass(value):
162+
for field in fields(value):
163+
_size_helper(getattr(value, field.name))
164+
165+
elif isinstance(value, (tuple, list)):
166+
for v in value:
167+
_size_helper(v)
168+
169+
elif isinstance(value, dict):
170+
for v in value.values():
171+
_size_helper(v)
172+
173+
else:
174+
sizes[torch.device("cpu")] += sys.getsizeof(value, 0)
175+
176+
for intermediates in self.batch_intermediates:
177+
for value in intermediates.values():
178+
_size_helper(value)
179+
180+
return dict(sizes)
181+
139182
def iter(
140183
self, input_names: Optional[List[str]] = None
141184
) -> Generator[Any, None, None]:
@@ -162,6 +205,9 @@ def _onload_value(self, intermediate: IntermediateValue) -> Any:
162205

163206
return value
164207

208+
if isinstance(value, list):
209+
return list(self._onload_value(v) for v in value)
210+
165211
if isinstance(value, tuple):
166212
return tuple(self._onload_value(v) for v in value)
167213

@@ -188,6 +234,11 @@ def _offload_value(self, value: Any) -> IntermediateValue:
188234

189235
return IntermediateValue(value=value, device=None)
190236

237+
if isinstance(value, list):
238+
return IntermediateValue(
239+
value=list(self._offload_value(v) for v in value), device=None
240+
)
241+
191242
if isinstance(value, tuple):
192243
return IntermediateValue(
193244
value=tuple(self._offload_value(v) for v in value), device=None

0 commit comments

Comments
 (0)