1
+ import sys
1
2
import warnings
3
+ from collections import defaultdict
2
4
from dataclasses import dataclass , fields , is_dataclass
3
5
from typing import Any , Dict , Generator , List , Optional , Union
4
6
@@ -132,10 +134,51 @@ def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None):
132
134
del intermediates [name ]
133
135
134
136
def append (self , values : Dict [str , Any ]):
137
+ """
138
+ Append new values to the cache. The new values will be assigned the next
139
+ available batch index
140
+
141
+ :param values: dictionary mapping keys to values used for update
142
+ """
135
143
batch_index = len (self .batch_intermediates )
136
144
self .batch_intermediates .append ({})
137
145
self .update (batch_index , values )
138
146
147
+ def size (self ) -> Dict [torch .device , int ]:
148
+ """
149
+ Returns the memory used by cached values, keyed by device, in bytes
150
+
151
+ :return: dictionary mapping torch device to number of bytes in cache
152
+ """
153
+ sizes = defaultdict (lambda : 0 )
154
+
155
+ def _size_helper (intermediate : IntermediateValue ) -> int :
156
+ value = intermediate .value
157
+
158
+ if isinstance (value , torch .Tensor ):
159
+ sizes [value .device ] += value .nbytes
160
+
161
+ elif is_dataclass (value ):
162
+ for field in fields (value ):
163
+ _size_helper (getattr (value , field .name ))
164
+
165
+ elif isinstance (value , (tuple , list )):
166
+ for v in value :
167
+ _size_helper (v )
168
+
169
+ elif isinstance (value , dict ):
170
+ for v in value .values ():
171
+ _size_helper (v )
172
+
173
+ else :
174
+ sizes [torch .device ("cpu" )] += sys .getsizeof (value , 0 )
175
+
176
+ for intermediates in self .batch_intermediates :
177
+ for value in intermediates .values ():
178
+ _size_helper (value )
179
+
180
+ return dict (sizes )
181
+
139
182
def iter (
140
183
self , input_names : Optional [List [str ]] = None
141
184
) -> Generator [Any , None , None ]:
@@ -162,6 +205,9 @@ def _onload_value(self, intermediate: IntermediateValue) -> Any:
162
205
163
206
return value
164
207
208
+ if isinstance (value , list ):
209
+ return list (self ._onload_value (v ) for v in value )
210
+
165
211
if isinstance (value , tuple ):
166
212
return tuple (self ._onload_value (v ) for v in value )
167
213
@@ -188,6 +234,11 @@ def _offload_value(self, value: Any) -> IntermediateValue:
188
234
189
235
return IntermediateValue (value = value , device = None )
190
236
237
+ if isinstance (value , list ):
238
+ return IntermediateValue (
239
+ value = list (self ._offload_value (v ) for v in value ), device = None
240
+ )
241
+
191
242
if isinstance (value , tuple ):
192
243
return IntermediateValue (
193
244
value = tuple (self ._offload_value (v ) for v in value ), device = None
0 commit comments