Skip to content

Commit 319cb1e

Browse files
authored
[Core] Batch multi modal input using pinned memory (#19169)
Signed-off-by: Lukas Geiger <[email protected]>
1 parent 1efef71 commit 319cb1e

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

vllm/multimodal/inputs.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,8 @@ def modalities(self):
680680
return self._items_by_modality.keys()
681681

682682
@staticmethod
683-
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
683+
def _try_stack(nested_tensors: NestedTensors,
684+
pin_memory: bool = False) -> NestedTensors:
684685
"""
685686
Stack the inner dimensions that have the same shape in
686687
a nested list of tensors.
@@ -697,7 +698,9 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
697698
if isinstance(nested_tensors, (int, float)):
698699
return torch.tensor(nested_tensors)
699700

700-
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
701+
stacked = [
702+
MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
703+
]
701704
if not is_list_of(stacked, torch.Tensor, check="all"):
702705
# Only tensors (not lists) can be stacked.
703706
return stacked
@@ -713,10 +716,16 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
713716
# The tensors have incompatible shapes and can't be stacked.
714717
return tensors_
715718

716-
return torch.stack(tensors_)
719+
outputs = torch.empty(len(tensors_),
720+
*tensors_[0].shape,
721+
dtype=tensors_[0].dtype,
722+
device=tensors_[0].device,
723+
pin_memory=pin_memory)
724+
return torch.stack(tensors_, out=outputs)
717725

718726
@staticmethod
719-
def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
727+
def batch(inputs_list: list["MultiModalKwargs"],
728+
pin_memory: bool = False) -> BatchedTensorInputs:
720729
"""
721730
Batch multiple inputs together into a dictionary.
722731
@@ -738,7 +747,7 @@ def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
738747
item_lists[k].append(v)
739748

740749
return {
741-
k: MultiModalKwargs._try_stack(item_list)
750+
k: MultiModalKwargs._try_stack(item_list, pin_memory)
742751
for k, item_list in item_lists.items()
743752
}
744753

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,8 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
962962

963963
encoder_outputs = []
964964
for grouped_mm_inputs in grouped_mm_inputs_list:
965-
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
965+
batched_mm_inputs = MultiModalKwargs.batch(
966+
grouped_mm_inputs, pin_memory=self.pin_memory)
966967
batched_mm_inputs = MultiModalKwargs.as_kwargs(
967968
batched_mm_inputs,
968969
device=self.device,
@@ -1989,7 +1990,8 @@ def profile_run(self) -> None:
19891990
).multi_modal_data
19901991

19911992
batched_dummy_mm_inputs = MultiModalKwargs.batch(
1992-
[dummy_mm_kwargs] * max_num_mm_items)
1993+
[dummy_mm_kwargs] * max_num_mm_items,
1994+
pin_memory=self.pin_memory)
19931995
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
19941996
batched_dummy_mm_inputs,
19951997
device=self.device,

0 commit comments

Comments
 (0)