File tree Expand file tree Collapse file tree 3 files changed +19
-10
lines changed
tests/v1/kv_connector/unit Expand file tree Collapse file tree 3 files changed +19
-10
lines changed Original file line number Diff line number Diff line change @@ -179,6 +179,13 @@ def create_model_runner_output(
179
179
sampled_token = EOS_TOKEN_ID if use_eos else 0
180
180
sampled_token_ids = [[sampled_token ] for _ in req_ids ]
181
181
182
+ kv_connector_output = None if (
183
+ finished_sending is None
184
+ and finished_recving is None ) else KVConnectorOutput (
185
+ finished_sending = finished_sending ,
186
+ finished_recving = finished_recving ,
187
+ )
188
+
182
189
# Make output data structure.
183
190
return ModelRunnerOutput (
184
191
req_ids = req_ids ,
@@ -188,10 +195,7 @@ def create_model_runner_output(
188
195
logprobs = None ,
189
196
prompt_logprobs_dict = {},
190
197
pooler_output = None ,
191
- kv_connector_output = KVConnectorOutput (
192
- finished_sending = finished_sending ,
193
- finished_recving = finished_recving ,
194
- ),
198
+ kv_connector_output = kv_connector_output ,
195
199
)
196
200
197
201
Original file line number Diff line number Diff line change @@ -1151,8 +1151,8 @@ def _update_from_kv_xfer_finished(self,
1151
1151
scheduler the request during the next step.
1152
1152
"""
1153
1153
1154
- assert self .connector is not None
1155
- self .connector .update_connector_output (kv_connector_output )
1154
+ if self .connector is not None :
1155
+ self .connector .update_connector_output (kv_connector_output )
1156
1156
1157
1157
# KV Connector:: update recv and send status from last step.
1158
1158
for req_id in (kv_connector_output .finished_recving or ()):
Original file line number Diff line number Diff line change @@ -1138,6 +1138,13 @@ def concat_lists(input_lists):
1138
1138
i , target_slice ] = valid_sampled_token_ids [i ]
1139
1139
req_state .output_token_ids .extend (valid_sampled_token_ids [i ])
1140
1140
1141
+ kv_connector_output = None if (
1142
+ finished_sending is None
1143
+ and finished_recving is None ) else KVConnectorOutput (
1144
+ finished_sending = finished_sending ,
1145
+ finished_recving = finished_recving ,
1146
+ )
1147
+
1141
1148
model_runner_output = ModelRunnerOutput (
1142
1149
req_ids = req_ids ,
1143
1150
req_id_to_index = self .input_batch .req_id_to_index ,
@@ -1146,10 +1153,8 @@ def concat_lists(input_lists):
1146
1153
logprobs = logprobs_lists ,
1147
1154
prompt_logprobs_dict = prompt_logprobs_dict ,
1148
1155
pooler_output = [],
1149
- kv_connector_output = KVConnectorOutput (
1150
- finished_sending = finished_sending ,
1151
- finished_recving = finished_recving ,
1152
- ))
1156
+ kv_connector_output = kv_connector_output ,
1157
+ )
1153
1158
1154
1159
# Check there are no new graphs compiled - all the graphs should be
1155
1160
# captured and compiled during warm up.
You can’t perform that action at this time.
0 commit comments