diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 53729799b629..650d17ef4edf 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -305,10 +305,18 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: # An optimization when `batch` contains only one tensor: # - produce exactly same result as `torch.stack(batch)` # - will achieve zero-copy if the tensor is contiguous - return batch[0].unsqueeze(0).contiguous() + # Replace original tensor so that its memory can be freed + # in the non-contiguous case. + batch[0] = batch[0].contiguous() + return batch[0].unsqueeze(0) first_shape = batch[0].shape if all(elem.shape == first_shape for elem in batch): - return torch.stack(batch) + stack = torch.stack(batch) + # Replace original tensors with slices into the new one, + # so that their memory can be freed. + for i in range(len(batch)): + batch[i] = stack[i] + return stack return batch @@ -337,10 +345,21 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: # An optimization when `batch` contains only one tensor: # - produce exactly same result as `torch.concat(batch)` # - will achieve zero-copy if the tensor is contiguous - return batch[0].contiguous() - first_shape = batch[0].shape - if all(elem.shape[1:] == first_shape[1:] for elem in batch): - return torch.concat(batch) + # Replace original tensor so that its memory can be freed + # in the non-contiguous case. + batch[0] = batch[0].contiguous() + return batch[0] + first_shape = batch[0].shape[1:] + if all(elem.shape[1:] == first_shape for elem in batch): + concat = torch.concat(batch) + # Replace original tensors with slices into the new one, + # so that their memory can be freed. + start = 0 + for i in range(len(batch)): + end = start + batch[i].shape[0] + batch[i] = concat[start:end] + start = end + return concat return [e for elem in batch for e in elem]