Skip to content

Commit 5898b13

Browse files
authored
[BugFix] Fix KVConnectorOutput TPU breakage (#22598)
Signed-off-by: Nick Hill <[email protected]>
1 parent b799f4b commit 5898b13

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

tests/v1/kv_connector/unit/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,13 @@ def create_model_runner_output(
179179
sampled_token = EOS_TOKEN_ID if use_eos else 0
180180
sampled_token_ids = [[sampled_token] for _ in req_ids]
181181

182+
kv_connector_output = None if (
183+
finished_sending is None
184+
and finished_recving is None) else KVConnectorOutput(
185+
finished_sending=finished_sending,
186+
finished_recving=finished_recving,
187+
)
188+
182189
# Make output data structure.
183190
return ModelRunnerOutput(
184191
req_ids=req_ids,
@@ -188,10 +195,7 @@ def create_model_runner_output(
188195
logprobs=None,
189196
prompt_logprobs_dict={},
190197
pooler_output=None,
191-
kv_connector_output=KVConnectorOutput(
192-
finished_sending=finished_sending,
193-
finished_recving=finished_recving,
194-
),
198+
kv_connector_output=kv_connector_output,
195199
)
196200

197201

vllm/v1/core/sched/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,8 +1151,8 @@ def _update_from_kv_xfer_finished(self,
11511151
scheduler the request during the next step.
11521152
"""
11531153

1154-
assert self.connector is not None
1155-
self.connector.update_connector_output(kv_connector_output)
1154+
if self.connector is not None:
1155+
self.connector.update_connector_output(kv_connector_output)
11561156

11571157
# KV Connector:: update recv and send status from last step.
11581158
for req_id in (kv_connector_output.finished_recving or ()):

vllm/v1/worker/tpu_model_runner.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,13 @@ def concat_lists(input_lists):
11381138
i, target_slice] = valid_sampled_token_ids[i]
11391139
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
11401140

1141+
kv_connector_output = None if (
1142+
finished_sending is None
1143+
and finished_recving is None) else KVConnectorOutput(
1144+
finished_sending=finished_sending,
1145+
finished_recving=finished_recving,
1146+
)
1147+
11411148
model_runner_output = ModelRunnerOutput(
11421149
req_ids=req_ids,
11431150
req_id_to_index=self.input_batch.req_id_to_index,
@@ -1146,10 +1153,8 @@ def concat_lists(input_lists):
11461153
logprobs=logprobs_lists,
11471154
prompt_logprobs_dict=prompt_logprobs_dict,
11481155
pooler_output=[],
1149-
kv_connector_output=KVConnectorOutput(
1150-
finished_sending=finished_sending,
1151-
finished_recving=finished_recving,
1152-
))
1156+
kv_connector_output=kv_connector_output,
1157+
)
11531158

11541159
# Check there are no new graphs compiled - all the graphs should be
11551160
# captured and compiled during warm up.

0 commit comments

Comments
 (0)