-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathcache.py
More file actions
290 lines (242 loc) · 10.8 KB
/
cache.py
File metadata and controls
290 lines (242 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from __future__ import annotations
import sys
import warnings
from collections import defaultdict
from dataclasses import dataclass, fields, is_dataclass
from typing import Any, Generator
from weakref import WeakKeyDictionary
import torch
from tqdm import tqdm
@dataclass
class IntermediateValue:
"""
Dataclass which recursively defines offloaded values and which device to onload to
:param value: either an offloaded Tensor, an primative value, or a recursable value
:param device: if the value is a Tensor, then the device to onload the tensor to,
otherwise None
"""
value: torch.Tensor | "IntermediateValue" | Any
device: torch.device | None
IntermediateValues = dict[str, IntermediateValue]
class IntermediatesCache:
"""
Cache which stores intermediate values (activations) produced by batched, sequential
execution of models. Values are offloaded to the `offload_device` when stored in
the cache and onloaded to their original device when fetched from the cache. If
`offload_device` is None, values will not be offloaded at all.
Currently supports nested offloading of dataclass instances and tuples
Construct using `empty` and `from_dataloader` class methods
"""
batch_intermediates: list[IntermediateValues]
offload_device: torch.device | None
# onload value -> offload value
# used to avoid excess memory usage when shared tensors are offloaded
offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary()
def __init__(
self,
batch_intermediates: list[IntermediateValues] | None = None,
offload_device: torch.device | None = "cpu",
):
self.batch_intermediates = batch_intermediates or []
self.offload_device = offload_device
@classmethod
def empty(cls, num_batches: int, offload_device: torch.device):
"""
Construct an empty cache
:param num_batches: the expected number of batches to be stored
:param offload_device: device to offload values to
"""
batch_intermediates = [{} for _ in range(num_batches)]
return cls(batch_intermediates, offload_device)
@classmethod
def from_dataloader(
cls,
dataloader: torch.utils.data.DataLoader,
model_device: torch.device = torch.device("cpu"),
offload_device: torch.device | None = torch.device("cpu"),
):
"""
Initialize a cache with data from the provided dataloader
This method iterates through all batches in the dataloader and offloads
them to the specified device. For faster cache preparation, consider:
- Increasing batch_size to reduce the number of iterations
- Using num_workers > 0 in the DataLoader for parallel loading
- Ensuring data preprocessing is done before creating the dataloader
:param dataloader: dataloader which generates values to be cached
:param model_device: device which values will be onloaded to when fetched
:param offload_device: device to offload values to
"""
batch_intermediates = [
{
key: cls._offload_value(value, offload_device, model_device)
for key, value in batch.items()
}
for batch in tqdm(dataloader, desc="Preparing cache")
]
return cls(batch_intermediates, offload_device)
def fetch(
self, batch_index: int, input_names: list[str] | None = None
) -> dict[str, Any]:
"""
Fetch values belonging to a batch
:param batch_index: index of batch whose values are being fetched
:param input_names: list of keys whose values are being fetched
:return: dictionary mapping keys to onloaded values
"""
intermediates = self.batch_intermediates[batch_index]
return {
key: self._onload_value(subgraph_input)
for key, subgraph_input in intermediates.items()
if input_names is None or key in input_names
}
def update(self, batch_index: int, values: dict[str, Any]):
"""
Update/put values belonging to a batch
:param batch_index: index of batch whose values will be updated
:param values: dictionary mapping keys to values used for update
"""
device = self.offload_device
intermediates = {k: self._offload_value(v, device) for k, v in values.items()}
self.batch_intermediates[batch_index].update(intermediates)
def delete(self, batch_index: int, consumed_names: list[str] | None = None):
"""
Delete values from the cache
:param batch_index: index of batch whose values will be deleted
:param consumed_names: list of keys whose values will be deleted, defaults to
removing all keys
"""
intermediates = self.batch_intermediates[batch_index]
if consumed_names is None:
consumed_names = list(intermediates.keys())
for name in consumed_names:
del intermediates[name]
def append(self, values: dict[str, Any]):
"""
Append new values to the cache. The new values will be assigned the next
available batch index
:param values: dictionary mapping keys to values used for update
"""
batch_index = len(self.batch_intermediates)
self.batch_intermediates.append({})
self.update(batch_index, values)
def size(self) -> dict[torch.device, int]:
"""
Returns the memory used by cached values, keyed by device, in bytes
:return: dictionary mapping torch device to number of bytes in cache
"""
sizes = defaultdict(lambda: 0)
def _size_helper(intermediate: IntermediateValue) -> int:
value = intermediate.value
match value:
case torch.Tensor():
sizes[value.device] += value.nbytes
case list() | tuple():
for v in value:
_size_helper(v)
case dict():
for v in value.values():
_size_helper(v)
case _ if is_dataclass(value):
for field in fields(value):
_size_helper(getattr(value, field.name))
case _:
# this handles primitive values that don't match any other cases
sizes[torch.device("cpu")] += sys.getsizeof(value, 0)
for intermediates in self.batch_intermediates:
for value in intermediates.values():
_size_helper(value)
return dict(sizes)
def iter(self, input_names: list[str] | None = None) -> Generator[Any, None, None]:
for batch_index in range(len(self.batch_intermediates)):
yield self.fetch(batch_index, input_names)
def __iter__(self) -> Generator[Any, None, None]:
yield from self.iter()
def __len__(self) -> int:
return len(self.batch_intermediates)
@classmethod
def _onload_value(cls, intermediate: IntermediateValue) -> Any:
"""
Onload a value's tensors to the onload device
:param intermediate: intermediates value representation to onload
:return: original value with tensors onloaded to the onload device
"""
value = intermediate.value
device = intermediate.device
match value:
case torch.Tensor():
return value.to(device=device)
case list():
return [cls._onload_value(v) for v in value]
case tuple():
return tuple(cls._onload_value(v) for v in value)
case dict():
return {k: cls._onload_value(v) for k, v in value.items()}
case _ if is_dataclass(value):
for field in fields(value):
v = getattr(value, field.name)
setattr(value, field.name, cls._onload_value(v))
return value
case _:
# handles primitive values that should be returned as is.
# without this, a MatchError would be raised for unhandled types.
return value
@classmethod
def _offload_value(
cls,
value: Any,
offload_device: torch.device | None,
onload_device: torch.device | None = None,
) -> IntermediateValue:
"""
Offload a value's tensors to the offload device
:param value: value to offload
:param offload_device: device to offload `torch.Tensor` values to
:param onload_device: device used when onloading `torch.Tensor` values.
If None is provided, use the tensor's current device
:return: Instance of IntermediateValue representing the offloaded value
"""
kwargs = {"offload_device": offload_device, "onload_device": onload_device}
match value:
case torch.Tensor():
# check for cache hit
if value in cls.offload_values:
offloaded = cls.offload_values[value]
else:
# move to offload if no hit
offloaded = value.to(device=offload_device)
cls.offload_values[value] = offloaded
return IntermediateValue(
value=offloaded,
device=(onload_device if onload_device else value.device),
)
case list():
return IntermediateValue(
value=[cls._offload_value(v, **kwargs) for v in value],
device=None,
)
case tuple():
return IntermediateValue(
value=tuple(cls._offload_value(v, **kwargs) for v in value),
device=None,
)
case dict():
return IntermediateValue(
value={
k: cls._offload_value(v, **kwargs) for k, v in value.items()
},
device=None,
)
case _ if is_dataclass(value):
for field in fields(value):
v = getattr(value, field.name)
setattr(value, field.name, cls._offload_value(v, **kwargs))
return IntermediateValue(value=value, device=None)
case _:
# handles primitive values and provides a warning for unsupported types.
# without this, values trigger a MatchError exception.
if not isinstance(
value,
(int, str, float, bool, torch.dtype, torch.device, type(None)),
):
warnings.warn(f"Offloading not implemented for type {type(value)}.")
return IntermediateValue(value=value, device=None)