-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathcache.py
More file actions
325 lines (272 loc) · 12 KB
/
cache.py
File metadata and controls
325 lines (272 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
from __future__ import annotations
import sys
import warnings
from collections import defaultdict
from dataclasses import dataclass, fields, is_dataclass
from typing import Any, Generator
from weakref import WeakKeyDictionary
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from tqdm import tqdm
class OverrideEqMode(TorchDispatchMode):
"""
When using a torch.Tensor as a key in a dictionary, the equality
check must return a single value instead of a torch.Tensor
of bool values.
Use this override context for such cases, to swap out the torch.eq
equality check for a check on id
>>> a = torch.tensor([1,2,3])
>>> b = torch.tensor([1,2,3])
>>> a == b
tensor([True, True, True])
>>> with OverrideEqMode():
... a == b
tensor(True)
"""
def __torch_dispatch__(self, func, _types, args=(), kwargs=None):
kwargs = kwargs or {}
# Check if the operation is equality
if func is torch.ops.aten.eq.Tensor:
# Override to use torch.equal
assert len(args) == 2, "Exactly 2 args must be provided"
# NOTE: Errors out without cast to torch.tensor
return torch.tensor(id(args[0]) == id(args[1]))
# For all other operations, just run them normally
return func(*args, **kwargs)
@dataclass
class IntermediateValue:
"""
Dataclass which recursively defines offloaded values and which device to onload to
:param value: either an offloaded Tensor, an primative value, or a recursable value
:param device: if the value is a Tensor, then the device to onload the tensor to,
otherwise None
"""
value: torch.Tensor | "IntermediateValue" | Any
device: torch.device | None
IntermediateValues = dict[str, IntermediateValue]
class IntermediatesCache:
"""
Cache which stores intermediate values (activations) produced by batched, sequential
execution of models. Values are offloaded to the `offload_device` when stored in
the cache and onloaded to their original device when fetched from the cache. If
`offload_device` is None, values will not be offloaded at all.
Currently supports nested offloading of dataclass instances and tuples
Construct using `empty` and `from_dataloader` class methods
"""
batch_intermediates: list[IntermediateValues]
offload_device: torch.device | None
# map of onload value -> offload value
# used to avoid excess memory usage when shared tensors are offloaded
offload_values: WeakKeyDictionary[torch.Tensor, torch.Tensor] = WeakKeyDictionary()
def __init__(
self,
batch_intermediates: list[IntermediateValues] | None = None,
offload_device: torch.device | None = "cpu",
):
self.batch_intermediates = batch_intermediates or []
self.offload_device = offload_device
@classmethod
def empty(cls, num_batches: int, offload_device: torch.device):
"""
Construct an empty cache
:param num_batches: the expected number of batches to be stored
:param offload_device: device to offload values to
"""
batch_intermediates = [{} for _ in range(num_batches)]
return cls(batch_intermediates, offload_device)
@classmethod
def from_dataloader(
cls,
dataloader: torch.utils.data.DataLoader,
model_device: torch.device = torch.device("cpu"),
offload_device: torch.device | None = torch.device("cpu"),
):
"""
Initialize a cache with data from the provided dataloader
This method iterates through all batches in the dataloader and offloads
them to the specified device. For faster cache preparation, consider:
- Increasing batch_size to reduce the number of iterations
- Using num_workers > 0 in the DataLoader for parallel loading
- Ensuring data preprocessing is done before creating the dataloader
:param dataloader: dataloader which generates values to be cached
:param model_device: device which values will be onloaded to when fetched
:param offload_device: device to offload values to
"""
batch_intermediates = [
{
key: cls._offload_value(value, offload_device, model_device)
for key, value in batch.items()
}
for batch in tqdm(dataloader, desc="Preparing cache")
]
return cls(batch_intermediates, offload_device)
def fetch(
self, batch_index: int, input_names: list[str] | None = None
) -> dict[str, Any]:
"""
Fetch values belonging to a batch
:param batch_index: index of batch whose values are being fetched
:param input_names: list of keys whose values are being fetched
:return: dictionary mapping keys to onloaded values
"""
intermediates = self.batch_intermediates[batch_index]
return {
key: self._onload_value(subgraph_input)
for key, subgraph_input in intermediates.items()
if input_names is None or key in input_names
}
def update(self, batch_index: int, values: dict[str, Any]):
"""
Update/put values belonging to a batch
:param batch_index: index of batch whose values will be updated
:param values: dictionary mapping keys to values used for update
"""
device = self.offload_device
intermediates = {k: self._offload_value(v, device) for k, v in values.items()}
self.batch_intermediates[batch_index].update(intermediates)
def delete(self, batch_index: int, consumed_names: list[str] | None = None):
"""
Delete values from the cache
:param batch_index: index of batch whose values will be deleted
:param consumed_names: list of keys whose values will be deleted, defaults to
removing all keys
"""
intermediates = self.batch_intermediates[batch_index]
if consumed_names is None:
consumed_names = list(intermediates.keys())
for name in consumed_names:
del intermediates[name]
def append(self, values: dict[str, Any]):
"""
Append new values to the cache. The new values will be assigned the next
available batch index
:param values: dictionary mapping keys to values used for update
"""
batch_index = len(self.batch_intermediates)
self.batch_intermediates.append({})
self.update(batch_index, values)
def size(self) -> dict[torch.device, int]:
"""
Returns the memory used by cached values, keyed by device, in bytes
:return: dictionary mapping torch device to number of bytes in cache
"""
sizes = defaultdict(lambda: 0)
memo = set()
def _size_helper(intermediate: IntermediateValue) -> int:
value = intermediate.value
match value:
case torch.Tensor():
if value not in memo:
sizes[value.device] += value.nbytes
memo.add(value)
case list() | tuple():
for v in value:
_size_helper(v)
case dict():
for v in value.values():
_size_helper(v)
case _ if is_dataclass(value):
for field in fields(value):
_size_helper(getattr(value, field.name))
case _:
# this handles primitive values that don't match any other cases
sizes[torch.device("cpu")] += sys.getsizeof(value, 0)
for intermediates in self.batch_intermediates:
for value in intermediates.values():
_size_helper(value)
return dict(sizes)
def iter(self, input_names: list[str] | None = None) -> Generator[Any, None, None]:
for batch_index in range(len(self.batch_intermediates)):
yield self.fetch(batch_index, input_names)
def __iter__(self) -> Generator[Any, None, None]:
yield from self.iter()
def __len__(self) -> int:
return len(self.batch_intermediates)
@classmethod
def _onload_value(cls, intermediate: IntermediateValue) -> Any:
"""
Onload a value's tensors to the onload device
:param intermediate: intermediates value representation to onload
:return: original value with tensors onloaded to the onload device
"""
value = intermediate.value
device = intermediate.device
match value:
case torch.Tensor():
return value.to(device=device)
case list():
return [cls._onload_value(v) for v in value]
case tuple():
return tuple(cls._onload_value(v) for v in value)
case dict():
return {k: cls._onload_value(v) for k, v in value.items()}
case _ if is_dataclass(value):
for field in fields(value):
v = getattr(value, field.name)
setattr(value, field.name, cls._onload_value(v))
return value
case _:
# handles primitive values that should be returned as is.
# without this, a MatchError would be raised for unhandled types.
return value
@classmethod
def _offload_value(
cls,
value: Any,
offload_device: torch.device | None,
onload_device: torch.device | None = None,
) -> IntermediateValue:
"""
Offload a value's tensors to the offload device
:param value: value to offload
:param offload_device: device to offload `torch.Tensor` values to
:param onload_device: device used when onloading `torch.Tensor` values.
If None is provided, use the tensor's current device
:return: Instance of IntermediateValue representing the offloaded value
"""
kwargs = {"offload_device": offload_device, "onload_device": onload_device}
match value:
case torch.Tensor():
with OverrideEqMode():
# check for cache hit between shared tensors
if value in cls.offload_values:
offloaded = cls.offload_values[value]
else:
# move to offload if no hit
offloaded = value.to(device=offload_device)
cls.offload_values[value] = offloaded
return IntermediateValue(
value=offloaded,
device=(onload_device if onload_device else value.device),
)
case list():
return IntermediateValue(
value=[cls._offload_value(v, **kwargs) for v in value],
device=None,
)
case tuple():
return IntermediateValue(
value=tuple(cls._offload_value(v, **kwargs) for v in value),
device=None,
)
case dict():
return IntermediateValue(
value={
k: cls._offload_value(v, **kwargs) for k, v in value.items()
},
device=None,
)
case _ if is_dataclass(value):
for field in fields(value):
v = getattr(value, field.name)
setattr(value, field.name, cls._offload_value(v, **kwargs))
return IntermediateValue(value=value, device=None)
case _:
# handles primitive values and provides a warning for unsupported types.
# without this, values trigger a MatchError exception.
if not isinstance(
value,
(int, str, float, bool, torch.dtype, torch.device, type(None)),
):
warnings.warn(f"Offloading not implemented for type {type(value)}.")
return IntermediateValue(value=value, device=None)