@@ -305,10 +305,18 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
305
305
# An optimization when `batch` contains only one tensor:
306
306
# - produce exactly same result as `torch.stack(batch)`
307
307
# - will achieve zero-copy if the tensor is contiguous
308
- return batch [0 ].unsqueeze (0 ).contiguous ()
308
+ # Replace original tensor so that its memory can be freed
309
+ # in the non-contiguous case.
310
+ batch [0 ] = batch [0 ].contiguous ()
311
+ return batch [0 ].unsqueeze (0 )
309
312
first_shape = batch [0 ].shape
310
313
if all (elem .shape == first_shape for elem in batch ):
311
- return torch .stack (batch )
314
+ stack = torch .stack (batch )
315
+ # Replace original tensors with slices into the new one,
316
+ # so that their memory can be freed.
317
+ for i in range (len (batch )):
318
+ batch [i ] = stack [i ]
319
+ return stack
312
320
313
321
return batch
314
322
@@ -337,10 +345,21 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
337
345
# An optimization when `batch` contains only one tensor:
338
346
# - produce exactly same result as `torch.concat(batch)`
339
347
# - will achieve zero-copy if the tensor is contiguous
340
- return batch [0 ].contiguous ()
341
- first_shape = batch [0 ].shape
342
- if all (elem .shape [1 :] == first_shape [1 :] for elem in batch ):
343
- return torch .concat (batch )
348
+ # Replace original tensor so that its memory can be freed
349
+ # in the non-contiguous case.
350
+ batch [0 ] = batch [0 ].contiguous ()
351
+ return batch [0 ]
352
+ first_shape = batch [0 ].shape [1 :]
353
+ if all (elem .shape [1 :] == first_shape for elem in batch ):
354
+ concat = torch .concat (batch )
355
+ # Replace original tensors with slices into the new one,
356
+ # so that their memory can be freed.
357
+ off = 0
358
+ for i in range (len (batch )):
359
+ size = batch [i ].shape [0 ]
360
+ batch [i ] = concat [off :off + size ]
361
+ off += size
362
+ return concat
344
363
345
364
return [e for elem in batch for e in elem ]
346
365
0 commit comments