Skip to content

Commit aefeea0

Browse files
sdavidbdDavid Ben-David
andauthored
[V1] [P/D] Refactor KV Connector Path (#21980)
Signed-off-by: David Ben-David <[email protected]> Co-authored-by: David Ben-David <[email protected]>
1 parent 24d1dff commit aefeea0

File tree

12 files changed

+141
-79
lines changed

12 files changed

+141
-79
lines changed

tests/v1/kv_connector/unit/test_output_aggreagator.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,24 @@
44
from typing import Optional
55

66
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
7-
from vllm.v1.outputs import ModelRunnerOutput
7+
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
88

99

1010
class DummyModelRunnerOutput(ModelRunnerOutput):
1111

1212
def __init__(self,
1313
finished_sending: Optional[set[str]] = None,
1414
finished_recving: Optional[set[str]] = None):
15-
self.finished_sending = finished_sending
16-
self.finished_recving = finished_recving
15+
self.kv_connector_output = KVConnectorOutput(
16+
finished_sending=finished_sending,
17+
finished_recving=finished_recving,
18+
)
19+
20+
def __repr__(self):
21+
return (
22+
f"DummyModelRunnerOutput("
23+
f"finished_sending={self.kv_connector_output.finished_sending},"
24+
f"finished_recving={self.kv_connector_output.finished_recving})")
1725

1826

1927
def test_aggregate_workers_output():
@@ -27,6 +35,7 @@ def test_aggregate_workers_output():
2735
aggregated = aggregator.aggregate([output1, output2])
2836

2937
assert aggregated is output1
38+
aggregated = aggregated.kv_connector_output
3039
assert aggregated.finished_sending is None
3140
assert aggregated.finished_recving is None
3241

@@ -38,6 +47,7 @@ def test_aggregate_workers_output():
3847
aggregated = aggregator.aggregate([output1, output2])
3948

4049
assert aggregated is output1
50+
aggregated = aggregated.kv_connector_output
4151
assert aggregated.finished_sending == {'req1'}
4252
assert aggregated.finished_recving is None
4353

@@ -49,6 +59,7 @@ def test_aggregate_workers_output():
4959
aggregated = aggregator.aggregate([output1, output2])
5060

5161
assert aggregated is output1
62+
aggregated = aggregated.kv_connector_output
5263
assert aggregated.finished_sending is None
5364
assert aggregated.finished_recving == {'req2'}
5465

@@ -70,6 +81,7 @@ def test_async_aggregate_workers_output():
7081
assert result_future.done()
7182
aggregated = result_future.result()
7283
assert aggregated is output1
84+
aggregated = aggregated.kv_connector_output
7385
assert aggregated.finished_sending is None
7486
assert aggregated.finished_recving is None
7587

@@ -87,6 +99,7 @@ def test_async_aggregate_workers_output():
8799
assert result_future.done()
88100
aggregated = result_future.result()
89101
assert aggregated is output1
102+
aggregated = aggregated.kv_connector_output
90103
assert aggregated.finished_sending == {'req1'}
91104
assert aggregated.finished_recving is None
92105

@@ -104,5 +117,6 @@ def test_async_aggregate_workers_output():
104117
assert result_future.done()
105118
aggregated = result_future.result()
106119
assert aggregated is output1
120+
aggregated = aggregated.kv_connector_output
107121
assert aggregated.finished_sending is None
108122
assert aggregated.finished_recving == {'req2'}

tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import copy
44

5-
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
5+
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
66
from vllm.v1.request import FinishReason, RequestStatus
77

88
from .utils import (assert_scheduler_empty, create_model_runner_output,
@@ -86,7 +86,8 @@ def test_basic_lifecycle():
8686

8787
# (3b): execute_model()
8888
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
89-
model_runner_output.finished_sending = [request_id]
89+
model_runner_output.kv_connector_output = KVConnectorOutput(
90+
finished_sending=[request_id])
9091

9192
# (3c): update_from_output()
9293
scheduler.update_from_output(scheduler_output, model_runner_output)
@@ -176,7 +177,8 @@ def test_prefix_cache_lifecycle():
176177
scheduler_output = scheduler.schedule()
177178
scheduler.schedule()
178179
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
179-
model_runner_output.finished_sending = [request_remote.request_id]
180+
model_runner_output.kv_connector_output = KVConnectorOutput(
181+
finished_sending=[request_remote.request_id])
180182
scheduler.update_from_output(scheduler_output, model_runner_output)
181183
_ = scheduler.schedule()
182184
assert_scheduler_empty(scheduler)

tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import copy
44

5-
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
5+
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
66
from vllm.v1.request import FinishReason, RequestStatus
77

88
from .utils import (assert_scheduler_empty, create_model_runner_output,
@@ -72,7 +72,8 @@ def test_basic_lifecycle():
7272

7373
# (2b): forward(): request finishes recv.
7474
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
75-
model_runner_output.finished_recving = [request_id]
75+
model_runner_output.kv_connector_output = KVConnectorOutput(
76+
finished_recving=[request_id])
7677

7778
# (2c): update_from_output():
7879
engine_core_outputs = scheduler.update_from_output(scheduler_output,
@@ -309,7 +310,8 @@ def test_full_block_prompt():
309310
# # STEP (2): Recv.
310311
scheduler_output = scheduler.schedule()
311312
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
312-
model_runner_output.finished_recving = [request_id]
313+
model_runner_output.kv_connector_output = KVConnectorOutput(
314+
finished_recving=[request_id])
313315
scheduler.update_from_output(scheduler_output, model_runner_output)
314316
assert len(scheduler.waiting) == 1
315317
assert (request_id in scheduler.finished_recving_kv_req_ids)

tests/v1/kv_connector/unit/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.v1.core.sched.scheduler import Scheduler
1818
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1919
KVCacheGroupSpec)
20-
from vllm.v1.outputs import ModelRunnerOutput
20+
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
2121
from vllm.v1.request import Request
2222
from vllm.v1.structured_output import StructuredOutputManager
2323

@@ -188,8 +188,10 @@ def create_model_runner_output(
188188
logprobs=None,
189189
prompt_logprobs_dict={},
190190
pooler_output=None,
191-
finished_sending=finished_sending,
192-
finished_recving=finished_recving,
191+
kv_connector_output=KVConnectorOutput(
192+
finished_sending=finished_sending,
193+
finished_recving=finished_recving,
194+
),
193195
)
194196

195197

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from vllm.distributed.kv_transfer.kv_connector.factory import (
1717
KVConnectorFactory)
1818
from vllm.logger import init_logger
19-
from vllm.v1.outputs import ModelRunnerOutput
19+
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
2020

2121
logger = init_logger(__name__)
2222

@@ -129,7 +129,7 @@ def __init__(self, world_size: int):
129129
def aggregate(self,
130130
outputs: list[ModelRunnerOutput],
131131
output_rank: int = 0) -> ModelRunnerOutput:
132-
# aggregate finished_sending, finished_recving from all workers
132+
# aggregate kv_connector_output from all workers
133133

134134
def update_finished_set(req_ids: Optional[set[str]],
135135
remaining_count_dict: dict[str, int],
@@ -143,6 +143,7 @@ def update_finished_set(req_ids: Optional[set[str]],
143143
finished_sending = set[str]()
144144
finished_recving = set[str]()
145145
for output in outputs:
146+
output = output.kv_connector_output
146147
update_finished_set(output.finished_sending,
147148
self._send_remaining_count, finished_sending)
148149
update_finished_set(output.finished_recving,
@@ -151,13 +152,10 @@ def update_finished_set(req_ids: Optional[set[str]],
151152
# select output of the worker specified by output_rank
152153
output = outputs[output_rank]
153154

154-
# set the aggregated finished_sending / finished_recving
155-
# if output.finished_sending/recving is not empty, but the other ranks
156-
# still have unfinished send/recv, we want to set the aggregated
157-
# finished_sending/recving to None until all ranks have finished
158-
# send/recv
159-
output.finished_sending = finished_sending if finished_sending else None
160-
output.finished_recving = finished_recving if finished_recving else None
155+
output.kv_connector_output = KVConnectorOutput(
156+
finished_sending=finished_sending or None,
157+
finished_recving=finished_recving or None,
158+
)
161159

162160
return output
163161

vllm/sequence.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Sequence as GenericSequence
1111
from dataclasses import dataclass, field
1212
from functools import reduce
13-
from typing import Any, Callable, Optional, Union
13+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1414

1515
import msgspec
1616
import torch
@@ -21,6 +21,10 @@
2121
from vllm.pooling_params import PoolingParams
2222
from vllm.sampling_params import RequestOutputKind, SamplingParams
2323

24+
if TYPE_CHECKING:
25+
from vllm.v1.worker.kv_connector_model_runner_mixin import (
26+
KVConnectorOutput)
27+
2428
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
2529

2630
VLLM_INVALID_TOKEN_ID = -1
@@ -1159,14 +1163,11 @@ class IntermediateTensors:
11591163
states and residuals to be sent to the next stage. This data structure
11601164
contains the hidden states and residuals for a request.
11611165
1162-
Each stage also needs to handle its own finished_sending and
1163-
finished_recving in case of kv transfer.
1166+
Each stage also needs to handle its own kv_connector_output.
11641167
"""
11651168

11661169
tensors: dict[str, torch.Tensor]
1167-
# [req_ids]
1168-
finished_sending: Optional[set[str]] = None
1169-
finished_recving: Optional[set[str]] = None
1170+
kv_connector_output: Optional["KVConnectorOutput"]
11701171

11711172
def __init__(self, tensors):
11721173
# manually define this function, so that

vllm/v1/core/sched/scheduler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
EngineCoreOutputs)
3131
from vllm.v1.kv_cache_interface import KVCacheConfig
3232
from vllm.v1.metrics.stats import SchedulerStats
33-
from vllm.v1.outputs import ModelRunnerOutput
33+
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
3434
from vllm.v1.request import Request, RequestStatus
3535
from vllm.v1.spec_decode.metrics import SpecDecodingStats
3636
from vllm.v1.structured_output import StructuredOutputManager
@@ -884,7 +884,9 @@ def update_from_output(
884884
self.waiting.remove_requests(stopped_preempted_reqs)
885885

886886
# KV Connector: update state for finished KV Transfers.
887-
self._update_from_kv_xfer_finished(model_runner_output)
887+
if model_runner_output.kv_connector_output:
888+
self._update_from_kv_xfer_finished(
889+
model_runner_output.kv_connector_output)
888890

889891
# Create EngineCoreOutputs for all clients that have requests with
890892
# outputs in this step.
@@ -1128,7 +1130,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool:
11281130
return True
11291131

11301132
def _update_from_kv_xfer_finished(self,
1131-
model_runner_output: ModelRunnerOutput):
1133+
kv_connector_output: KVConnectorOutput):
11321134
"""
11331135
KV Connector: update the scheduler state based on the output.
11341136
@@ -1139,9 +1141,9 @@ def _update_from_kv_xfer_finished(self,
11391141
scheduler the request during the next step.
11401142
"""
11411143
# KV Connector:: update recv and send status from last step.
1142-
for req_id in (model_runner_output.finished_recving or ()):
1144+
for req_id in (kv_connector_output.finished_recving or ()):
11431145
logger.debug("Finished recving KV transfer for request %s", req_id)
11441146
self.finished_recving_kv_req_ids.add(req_id)
1145-
for req_id in (model_runner_output.finished_sending or ()):
1147+
for req_id in (kv_connector_output.finished_sending or ()):
11461148
logger.debug("Finished sending KV transfer for request %s", req_id)
11471149
self._free_blocks(self.requests[req_id])

vllm/v1/outputs.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ class SamplerOutput:
7171
logprobs_tensors: Optional[LogprobsTensors]
7272

7373

74+
@dataclass
75+
class KVConnectorOutput:
76+
# [req_ids]
77+
finished_sending: Optional[set[str]] = None
78+
finished_recving: Optional[set[str]] = None
79+
80+
7481
# ModelRunnerOutput is serialized and sent to the scheduler process.
7582
# This is expensive for torch.Tensor so prefer to use list instead.
7683
@dataclass
@@ -104,9 +111,7 @@ class ModelRunnerOutput:
104111
# [num_reqs, hidden_size]
105112
pooler_output: list[Optional[torch.Tensor]]
106113

107-
# [req_ids]
108-
finished_sending: Optional[set[str]] = None
109-
finished_recving: Optional[set[str]] = None
114+
kv_connector_output: Optional[KVConnectorOutput] = None
110115

111116
# req_id -> num_nans_in_logits
112117
num_nans_in_logits: Optional[dict[str, int]] = None
@@ -119,6 +124,4 @@ class ModelRunnerOutput:
119124
logprobs=None,
120125
prompt_logprobs_dict={},
121126
pooler_output=[],
122-
finished_sending=None,
123-
finished_recving=None,
124127
num_nans_in_logits=None)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
7070
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
7171
from vllm.v1.worker.kv_connector_model_runner_mixin import (
72-
KVConnectorModelRunnerMixin)
72+
KVConnectorModelRunnerMixin, KVConnectorOutput)
7373
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
7474

7575
from ..sample.logits_processor import LogitsProcessorManager
@@ -1423,8 +1423,7 @@ def _pool(
14231423
hidden_states: torch.Tensor,
14241424
num_scheduled_tokens: int,
14251425
num_scheduled_tokens_np: np.ndarray,
1426-
finished_sending: Optional[set[str]],
1427-
finished_recving: Optional[set[str]],
1426+
kv_connector_output: Optional[KVConnectorOutput],
14281427
) -> ModelRunnerOutput:
14291428
assert self.input_batch.num_reqs ==\
14301429
len(self.input_batch.pooling_params), \
@@ -1459,8 +1458,7 @@ def _pool(
14591458
logprobs=None,
14601459
prompt_logprobs_dict={},
14611460
pooler_output=pooler_output,
1462-
finished_sending=finished_sending,
1463-
finished_recving=finished_recving,
1461+
kv_connector_output=kv_connector_output,
14641462
)
14651463

14661464
@torch.inference_mode()
@@ -1564,8 +1562,8 @@ def execute_model(
15641562
num_tokens=num_input_tokens,
15651563
num_tokens_across_dp=num_tokens_across_dp,
15661564
skip_cuda_graphs=skip_cuda_graphs,
1567-
):
1568-
self.maybe_setup_kv_connector(scheduler_output)
1565+
), self.maybe_get_kv_connector_output(
1566+
scheduler_output) as kv_connector_output:
15691567

15701568
model_output = self.model(
15711569
input_ids=input_ids,
@@ -1578,10 +1576,6 @@ def execute_model(
15781576
),
15791577
)
15801578

1581-
self.maybe_wait_for_kv_save()
1582-
finished_sending, finished_recving = (
1583-
self.get_finished_kv_transfers(scheduler_output))
1584-
15851579
if self.use_aux_hidden_state_outputs:
15861580
hidden_states, aux_hidden_states = model_output
15871581
else:
@@ -1597,20 +1591,17 @@ def execute_model(
15971591
== "external_launcher" and len(get_pp_group().ranks) > 0
15981592
if not get_pp_group().is_last_rank:
15991593
# For mid-pipeline stages, return the hidden states.
1594+
assert isinstance(hidden_states, IntermediateTensors)
16001595
if not broadcast_pp_output:
1601-
if finished_sending or finished_recving:
1602-
hidden_states.finished_sending = finished_sending
1603-
hidden_states.finished_recving = finished_recving
1596+
hidden_states.kv_connector_output = kv_connector_output
16041597
return hidden_states
1605-
assert isinstance(hidden_states, IntermediateTensors)
16061598
get_pp_group().send_tensor_dict(hidden_states.tensors,
16071599
all_gather_group=get_tp_group())
16081600
logits = None
16091601
else:
16101602
if self.input_batch.pooling_params:
16111603
return self._pool(hidden_states, num_scheduled_tokens,
1612-
num_scheduled_tokens_np, finished_sending,
1613-
finished_recving)
1604+
num_scheduled_tokens_np, kv_connector_output)
16141605

16151606
sample_hidden_states = hidden_states[logits_indices]
16161607
logits = self.model.compute_logits(sample_hidden_states, None)
@@ -1760,8 +1751,7 @@ def execute_model(
17601751
logprobs=logprobs_lists,
17611752
prompt_logprobs_dict=prompt_logprobs_dict,
17621753
pooler_output=[],
1763-
finished_sending=finished_sending,
1764-
finished_recving=finished_recving,
1754+
kv_connector_output=kv_connector_output,
17651755
num_nans_in_logits=num_nans_in_logits,
17661756
)
17671757

0 commit comments

Comments
 (0)