From 297bc934438ef51f209b9dfbd8888fd0d56faf60 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 10 Apr 2025 18:11:21 -0700 Subject: [PATCH 1/2] [V1][Perf] Avoid mem duplication when aggregating MM tensors Signed-off-by: Nick Hill --- vllm/multimodal/inputs.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 53729799b629..fb2514aebe4e 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -305,10 +305,18 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: # An optimization when `batch` contains only one tensor: # - produce exactly same result as `torch.stack(batch)` # - will achieve zero-copy if the tensor is contiguous - return batch[0].unsqueeze(0).contiguous() + # Replace original tensor so that its memory can be freed + # in the non-contiguous case. + batch[0] = batch[0].contiguous() + return batch[0].unsqueeze(0) first_shape = batch[0].shape if all(elem.shape == first_shape for elem in batch): - return torch.stack(batch) + stack = torch.stack(batch) + # Replace original tensors with slices into the new one, + # so that their memory can be freed. + for i in range(len(batch)): + batch[i] = stack[i] + return stack return batch @@ -337,10 +345,21 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: # An optimization when `batch` contains only one tensor: # - produce exactly same result as `torch.concat(batch)` # - will achieve zero-copy if the tensor is contiguous - return batch[0].contiguous() - first_shape = batch[0].shape - if all(elem.shape[1:] == first_shape[1:] for elem in batch): - return torch.concat(batch) + # Replace original tensor so that its memory can be freed + # in the non-contiguous case. + batch[0] = batch[0].contiguous() + return batch[0] + first_shape = batch[0].shape[1:] + if all(elem.shape[1:] == first_shape for elem in batch): + concat = torch.concat(batch) + # Replace original tensors with slices into the new one, + # so that their memory can be freed. + off = 0 + for i in range(len(batch)): + size = batch[i].shape[0] + batch[i] = concat[off:off + size] + off += size + return concat return [e for elem in batch for e in elem] From d728cc6d5db072fb5876f2b67cd17c1686353fa5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 10 Apr 2025 20:13:54 -0700 Subject: [PATCH 2/2] update Signed-off-by: Nick Hill --- vllm/multimodal/inputs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index fb2514aebe4e..650d17ef4edf 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -354,11 +354,11 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: concat = torch.concat(batch) # Replace original tensors with slices into the new one, # so that their memory can be freed. - off = 0 + start = 0 for i in range(len(batch)): - size = batch[i].shape[0] - batch[i] = concat[off:off + size] - off += size + end = start + batch[i].shape[0] + batch[i] = concat[start:end] + start = end return concat return [e for elem in batch for e in elem]