10
10
from vllm .model_executor import get_model , InputMetadata , SamplingMetadata
11
11
from vllm .sampling_params import SamplingParams , SamplingType
12
12
from vllm .sequence import SamplerOutput , SequenceData , SequenceGroupMetadata
13
+ from vllm .utils import in_wsl
13
14
14
15
logger = init_logger (__name__ )
15
16
@@ -52,6 +53,8 @@ def __init__(
52
53
# The shape of the cached block table will be
53
54
# (max batch size to capture, max context len to capture / block size).
54
55
self .graph_block_tables = None # Set after initial profiling.
56
+ # cache in_wsl result
57
+ self .in_wsl = in_wsl ()
55
58
56
59
def load_model (self ) -> None :
57
60
self .model = get_model (self .model_config )
@@ -203,24 +206,29 @@ def _prepare_decode(
203
206
# When using CUDA graph, we don't need to make the tensors on the GPU
204
207
# because they will be eventually copied to the designated GPU buffer.
205
208
device = "cpu" if use_captured_graph else "cuda"
209
+ pin_memory = use_captured_graph and not self .in_wsl
206
210
input_tokens = _make_tensor_with_pad (input_tokens ,
207
211
max_len = 1 ,
208
212
pad = 0 ,
209
213
dtype = torch .long ,
210
- device = device )
214
+ device = device ,
215
+ pin_memory = pin_memory )
211
216
input_positions = _make_tensor_with_pad (input_positions ,
212
217
max_len = 1 ,
213
218
pad = 0 ,
214
219
dtype = torch .long ,
215
- device = device )
220
+ device = device ,
221
+ pin_memory = pin_memory )
216
222
slot_mapping = _make_tensor_with_pad (slot_mapping ,
217
223
max_len = 1 ,
218
224
pad = _PAD_SLOT_ID ,
219
225
dtype = torch .long ,
220
- device = device )
226
+ device = device ,
227
+ pin_memory = pin_memory )
221
228
context_lens = torch .tensor (context_lens ,
222
229
dtype = torch .int ,
223
- device = device )
230
+ device = device ,
231
+ pin_memory = pin_memory )
224
232
225
233
if use_captured_graph :
226
234
# The shape of graph_block_tables is
@@ -229,7 +237,7 @@ def _prepare_decode(
229
237
for i , block_table in enumerate (block_tables ):
230
238
if block_table :
231
239
input_block_tables [i , :len (block_table )] = block_table
232
- block_tables = torch .from_numpy (input_block_tables ). to ( device )
240
+ block_tables = torch .tensor (input_block_tables , device = device )
233
241
else :
234
242
block_tables = _make_tensor_with_pad (
235
243
block_tables ,
@@ -297,11 +305,11 @@ def _prepare_sample(
297
305
categorized_sample_indices_start_idx + num_seqs ))
298
306
categorized_sample_indices_start_idx += num_seqs
299
307
300
- selected_token_indices = torch . tensor (selected_token_indices ,
301
- dtype = torch .long ,
302
- device = "cuda" )
308
+ selected_token_indices = _async_h2d (selected_token_indices ,
309
+ dtype = torch .long ,
310
+ pin_memory = not self . in_wsl )
303
311
categorized_sample_indices = {
304
- t : torch . tensor (seq_ids , dtype = torch .int , device = "cuda" )
312
+ t : _async_h2d (seq_ids , dtype = torch .int , pin_memory = not self . in_wsl )
305
313
for t , seq_ids in categorized_sample_indices .items ()
306
314
}
307
315
@@ -334,8 +342,6 @@ def execute_model(
334
342
else :
335
343
inputs = self ._prepare_decode (seq_group_metadata_list )
336
344
input_tokens , input_positions , input_metadata = inputs
337
- sampling_metadata = self ._prepare_sample (seq_group_metadata_list ,
338
- input_metadata .prompt_lens )
339
345
340
346
# Execute the model.
341
347
if input_metadata .use_cuda_graph :
@@ -350,6 +356,9 @@ def execute_model(
350
356
input_metadata = input_metadata ,
351
357
)
352
358
359
+ sampling_metadata = self ._prepare_sample (seq_group_metadata_list ,
360
+ input_metadata .prompt_lens )
361
+
353
362
# Sample the next token.
354
363
output = self .model .sample (
355
364
hidden_states = hidden_states ,
@@ -502,11 +511,14 @@ def forward(
502
511
del kv_caches
503
512
504
513
# Copy the input tensors to the input buffers.
505
- self .input_buffers ["input_ids" ].copy_ (input_ids )
506
- self .input_buffers ["positions" ].copy_ (positions )
507
- self .input_buffers ["slot_mapping" ].copy_ (input_metadata .slot_mapping )
508
- self .input_buffers ["context_lens" ].copy_ (input_metadata .context_lens )
509
- self .input_buffers ["block_tables" ].copy_ (input_metadata .block_tables )
514
+ self .input_buffers ["input_ids" ].copy_ (input_ids , non_blocking = True )
515
+ self .input_buffers ["positions" ].copy_ (positions , non_blocking = True )
516
+ self .input_buffers ["slot_mapping" ].copy_ (input_metadata .slot_mapping ,
517
+ non_blocking = True )
518
+ self .input_buffers ["context_lens" ].copy_ (input_metadata .context_lens ,
519
+ non_blocking = True )
520
+ self .input_buffers ["block_tables" ].copy_ (input_metadata .block_tables ,
521
+ non_blocking = True )
510
522
511
523
# Run the graph.
512
524
self .graph .replay ()
@@ -529,9 +541,13 @@ def _make_tensor_with_pad(
529
541
pad : int ,
530
542
dtype : torch .dtype ,
531
543
device : Union [str , torch .device ] = "cuda" ,
544
+ pin_memory : bool = False ,
532
545
) -> torch .Tensor :
533
546
padded_x = [_pad_to_max (x_i , max_len , pad ) for x_i in x ]
534
- return torch .tensor (padded_x , dtype = dtype , device = device )
547
+ return torch .tensor (padded_x ,
548
+ dtype = dtype ,
549
+ device = device ,
550
+ pin_memory = pin_memory and str (device ) == "cpu" )
535
551
536
552
537
553
def _get_graph_batch_size (batch_size : int ) -> int :
@@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int:
541
557
return 4
542
558
else :
543
559
return (batch_size + 7 ) // 8 * 8
560
+
561
+
562
+ def _async_h2d (data : list , dtype , pin_memory ):
563
+ t = torch .tensor (data , dtype = dtype , pin_memory = pin_memory )
564
+ return t .to (device = "cuda" , non_blocking = True )
0 commit comments