Skip to content

Commit 31bff69

Browse files
authored
Make _prepare_sample non-blocking and use pinned memory for input buffers (#2207)
1 parent ba4f826 commit 31bff69

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

vllm/worker/model_runner.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
1111
from vllm.sampling_params import SamplingParams, SamplingType
1212
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
13+
from vllm.utils import in_wsl
1314

1415
logger = init_logger(__name__)
1516

@@ -52,6 +53,8 @@ def __init__(
5253
# The shape of the cached block table will be
5354
# (max batch size to capture, max context len to capture / block size).
5455
self.graph_block_tables = None # Set after initial profiling.
56+
# cache in_wsl result
57+
self.in_wsl = in_wsl()
5558

5659
def load_model(self) -> None:
5760
self.model = get_model(self.model_config)
@@ -203,24 +206,29 @@ def _prepare_decode(
203206
# When using CUDA graph, we don't need to make the tensors on the GPU
204207
# because they will be eventually copied to the designated GPU buffer.
205208
device = "cpu" if use_captured_graph else "cuda"
209+
pin_memory = use_captured_graph and not self.in_wsl
206210
input_tokens = _make_tensor_with_pad(input_tokens,
207211
max_len=1,
208212
pad=0,
209213
dtype=torch.long,
210-
device=device)
214+
device=device,
215+
pin_memory=pin_memory)
211216
input_positions = _make_tensor_with_pad(input_positions,
212217
max_len=1,
213218
pad=0,
214219
dtype=torch.long,
215-
device=device)
220+
device=device,
221+
pin_memory=pin_memory)
216222
slot_mapping = _make_tensor_with_pad(slot_mapping,
217223
max_len=1,
218224
pad=_PAD_SLOT_ID,
219225
dtype=torch.long,
220-
device=device)
226+
device=device,
227+
pin_memory=pin_memory)
221228
context_lens = torch.tensor(context_lens,
222229
dtype=torch.int,
223-
device=device)
230+
device=device,
231+
pin_memory=pin_memory)
224232

225233
if use_captured_graph:
226234
# The shape of graph_block_tables is
@@ -229,7 +237,7 @@ def _prepare_decode(
229237
for i, block_table in enumerate(block_tables):
230238
if block_table:
231239
input_block_tables[i, :len(block_table)] = block_table
232-
block_tables = torch.from_numpy(input_block_tables).to(device)
240+
block_tables = torch.tensor(input_block_tables, device=device)
233241
else:
234242
block_tables = _make_tensor_with_pad(
235243
block_tables,
@@ -297,11 +305,11 @@ def _prepare_sample(
297305
categorized_sample_indices_start_idx + num_seqs))
298306
categorized_sample_indices_start_idx += num_seqs
299307

300-
selected_token_indices = torch.tensor(selected_token_indices,
301-
dtype=torch.long,
302-
device="cuda")
308+
selected_token_indices = _async_h2d(selected_token_indices,
309+
dtype=torch.long,
310+
pin_memory=not self.in_wsl)
303311
categorized_sample_indices = {
304-
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
312+
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
305313
for t, seq_ids in categorized_sample_indices.items()
306314
}
307315

@@ -334,8 +342,6 @@ def execute_model(
334342
else:
335343
inputs = self._prepare_decode(seq_group_metadata_list)
336344
input_tokens, input_positions, input_metadata = inputs
337-
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
338-
input_metadata.prompt_lens)
339345

340346
# Execute the model.
341347
if input_metadata.use_cuda_graph:
@@ -350,6 +356,9 @@ def execute_model(
350356
input_metadata=input_metadata,
351357
)
352358

359+
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
360+
input_metadata.prompt_lens)
361+
353362
# Sample the next token.
354363
output = self.model.sample(
355364
hidden_states=hidden_states,
@@ -502,11 +511,14 @@ def forward(
502511
del kv_caches
503512

504513
# Copy the input tensors to the input buffers.
505-
self.input_buffers["input_ids"].copy_(input_ids)
506-
self.input_buffers["positions"].copy_(positions)
507-
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping)
508-
self.input_buffers["context_lens"].copy_(input_metadata.context_lens)
509-
self.input_buffers["block_tables"].copy_(input_metadata.block_tables)
514+
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
515+
self.input_buffers["positions"].copy_(positions, non_blocking=True)
516+
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
517+
non_blocking=True)
518+
self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
519+
non_blocking=True)
520+
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
521+
non_blocking=True)
510522

511523
# Run the graph.
512524
self.graph.replay()
@@ -529,9 +541,13 @@ def _make_tensor_with_pad(
529541
pad: int,
530542
dtype: torch.dtype,
531543
device: Union[str, torch.device] = "cuda",
544+
pin_memory: bool = False,
532545
) -> torch.Tensor:
533546
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
534-
return torch.tensor(padded_x, dtype=dtype, device=device)
547+
return torch.tensor(padded_x,
548+
dtype=dtype,
549+
device=device,
550+
pin_memory=pin_memory and str(device) == "cpu")
535551

536552

537553
def _get_graph_batch_size(batch_size: int) -> int:
@@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int:
541557
return 4
542558
else:
543559
return (batch_size + 7) // 8 * 8
560+
561+
562+
def _async_h2d(data: list, dtype, pin_memory):
563+
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
564+
return t.to(device="cuda", non_blocking=True)

0 commit comments

Comments
 (0)