diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index 364df00d4..90ba89ffd 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -1,4 +1,6 @@ +import sys import warnings +from collections import defaultdict from dataclasses import dataclass, fields, is_dataclass from typing import Any, Dict, Generator, List, Optional, Union @@ -132,10 +134,51 @@ def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None): del intermediates[name] def append(self, values: Dict[str, Any]): + """ + Append new values to the cache. The new values will be assigned the next + available batch index + + :param values: dictionary mapping keys to values used for update + """ batch_index = len(self.batch_intermediates) self.batch_intermediates.append({}) self.update(batch_index, values) + def size(self) -> Dict[torch.device, int]: + """ + Returns the memory used by cached values, keyed by device, in bytes + + :return: dictionary mapping torch device to number of bytes in cache + """ + sizes = defaultdict(lambda: 0) + + def _size_helper(intermediate: IntermediateValue) -> int: + value = intermediate.value + + if isinstance(value, torch.Tensor): + sizes[value.device] += value.nbytes + + elif is_dataclass(value): + for field in fields(value): + _size_helper(getattr(value, field.name)) + + elif isinstance(value, tuple): + for v in value: + _size_helper(v) + + elif isinstance(value, dict): + for v in value.values(): + _size_helper(v) + + else: + sizes[torch.device("cpu")] += sys.getsizeof(value, 0) + + for intermediates in self.batch_intermediates: + for value in intermediates.values(): + _size_helper(value) + + return dict(sizes) + def iter( self, input_names: Optional[List[str]] = None ) -> Generator[Any, None, None]: