Skip to content

Commit eadcb73

Browse files
committed
add util
Signed-off-by: Kyle Sayers <[email protected]>
1 parent b88221b commit eadcb73

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

src/llmcompressor/pipelines/cache.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import sys
12
import warnings
3+
from collections import defaultdict
24
from dataclasses import dataclass, fields, is_dataclass
35
from typing import Any, Dict, Generator, List, Optional, Union
46

@@ -132,10 +134,51 @@ def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None):
132134
del intermediates[name]
133135

134136
def append(self, values: Dict[str, Any]):
137+
"""
138+
Append new values to the cache. The new values will be assigned the next
139+
available batch index
140+
141+
:param values: dictionary mapping keys to values used for update
142+
"""
135143
batch_index = len(self.batch_intermediates)
136144
self.batch_intermediates.append({})
137145
self.update(batch_index, values)
138146

147+
def size(self) -> Dict[torch.device, int]:
148+
"""
149+
Returns the memory used by cached values, keyed by device, in bytes
150+
151+
:return: dictionary mapping torch device to number of bytes in cache
152+
"""
153+
sizes = defaultdict(lambda: 0)
154+
155+
def _size_helper(intermediate: IntermediateValue) -> int:
156+
value = intermediate.value
157+
158+
if isinstance(value, torch.Tensor):
159+
sizes[value.device] += value.nbytes
160+
161+
elif is_dataclass(value):
162+
for field in fields(value):
163+
_size_helper(getattr(value, field.name))
164+
165+
elif isinstance(value, tuple):
166+
for v in value:
167+
_size_helper(v)
168+
169+
elif isinstance(value, dict):
170+
for v in value.values():
171+
_size_helper(v)
172+
173+
else:
174+
sizes[torch.device("cpu")] += sys.getsizeof(value, 0)
175+
176+
for intermediates in self.batch_intermediates:
177+
for value in intermediates.values():
178+
_size_helper(value)
179+
180+
return dict(sizes)
181+
139182
def iter(
140183
self, input_names: Optional[List[str]] = None
141184
) -> Generator[Any, None, None]:

0 commit comments

Comments
 (0)