@@ -1191,6 +1191,7 @@ def _get_prompt_logprobs_dict(
1191
1191
if not num_prompt_logprobs_dict :
1192
1192
return {}
1193
1193
1194
+ in_progress_dict = self .input_batch .in_progress_prompt_logprobs_cpu
1194
1195
prompt_logprobs_dict : dict [str , Optional [LogprobsTensors ]] = {}
1195
1196
1196
1197
# Since prompt logprobs are a rare feature, prioritize simple,
@@ -1206,16 +1207,36 @@ def _get_prompt_logprobs_dict(
1206
1207
prompt_token_ids = torch .tensor (request .prompt_token_ids ).to (
1207
1208
self .device , non_blocking = True )
1208
1209
1210
+ # Set up target LogprobsTensors object.
1211
+ logprobs_tensors = in_progress_dict .get (req_id )
1212
+ if not logprobs_tensors :
1213
+ # Create empty logprobs CPU tensors for the entire prompt.
1214
+ # If chunked, we'll copy in slice by slice.
1215
+ logprobs_tensors = LogprobsTensors .empty_cpu (
1216
+ num_prompt_tokens - 1 , num_prompt_logprobs + 1 )
1217
+ in_progress_dict [req_id ] = logprobs_tensors
1218
+
1209
1219
# Determine number of logits to retrieve.
1210
- start_tok = request .num_computed_tokens + 1
1220
+ start_idx = request .num_computed_tokens
1221
+ start_tok = start_idx + 1
1211
1222
num_remaining_tokens = num_prompt_tokens - start_tok
1212
- if num_tokens < num_remaining_tokens :
1223
+ if num_tokens <= num_remaining_tokens :
1213
1224
# This is a chunk, more tokens remain.
1225
+ # In the == case, there are no more prompt logprobs to produce
1226
+ # but we want to defer returning them to the next step where we
1227
+ # have new generated tokens to return.
1214
1228
num_logits = num_tokens
1215
1229
else :
1216
1230
# This is the last chunk of prompt tokens to return.
1217
1231
num_logits = num_remaining_tokens
1218
1232
completed_prefill_reqs .append (req_id )
1233
+ prompt_logprobs_dict [req_id ] = logprobs_tensors
1234
+
1235
+ if num_logits <= 0 :
1236
+ # This can happen for the final chunk if we prefilled exactly
1237
+ # (num_prompt_tokens - 1) tokens for this request in the prior
1238
+ # step. There are no more prompt logprobs to produce.
1239
+ continue
1219
1240
1220
1241
# Get the logits corresponding to this req's prompt tokens.
1221
1242
# If this is a partial request (i.e. chunked prefill),
@@ -1236,19 +1257,23 @@ def _get_prompt_logprobs_dict(
1236
1257
logprobs , num_prompt_logprobs , tgt_token_ids )
1237
1258
1238
1259
# Transfer GPU->CPU async.
1239
- prompt_logprobs_dict [req_id ] = LogprobsTensors (
1240
- token_ids .to ("cpu" , non_blocking = True ),
1241
- logprobs .to ("cpu" , non_blocking = True ),
1242
- ranks .to ("cpu" , non_blocking = True ),
1243
- )
1260
+ chunk_slice = slice (start_idx , start_idx + num_logits )
1261
+ logprobs_tensors .logprob_token_ids [chunk_slice ].copy_ (
1262
+ token_ids , non_blocking = True )
1263
+ logprobs_tensors .logprobs [chunk_slice ].copy_ (logprobs ,
1264
+ non_blocking = True )
1265
+ logprobs_tensors .selected_token_ranks [chunk_slice ].copy_ (
1266
+ ranks , non_blocking = True )
1244
1267
1245
1268
# Remove requests that have completed prefill from the batch
1246
1269
# num_prompt_logprobs_dict.
1247
1270
for req_id in completed_prefill_reqs :
1248
1271
del num_prompt_logprobs_dict [req_id ]
1272
+ del in_progress_dict [req_id ]
1249
1273
1250
1274
# Must synchronize the non-blocking GPU->CPU transfers.
1251
- torch .cuda .synchronize ()
1275
+ if prompt_logprobs_dict :
1276
+ torch .cuda .synchronize ()
1252
1277
1253
1278
return prompt_logprobs_dict
1254
1279
0 commit comments