1
1
# SPDX-License-Identifier: Apache-2.0
2
2
import collections
3
3
import contextlib
4
+ import copy
4
5
import functools
5
6
import itertools
6
7
import math
26
27
from vllm .attention .layer import Attention
27
28
from vllm .attention .selector import get_attn_backend
28
29
from vllm .config import (VllmConfig , update_config )
29
- from vllm .forward_context import set_forward_context
30
+ from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
31
+ has_kv_transfer_group )
32
+ from vllm .distributed .kv_transfer .kv_connector .v1 import KVConnectorBase_V1
33
+ from vllm .forward_context import get_forward_context , set_forward_context
30
34
from vllm .model_executor .layers .fused_moe .layer import FusedMoE
31
35
from vllm .model_executor .layers .layernorm import RMSNorm
32
36
from vllm .model_executor .layers .sampler import get_sampler
@@ -407,6 +411,7 @@ def forward(self, *args, **kwargs):
407
411
# kwargs['attn_metadata'].slot_mapping, compared to untrimmed metadata
408
412
kwargs = kwargs .copy ()
409
413
# selected_token_indices = kwargs.pop('selected_token_indices')
414
+ is_warmup = kwargs .get ('warmup_mode' , False )
410
415
if 'warmup_mode' in kwargs :
411
416
kwargs .pop ('warmup_mode' )
412
417
input_ids = kwargs ['input_ids' ]
@@ -420,7 +425,12 @@ def forward(self, *args, **kwargs):
420
425
if 'kv_caches' in kwargs :
421
426
kwargs .pop ('kv_caches' )
422
427
with set_forward_context (attn_meta , self .vllm_config ):
428
+ if not is_warmup :
429
+ self .maybe_start_load_kv ()
423
430
hidden_states = self .model (* args , ** kwargs )
431
+ if not is_warmup :
432
+ self .maybe_wait_for_kv_save ()
433
+
424
434
if self ._rotary_prepare_cos_sin is not None :
425
435
self ._reset_rotary_cos_sin ()
426
436
return hidden_states
@@ -431,6 +441,22 @@ def compute_logits(self, *args, **kwargs):
431
441
# def sample(self, *args, **kwargs):
432
442
# return self.sampler(*args, **kwargs)
433
443
444
+ @staticmethod
445
+ def maybe_start_load_kv ():
446
+ if has_kv_transfer_group ():
447
+ kv_connector = get_kv_transfer_group ()
448
+
449
+ # Background KV cache transfers happen here.
450
+ # These transfers are designed to be async and the requests
451
+ # involved may be disjoint from the running requests.
452
+ # Do this here to save a collective_rpc.
453
+ kv_connector .start_load_kv (get_forward_context ())
454
+
455
+ @staticmethod
456
+ def maybe_wait_for_kv_save () -> None :
457
+ if has_kv_transfer_group ():
458
+ get_kv_transfer_group ().wait_for_save ()
459
+
434
460
def generate_proposals (self , * args , ** kwargs ):
435
461
return self .model .generate_proposals (* args , ** kwargs )
436
462
@@ -716,6 +742,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
716
742
req_index = self .input_batch .remove_request (req_id )
717
743
if req_index is not None :
718
744
removed_req_indices .append (req_index )
745
+ if req_id in self .input_batch .req_type :
746
+ del self .input_batch .req_type [req_id ]
719
747
720
748
# Remove the unscheduled requests from the persistent batch.
721
749
# NOTE(woosuk): The unscheduled requests are either preempted requests
@@ -862,6 +890,10 @@ def get_model(self) -> torch.nn.Module:
862
890
assert self .model is not None
863
891
return self .model
864
892
893
+ def is_decoder_only (self , req_id ) -> bool :
894
+ return bool (req_id in self .input_batch .req_type and \
895
+ self .input_batch .req_type [req_id ] == "decode" )
896
+
865
897
def _get_prompts_and_decodes (
866
898
self ,
867
899
scheduler_output : "SchedulerOutput" ,
@@ -871,24 +903,38 @@ def _get_prompts_and_decodes(
871
903
num_reqs = self .input_batch .num_reqs
872
904
assert num_reqs > 0
873
905
906
+ if scheduler_output .kv_connector_metadata :
907
+ requests = scheduler_output .kv_connector_metadata .requests
908
+ else :
909
+ requests = None
910
+
874
911
# Traverse decodes first
875
912
decode_req_ids = []
876
913
num_computed_tokens_decode = []
877
914
for i in range (num_reqs ):
878
915
req_id = self .input_batch .req_ids [i ]
879
916
assert req_id is not None
880
917
918
+ if requests is not None and req_id not in self .input_batch .req_type :
919
+ for request in requests :
920
+ if request .req_id == req_id :
921
+ self .input_batch .req_type [req_id ] = "prefill" \
922
+ if request .load_spec is None else "decode"
923
+ break
924
+
881
925
num_computed_tokens = self .input_batch .num_computed_tokens_cpu [i ]
882
926
num_prompt_tokens = self .input_batch .num_prompt_tokens [i ]
883
927
num_scheduled_tokens = scheduler_output .num_scheduled_tokens [
884
928
req_id ]
885
929
886
- if num_computed_tokens < num_prompt_tokens :
930
+ if num_computed_tokens < num_prompt_tokens and \
931
+ not self .is_decoder_only (req_id ):
887
932
# This is prompt
888
933
break
889
934
890
935
# This is decode
891
- assert num_scheduled_tokens == 1
936
+ if not self .is_decoder_only (req_id ):
937
+ assert num_scheduled_tokens == 1
892
938
decode_req_ids .append (req_id )
893
939
num_computed_tokens_decode .append (int (num_computed_tokens + 1 ))
894
940
@@ -1369,7 +1415,7 @@ def _prepare_inputs(
1369
1415
num_scheduled_tokens .append (seq_num_scheduled_tokens )
1370
1416
num_prompt_tokens .append (seq_num_prompt_tokens )
1371
1417
# NOTE: assert that all the decodes are "decodes".
1372
- if idx < num_decodes :
1418
+ if idx < num_decodes and not self . is_decoder_only ( req_id ) :
1373
1419
assert seq_num_scheduled_tokens == 1
1374
1420
return (self ._prepare_prefill_inputs (num_prefills , num_decodes ,
1375
1421
num_scheduled_tokens ),
@@ -1391,8 +1437,8 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata,
1391
1437
self .seen_configs .add (cfg )
1392
1438
if not seen and not warmup_mode :
1393
1439
logger .warning (
1394
- "Configuration: (%s, %s, %s, %s) was not warmed-up!" , phase ,
1395
- batch_size , seq_len , num_blocks )
1440
+ "Configuration: rank (%s, %s, %s, %s, %s ) was not warmed-up!" ,
1441
+ os . getenv ( 'RANK' , '0' ), phase , batch_size , seq_len , num_blocks )
1396
1442
1397
1443
def _execute_model_generic (self ,
1398
1444
token_ids ,
@@ -1579,8 +1625,11 @@ def execute_model(
1579
1625
1580
1626
batch_changed = self ._update_states (scheduler_output )
1581
1627
if not scheduler_output .total_num_scheduled_tokens :
1582
- # Return empty ModelRunnerOuptut if there's no work to do.
1583
- return EMPTY_MODEL_RUNNER_OUTPUT
1628
+ if not has_kv_transfer_group ():
1629
+ # Return empty ModelRunnerOuptut if there's no work to do.
1630
+ return EMPTY_MODEL_RUNNER_OUTPUT
1631
+
1632
+ return self .kv_connector_no_forward (scheduler_output )
1584
1633
# If necessary, swap decodes/prompts to have all decodes on the start
1585
1634
ensure_decodes_first (self .input_batch )
1586
1635
# Prepare prompts/decodes info
@@ -1607,11 +1656,14 @@ def execute_model(
1607
1656
self .event_start = self .profiler .get_timestamp_us ()
1608
1657
self .profiler .start ("internal" , "prefill" )
1609
1658
htorch .core .mark_step ()
1659
+ self .maybe_setup_kv_connector (scheduler_output )
1610
1660
prefill_hidden_states_ts , logits_device = \
1611
1661
self ._execute_model_generic (
1612
1662
token_ids , position_ids , attn_metadata , logits_indices ,
1613
1663
self .kv_caches )
1614
1664
htorch .core .mark_step ()
1665
+ finished_sending , finished_recving = (
1666
+ self .get_finished_kv_transfers (scheduler_output ))
1615
1667
with self .profiler .record_event ('internal' , "sampler" ):
1616
1668
sampling_metadata = self ._prepare_sampling (
1617
1669
batch_changed , req_id , pad_to = logits_device .shape [0 ])
@@ -1645,11 +1697,15 @@ def execute_model(
1645
1697
self .profiler .start ("internal" , "decode" )
1646
1698
assert decode_data is not None
1647
1699
htorch .core .mark_step ()
1648
- _ , logits_device = self ._execute_model_generic (
1700
+ self .maybe_setup_kv_connector (scheduler_output )
1701
+ _ , logits_device = \
1702
+ self ._execute_model_generic (
1649
1703
decode_data .token_ids , decode_data .position_ids ,
1650
1704
decode_data .attn_metadata , decode_data .logits_indices ,
1651
1705
self .kv_caches )
1652
1706
htorch .core .mark_step ()
1707
+ finished_sending , finished_recving = (
1708
+ self .get_finished_kv_transfers (scheduler_output ))
1653
1709
with self .profiler .record_event ('internal' , "sampler" ):
1654
1710
sampling_metadata = self ._prepare_sampling (
1655
1711
batch_changed ,
@@ -1760,7 +1816,11 @@ def execute_model(
1760
1816
spec_token_ids = None ,
1761
1817
prompt_logprobs_dict = prompt_logprobs_dict , # type: ignore[arg-type]
1762
1818
pooler_output = [],
1819
+ finished_sending = finished_sending ,
1820
+ finished_recving = finished_recving ,
1763
1821
)
1822
+ if has_kv_transfer_group ():
1823
+ get_kv_transfer_group ().clear_connector_metadata ()
1764
1824
return model_runner_output
1765
1825
1766
1826
def load_model (self ) -> None :
@@ -2450,3 +2510,41 @@ def reload_weights(self) -> None:
2450
2510
logger .info ("Reloading weights inplace..." )
2451
2511
model_loader .load_weights (self .model , model_config = self .model_config )
2452
2512
torch .hpu .synchronize ()
2513
+
2514
+ @staticmethod
2515
+ def maybe_setup_kv_connector (scheduler_output : "SchedulerOutput" ):
2516
+ # Update KVConnector with the KVConnector metadata forward().
2517
+ if has_kv_transfer_group ():
2518
+ kv_connector = get_kv_transfer_group ()
2519
+ assert isinstance (kv_connector , KVConnectorBase_V1 )
2520
+ assert scheduler_output .kv_connector_metadata is not None
2521
+ kv_connector .bind_connector_metadata (
2522
+ scheduler_output .kv_connector_metadata )
2523
+
2524
+ @staticmethod
2525
+ def get_finished_kv_transfers (
2526
+ scheduler_output : "SchedulerOutput" ,
2527
+ ) -> tuple [Optional [set [str ]], Optional [set [str ]]]:
2528
+ if has_kv_transfer_group ():
2529
+ return get_kv_transfer_group ().get_finished (
2530
+ scheduler_output .finished_req_ids )
2531
+ return None , None
2532
+
2533
+ def kv_connector_no_forward (
2534
+ self , scheduler_output : "SchedulerOutput" ) -> ModelRunnerOutput :
2535
+ # KV send/recv even if no work to do.
2536
+ with set_forward_context (None , self .vllm_config ):
2537
+ self .maybe_setup_kv_connector (scheduler_output )
2538
+ if has_kv_transfer_group ():
2539
+ kv_connector = get_kv_transfer_group ()
2540
+ kv_connector .start_load_kv (get_forward_context ())
2541
+ finished_sending , finished_recving = (
2542
+ self .get_finished_kv_transfers (scheduler_output ))
2543
+
2544
+ if not finished_sending and not finished_recving :
2545
+ return EMPTY_MODEL_RUNNER_OUTPUT
2546
+
2547
+ output = copy .copy (EMPTY_MODEL_RUNNER_OUTPUT )
2548
+ output .finished_sending = finished_sending
2549
+ output .finished_recving = finished_recving
2550
+ return output
0 commit comments