@@ -162,7 +162,7 @@ def _size_helper(intermediate: IntermediateValue) -> int:
162
162
for field in fields (value ):
163
163
_size_helper (getattr (value , field .name ))
164
164
165
- elif isinstance (value , tuple ):
165
+ elif isinstance (value , ( tuple , list ) ):
166
166
for v in value :
167
167
_size_helper (v )
168
168
@@ -205,6 +205,9 @@ def _onload_value(self, intermediate: IntermediateValue) -> Any:
205
205
206
206
return value
207
207
208
+ if isinstance (value , list ):
209
+ return list (self ._onload_value (v ) for v in value )
210
+
208
211
if isinstance (value , tuple ):
209
212
return tuple (self ._onload_value (v ) for v in value )
210
213
@@ -231,6 +234,11 @@ def _offload_value(self, value: Any) -> IntermediateValue:
231
234
232
235
return IntermediateValue (value = value , device = None )
233
236
237
+ if isinstance (value , list ):
238
+ return IntermediateValue (
239
+ value = list (self ._offload_value (v ) for v in value ), device = None
240
+ )
241
+
234
242
if isinstance (value , tuple ):
235
243
return IntermediateValue (
236
244
value = tuple (self ._offload_value (v ) for v in value ), device = None
0 commit comments