Skip to content

Commit c80c53a

Browse files
authored
[BugFix] Fix batch updates for pooling models (#23398)
Signed-off-by: Nick Hill <[email protected]>
1 parent 24d0c9e commit c80c53a

File tree

3 files changed

+95
-79
lines changed

3 files changed

+95
-79
lines changed

vllm/v1/sample/logits_processor/state.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def __init__(
5050
self.added = added or []
5151
self._is_removed_sorted = False
5252

53+
# Used to track changes in the pooling case
54+
# where we don't populate the added list.
55+
self.batch_changed = False
56+
5357
def _ensure_removed_sorted(self) -> None:
5458
"""Sort removed request indices in
5559
descending order.
@@ -80,6 +84,7 @@ def removed_append(self, index: int) -> None:
8084
raise RuntimeError("Cannot register new removed request after"
8185
" self.removed has been read.")
8286
self._removed.append(index)
87+
self.batch_changed = True
8388

8489
def has_removed(self) -> bool:
8590
return bool(self._removed)
@@ -98,9 +103,15 @@ def pop_removed(self) -> Optional[int]:
98103
return self._removed.pop()
99104
return None
100105

101-
def _is_update(self) -> bool:
102-
"""True if there is a batch state change"""
103-
return any((self._removed, self.moved, self.added))
106+
def reset(self) -> bool:
107+
"""Returns True if there were any changes to the batch."""
108+
self._is_removed_sorted = False
109+
self._removed.clear()
110+
self.moved.clear()
111+
self.added.clear()
112+
batch_changed = self.batch_changed
113+
self.batch_changed = False
114+
return batch_changed
104115

105116
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
106117
"""Generate a logitsprocs batch update data structure and reset
@@ -114,7 +125,8 @@ def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
114125
"""
115126
# Reset removal-sorting logic
116127
self._is_removed_sorted = False
117-
if not self._is_update():
128+
self.batch_changed = False
129+
if not any((self._removed, self.moved, self.added)):
118130
# No update; short-circuit
119131
return None
120132
# Build batch state update

vllm/v1/worker/gpu_input_batch.py

Lines changed: 76 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]:
6565
def get_token_id(self, idx: int) -> int:
6666
if idx < self.num_prompt_tokens:
6767
return self.prompt_token_ids[idx]
68-
else:
69-
return self.output_token_ids[idx - self.num_prompt_tokens]
68+
return self.output_token_ids[idx - self.num_prompt_tokens]
7069

7170

7271
class InputBatch:
@@ -261,30 +260,27 @@ def _register_add_request(self, request: "CachedRequestState") -> int:
261260
Not applicable to pooling models.
262261
"""
263262

264-
# Detailed added request metadata is only required for non-pooling
265-
# models, to support logitsprocs
266-
assert request.sampling_params
267-
268263
# Fill the next empty index if there is one.
269264
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
270265
# Append to end otherwise.
271266
new_req_index = self.num_reqs
272267

273268
assert new_req_index < self.max_num_reqs
274-
self.batch_update_builder.added.append(
275-
(new_req_index, request.sampling_params, request.prompt_token_ids,
276-
request.output_token_ids))
269+
self.batch_update_builder.batch_changed = True
270+
if request.sampling_params:
271+
# Detailed added request metadata is only required for non-pooling
272+
# models, to support logitsprocs.
273+
self.batch_update_builder.added.append(
274+
(new_req_index, request.sampling_params,
275+
request.prompt_token_ids, request.output_token_ids))
276+
277277
return new_req_index
278278

279279
def add_request(
280280
self,
281281
request: "CachedRequestState",
282282
) -> int:
283-
if not self.is_pooling_model:
284-
# New request index bookkeeping for autoregressive models.
285-
req_index = self._register_add_request(request)
286-
else:
287-
req_index = self.num_reqs
283+
req_index = self._register_add_request(request)
288284

289285
req_id = request.req_id
290286
if req_index == len(self._req_ids):
@@ -389,7 +385,7 @@ def add_request(
389385
self.logits_processing_needs_token_ids[req_index] = (
390386
pooling_params.requires_token_ids)
391387
else:
392-
raise NotImplementedError(request)
388+
raise NotImplementedError("Unrecognized request type")
393389

394390
# Add request lora ID
395391
if request.lora_request:
@@ -419,13 +415,25 @@ def remove_request(self, req_id: str) -> Optional[int]:
419415
req_index = self.req_id_to_index.pop(req_id, None)
420416
if req_index is None:
421417
return None
422-
if not self.is_pooling_model:
423-
# Autoregressive models require bookkeeping of removed requests to
424-
# support logitsprocs.
425-
self.batch_update_builder.removed_append(req_index)
418+
419+
self.batch_update_builder.removed_append(req_index)
426420
self._req_ids[req_index] = None
427421
self.req_output_token_ids[req_index] = None
428422

423+
# LoRA
424+
lora_id = self.request_lora_mapping[req_index]
425+
if lora_id != 0:
426+
lora_req_ids = self.lora_id_to_request_ids[lora_id]
427+
lora_req_ids.discard(req_id)
428+
if not lora_req_ids:
429+
del self.lora_id_to_request_ids[lora_id]
430+
del self.lora_id_to_lora_request[lora_id]
431+
self.request_lora_mapping[req_index] = 0
432+
433+
if self.is_pooling_model:
434+
self.pooling_params.pop(req_id, None)
435+
return req_index
436+
429437
self.greedy_reqs.discard(req_id)
430438
self.random_reqs.discard(req_id)
431439
self.top_p_reqs.discard(req_id)
@@ -439,29 +447,14 @@ def remove_request(self, req_id: str) -> Optional[int]:
439447
self.num_prompt_logprobs.pop(req_id, None)
440448
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
441449

442-
# LoRA
443-
lora_id = self.request_lora_mapping[req_index]
444-
if lora_id != 0:
445-
lora_req_ids = self.lora_id_to_request_ids[lora_id]
446-
lora_req_ids.discard(req_id)
447-
if not lora_req_ids:
448-
del self.lora_id_to_request_ids[lora_id]
449-
del self.lora_id_to_lora_request[lora_id]
450-
self.request_lora_mapping[req_index] = 0
451-
452450
self.has_allowed_token_ids.discard(req_id)
453451
if self.allowed_token_ids_mask_cpu_tensor is not None:
454452
# False means we don't fill with -inf.
455453
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
456454
self.bad_words_token_ids.pop(req_index, None)
457-
self.pooling_params.pop(req_id, None)
458455
return req_index
459456

460457
def swap_states(self, i1: int, i2: int) -> None:
461-
# For autoregressive models, track detailed request reordering info
462-
# to support logitsprocs
463-
self.batch_update_builder.moved.append(
464-
(i1, i2, MoveDirectionality.SWAP))
465458
old_id_i1 = self._req_ids[i1]
466459
old_id_i2 = self._req_ids[i2]
467460
self._req_ids[i1], self._req_ids[i2] =\
@@ -479,18 +472,6 @@ def swap_states(self, i1: int, i2: int) -> None:
479472
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
480473
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
481474
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
482-
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
483-
self.temperature_cpu[i2], self.temperature_cpu[i1]
484-
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
485-
self.top_p_cpu[i2], self.top_p_cpu[i1]
486-
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
487-
self.top_k_cpu[i2], self.top_k_cpu[i1]
488-
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
489-
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
490-
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
491-
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
492-
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
493-
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
494475

495476
# NOTE: the following is unsafe
496477
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
@@ -501,18 +482,41 @@ def swap_states(self, i1: int, i2: int) -> None:
501482
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
502483
self.token_ids_cpu[i2, ...] = tmp
503484

504-
swap_dict_values(self.generators, i1, i2)
505-
swap_dict_values(self.bad_words_token_ids, i1, i2)
485+
self.block_table.swap_row(i1, i2)
506486

507-
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
487+
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
508488
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
509489

490+
if self.is_pooling_model:
491+
# Sampling and logits parameters don't apply to pooling models.
492+
return
493+
494+
# For autoregressive models, track detailed request reordering info
495+
# to support logitsprocs.
496+
self.batch_update_builder.moved.append(
497+
(i1, i2, MoveDirectionality.SWAP))
498+
499+
self.temperature_cpu[i1], self.temperature_cpu[i2] = \
500+
self.temperature_cpu[i2], self.temperature_cpu[i1]
501+
self.top_p_cpu[i1], self.top_p_cpu[i2] = \
502+
self.top_p_cpu[i2], self.top_p_cpu[i1]
503+
self.top_k_cpu[i1], self.top_k_cpu[i2] = \
504+
self.top_k_cpu[i2], self.top_k_cpu[i1]
505+
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
506+
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
507+
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
508+
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
509+
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
510+
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
511+
512+
swap_dict_values(self.generators, i1, i2)
513+
swap_dict_values(self.bad_words_token_ids, i1, i2)
514+
510515
if self.allowed_token_ids_mask_cpu_tensor is not None:
511516
self.allowed_token_ids_mask_cpu_tensor[i1], \
512517
self.allowed_token_ids_mask_cpu_tensor[i2] =\
513518
self.allowed_token_ids_mask_cpu_tensor[i2], \
514519
self.allowed_token_ids_mask_cpu_tensor[i1]
515-
self.block_table.swap_row(i1, i2)
516520

517521
def condense(self) -> None:
518522
"""Slide non-empty requests down into lower, empty indices.
@@ -529,12 +533,6 @@ def condense(self) -> None:
529533
"""
530534
num_reqs = self.num_reqs
531535

532-
if self.is_pooling_model:
533-
# Will be contiguous in pooling case, just trim the lists.
534-
del self._req_ids[num_reqs:]
535-
del self.req_output_token_ids[num_reqs:]
536-
return
537-
538536
if not (empty_req_indices := self.batch_update_builder.removed):
539537
# All removed requests were replaced by added requests, or else no
540538
# requests were removed at all. No condense() needed
@@ -562,11 +560,6 @@ def condense(self) -> None:
562560
# Move active request down into empty request
563561
# index.
564562
self.batch_update_builder.pop_removed()
565-
# Autoregressive models require detailed tracking of condense
566-
# operations to support logitsprocs
567-
self.batch_update_builder.moved.append(
568-
(last_req_index, empty_index,
569-
MoveDirectionality.UNIDIRECTIONAL))
570563
req_id = self._req_ids[last_req_index]
571564
output_token_ids = self.req_output_token_ids[last_req_index]
572565
assert req_id is not None
@@ -587,6 +580,21 @@ def condense(self) -> None:
587580
self.num_computed_tokens_cpu[
588581
empty_index] = self.num_computed_tokens_cpu[last_req_index]
589582
self.block_table.move_row(last_req_index, empty_index)
583+
584+
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
585+
last_req_index]
586+
587+
if self.is_pooling_model:
588+
last_req_index -= 1
589+
# Samping state not used by pooling models.
590+
continue
591+
592+
# Autoregressive models require detailed tracking of condense
593+
# operations to support logitsprocs
594+
self.batch_update_builder.moved.append(
595+
(last_req_index, empty_index,
596+
MoveDirectionality.UNIDIRECTIONAL))
597+
590598
self.temperature_cpu[empty_index] = self.temperature_cpu[
591599
last_req_index]
592600
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
@@ -601,9 +609,6 @@ def condense(self) -> None:
601609
if generator is not None:
602610
self.generators[empty_index] = generator
603611

604-
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
605-
last_req_index]
606-
607612
# TODO convert these to LogitsProcessors
608613
if self.allowed_token_ids_mask_cpu_tensor is not None:
609614
self.allowed_token_ids_mask_cpu_tensor[
@@ -626,8 +631,9 @@ def refresh_metadata(self):
626631
"""Apply any batch updates to sampling metadata."""
627632

628633
if self.is_pooling_model:
629-
# Batch changes every step for pooling models.
630-
self.sampling_metadata = self._make_sampling_metadata()
634+
batch_changed = self.batch_update_builder.reset()
635+
if batch_changed:
636+
self.sampling_metadata = self._make_sampling_metadata()
631637
return
632638

633639
# For non-pooling models - generate and apply logitsprocs update;
@@ -720,19 +726,19 @@ def pooling_metadata(self) -> PoolingMetadata:
720726
)
721727

722728
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
723-
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
729+
num_reqs = self.num_reqs
730+
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
724731
prompt_token_ids_cpu_tensor = torch.empty(
725732
(self.num_reqs, max_prompt_len),
726733
device="cpu",
727734
dtype=torch.int64,
728735
pin_memory=self.pin_memory,
729736
)
730737
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
731-
prompt_token_ids[:] = self.token_ids_cpu[:self.
732-
num_reqs, :max_prompt_len]
738+
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
733739
# Use the value of vocab_size as a pad since we don't have a
734740
# token_id of this value.
735-
for i in range(self.num_reqs):
741+
for i in range(num_reqs):
736742
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
737743
return prompt_token_ids_cpu_tensor.to(device=self.device,
738744
non_blocking=True)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,10 +1489,8 @@ def _pool(
14891489
for raw_output, seq_len, prompt_len in zip(
14901490
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
14911491

1492-
if seq_len == prompt_len:
1493-
pooler_output.append(raw_output.data)
1494-
else:
1495-
pooler_output.append(None)
1492+
output = raw_output.data if seq_len == prompt_len else None
1493+
pooler_output.append(output)
14961494

14971495
return ModelRunnerOutput(
14981496
req_ids=self.input_batch.req_ids,
@@ -1522,7 +1520,7 @@ def execute_model(
15221520
# Prepare the decoder inputs.
15231521
(attn_metadata, logits_indices, spec_decode_metadata,
15241522
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
1525-
max_query_len) = (self._prepare_inputs(scheduler_output))
1523+
max_query_len) = self._prepare_inputs(scheduler_output)
15261524

15271525
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
15281526
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE

0 commit comments

Comments
 (0)