Skip to content

Commit 297bc93

Browse files
committed
[V1][Perf] Avoid mem duplication when aggregating MM tensors
Signed-off-by: Nick Hill <[email protected]>
1 parent 268c325 commit 297bc93

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

vllm/multimodal/inputs.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,18 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
305305
# An optimization when `batch` contains only one tensor:
306306
# - produce exactly same result as `torch.stack(batch)`
307307
# - will achieve zero-copy if the tensor is contiguous
308-
return batch[0].unsqueeze(0).contiguous()
308+
# Replace original tensor so that its memory can be freed
309+
# in the non-contiguous case.
310+
batch[0] = batch[0].contiguous()
311+
return batch[0].unsqueeze(0)
309312
first_shape = batch[0].shape
310313
if all(elem.shape == first_shape for elem in batch):
311-
return torch.stack(batch)
314+
stack = torch.stack(batch)
315+
# Replace original tensors with slices into the new one,
316+
# so that their memory can be freed.
317+
for i in range(len(batch)):
318+
batch[i] = stack[i]
319+
return stack
312320

313321
return batch
314322

@@ -337,10 +345,21 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
337345
# An optimization when `batch` contains only one tensor:
338346
# - produce exactly same result as `torch.concat(batch)`
339347
# - will achieve zero-copy if the tensor is contiguous
340-
return batch[0].contiguous()
341-
first_shape = batch[0].shape
342-
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
343-
return torch.concat(batch)
348+
# Replace original tensor so that its memory can be freed
349+
# in the non-contiguous case.
350+
batch[0] = batch[0].contiguous()
351+
return batch[0]
352+
first_shape = batch[0].shape[1:]
353+
if all(elem.shape[1:] == first_shape for elem in batch):
354+
concat = torch.concat(batch)
355+
# Replace original tensors with slices into the new one,
356+
# so that their memory can be freed.
357+
off = 0
358+
for i in range(len(batch)):
359+
size = batch[i].shape[0]
360+
batch[i] = concat[off:off + size]
361+
off += size
362+
return concat
344363

345364
return [e for elem in batch for e in elem]
346365

0 commit comments

Comments
 (0)