Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,18 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.stack(batch)`
# - will achieve zero-copy if the tensor is contiguous
return batch[0].unsqueeze(0).contiguous()
# Replace original tensor so that its memory can be freed
# in the non-contiguous case.
batch[0] = batch[0].contiguous()
return batch[0].unsqueeze(0)
first_shape = batch[0].shape
if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch)
stack = torch.stack(batch)
# Replace original tensors with slices into the new one,
# so that their memory can be freed.
for i in range(len(batch)):
batch[i] = stack[i]
return stack

return batch

Expand Down Expand Up @@ -337,10 +345,21 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.concat(batch)`
# - will achieve zero-copy if the tensor is contiguous
return batch[0].contiguous()
first_shape = batch[0].shape
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
return torch.concat(batch)
# Replace original tensor so that its memory can be freed
# in the non-contiguous case.
batch[0] = batch[0].contiguous()
return batch[0]
first_shape = batch[0].shape[1:]
if all(elem.shape[1:] == first_shape for elem in batch):
concat = torch.concat(batch)
# Replace original tensors with slices into the new one,
# so that their memory can be freed.
start = 0
for i in range(len(batch)):
end = start + batch[i].shape[0]
batch[i] = concat[start:end]
start = end
return concat

return [e for elem in batch for e in elem]

Expand Down