@@ -65,8 +65,7 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]:
65
65
def get_token_id (self , idx : int ) -> int :
66
66
if idx < self .num_prompt_tokens :
67
67
return self .prompt_token_ids [idx ]
68
- else :
69
- return self .output_token_ids [idx - self .num_prompt_tokens ]
68
+ return self .output_token_ids [idx - self .num_prompt_tokens ]
70
69
71
70
72
71
class InputBatch :
@@ -261,30 +260,27 @@ def _register_add_request(self, request: "CachedRequestState") -> int:
261
260
Not applicable to pooling models.
262
261
"""
263
262
264
- # Detailed added request metadata is only required for non-pooling
265
- # models, to support logitsprocs
266
- assert request .sampling_params
267
-
268
263
# Fill the next empty index if there is one.
269
264
if (new_req_index := self .batch_update_builder .pop_removed ()) is None :
270
265
# Append to end otherwise.
271
266
new_req_index = self .num_reqs
272
267
273
268
assert new_req_index < self .max_num_reqs
274
- self .batch_update_builder .added .append (
275
- (new_req_index , request .sampling_params , request .prompt_token_ids ,
276
- request .output_token_ids ))
269
+ self .batch_update_builder .batch_changed = True
270
+ if request .sampling_params :
271
+ # Detailed added request metadata is only required for non-pooling
272
+ # models, to support logitsprocs.
273
+ self .batch_update_builder .added .append (
274
+ (new_req_index , request .sampling_params ,
275
+ request .prompt_token_ids , request .output_token_ids ))
276
+
277
277
return new_req_index
278
278
279
279
def add_request (
280
280
self ,
281
281
request : "CachedRequestState" ,
282
282
) -> int :
283
- if not self .is_pooling_model :
284
- # New request index bookkeeping for autoregressive models.
285
- req_index = self ._register_add_request (request )
286
- else :
287
- req_index = self .num_reqs
283
+ req_index = self ._register_add_request (request )
288
284
289
285
req_id = request .req_id
290
286
if req_index == len (self ._req_ids ):
@@ -389,7 +385,7 @@ def add_request(
389
385
self .logits_processing_needs_token_ids [req_index ] = (
390
386
pooling_params .requires_token_ids )
391
387
else :
392
- raise NotImplementedError (request )
388
+ raise NotImplementedError ("Unrecognized request type" )
393
389
394
390
# Add request lora ID
395
391
if request .lora_request :
@@ -419,13 +415,25 @@ def remove_request(self, req_id: str) -> Optional[int]:
419
415
req_index = self .req_id_to_index .pop (req_id , None )
420
416
if req_index is None :
421
417
return None
422
- if not self .is_pooling_model :
423
- # Autoregressive models require bookkeeping of removed requests to
424
- # support logitsprocs.
425
- self .batch_update_builder .removed_append (req_index )
418
+
419
+ self .batch_update_builder .removed_append (req_index )
426
420
self ._req_ids [req_index ] = None
427
421
self .req_output_token_ids [req_index ] = None
428
422
423
+ # LoRA
424
+ lora_id = self .request_lora_mapping [req_index ]
425
+ if lora_id != 0 :
426
+ lora_req_ids = self .lora_id_to_request_ids [lora_id ]
427
+ lora_req_ids .discard (req_id )
428
+ if not lora_req_ids :
429
+ del self .lora_id_to_request_ids [lora_id ]
430
+ del self .lora_id_to_lora_request [lora_id ]
431
+ self .request_lora_mapping [req_index ] = 0
432
+
433
+ if self .is_pooling_model :
434
+ self .pooling_params .pop (req_id , None )
435
+ return req_index
436
+
429
437
self .greedy_reqs .discard (req_id )
430
438
self .random_reqs .discard (req_id )
431
439
self .top_p_reqs .discard (req_id )
@@ -439,29 +447,14 @@ def remove_request(self, req_id: str) -> Optional[int]:
439
447
self .num_prompt_logprobs .pop (req_id , None )
440
448
self .in_progress_prompt_logprobs_cpu .pop (req_id , None )
441
449
442
- # LoRA
443
- lora_id = self .request_lora_mapping [req_index ]
444
- if lora_id != 0 :
445
- lora_req_ids = self .lora_id_to_request_ids [lora_id ]
446
- lora_req_ids .discard (req_id )
447
- if not lora_req_ids :
448
- del self .lora_id_to_request_ids [lora_id ]
449
- del self .lora_id_to_lora_request [lora_id ]
450
- self .request_lora_mapping [req_index ] = 0
451
-
452
450
self .has_allowed_token_ids .discard (req_id )
453
451
if self .allowed_token_ids_mask_cpu_tensor is not None :
454
452
# False means we don't fill with -inf.
455
453
self .allowed_token_ids_mask_cpu_tensor [req_index ].fill_ (False )
456
454
self .bad_words_token_ids .pop (req_index , None )
457
- self .pooling_params .pop (req_id , None )
458
455
return req_index
459
456
460
457
def swap_states (self , i1 : int , i2 : int ) -> None :
461
- # For autoregressive models, track detailed request reordering info
462
- # to support logitsprocs
463
- self .batch_update_builder .moved .append (
464
- (i1 , i2 , MoveDirectionality .SWAP ))
465
458
old_id_i1 = self ._req_ids [i1 ]
466
459
old_id_i2 = self ._req_ids [i2 ]
467
460
self ._req_ids [i1 ], self ._req_ids [i2 ] = \
@@ -479,18 +472,6 @@ def swap_states(self, i1: int, i2: int) -> None:
479
472
self .num_prompt_tokens [i2 ], self .num_prompt_tokens [i1 ]
480
473
self .num_computed_tokens_cpu [i1 ], self .num_computed_tokens_cpu [i2 ] = \
481
474
self .num_computed_tokens_cpu [i2 ], self .num_computed_tokens_cpu [i1 ]
482
- self .temperature_cpu [i1 ], self .temperature_cpu [i2 ] = \
483
- self .temperature_cpu [i2 ], self .temperature_cpu [i1 ]
484
- self .top_p_cpu [i1 ], self .top_p_cpu [i2 ] = \
485
- self .top_p_cpu [i2 ], self .top_p_cpu [i1 ]
486
- self .top_k_cpu [i1 ], self .top_k_cpu [i2 ] = \
487
- self .top_k_cpu [i2 ], self .top_k_cpu [i1 ]
488
- self .frequency_penalties_cpu [i1 ], self .frequency_penalties_cpu [i2 ] = \
489
- self .frequency_penalties_cpu [i2 ], self .frequency_penalties_cpu [i1 ]
490
- self .presence_penalties_cpu [i1 ], self .presence_penalties_cpu [i2 ] = \
491
- self .presence_penalties_cpu [i2 ], self .presence_penalties_cpu [i1 ]
492
- self .repetition_penalties_cpu [i1 ], self .repetition_penalties_cpu [i2 ] = \
493
- self .repetition_penalties_cpu [i2 ], self .repetition_penalties_cpu [i1 ]
494
475
495
476
# NOTE: the following is unsafe
496
477
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
@@ -501,18 +482,41 @@ def swap_states(self, i1: int, i2: int) -> None:
501
482
self .token_ids_cpu [i1 , ...] = self .token_ids_cpu [i2 , ...]
502
483
self .token_ids_cpu [i2 , ...] = tmp
503
484
504
- swap_dict_values (self .generators , i1 , i2 )
505
- swap_dict_values (self .bad_words_token_ids , i1 , i2 )
485
+ self .block_table .swap_row (i1 , i2 )
506
486
507
- self .request_lora_mapping [i1 ], self .request_lora_mapping [i2 ] = \
487
+ self .request_lora_mapping [i1 ], self .request_lora_mapping [i2 ] = \
508
488
self .request_lora_mapping [i2 ], self .request_lora_mapping [i1 ]
509
489
490
+ if self .is_pooling_model :
491
+ # Sampling and logits parameters don't apply to pooling models.
492
+ return
493
+
494
+ # For autoregressive models, track detailed request reordering info
495
+ # to support logitsprocs.
496
+ self .batch_update_builder .moved .append (
497
+ (i1 , i2 , MoveDirectionality .SWAP ))
498
+
499
+ self .temperature_cpu [i1 ], self .temperature_cpu [i2 ] = \
500
+ self .temperature_cpu [i2 ], self .temperature_cpu [i1 ]
501
+ self .top_p_cpu [i1 ], self .top_p_cpu [i2 ] = \
502
+ self .top_p_cpu [i2 ], self .top_p_cpu [i1 ]
503
+ self .top_k_cpu [i1 ], self .top_k_cpu [i2 ] = \
504
+ self .top_k_cpu [i2 ], self .top_k_cpu [i1 ]
505
+ self .frequency_penalties_cpu [i1 ], self .frequency_penalties_cpu [i2 ] = \
506
+ self .frequency_penalties_cpu [i2 ], self .frequency_penalties_cpu [i1 ]
507
+ self .presence_penalties_cpu [i1 ], self .presence_penalties_cpu [i2 ] = \
508
+ self .presence_penalties_cpu [i2 ], self .presence_penalties_cpu [i1 ]
509
+ self .repetition_penalties_cpu [i1 ], self .repetition_penalties_cpu [i2 ] = \
510
+ self .repetition_penalties_cpu [i2 ], self .repetition_penalties_cpu [i1 ]
511
+
512
+ swap_dict_values (self .generators , i1 , i2 )
513
+ swap_dict_values (self .bad_words_token_ids , i1 , i2 )
514
+
510
515
if self .allowed_token_ids_mask_cpu_tensor is not None :
511
516
self .allowed_token_ids_mask_cpu_tensor [i1 ], \
512
517
self .allowed_token_ids_mask_cpu_tensor [i2 ] = \
513
518
self .allowed_token_ids_mask_cpu_tensor [i2 ], \
514
519
self .allowed_token_ids_mask_cpu_tensor [i1 ]
515
- self .block_table .swap_row (i1 , i2 )
516
520
517
521
def condense (self ) -> None :
518
522
"""Slide non-empty requests down into lower, empty indices.
@@ -529,12 +533,6 @@ def condense(self) -> None:
529
533
"""
530
534
num_reqs = self .num_reqs
531
535
532
- if self .is_pooling_model :
533
- # Will be contiguous in pooling case, just trim the lists.
534
- del self ._req_ids [num_reqs :]
535
- del self .req_output_token_ids [num_reqs :]
536
- return
537
-
538
536
if not (empty_req_indices := self .batch_update_builder .removed ):
539
537
# All removed requests were replaced by added requests, or else no
540
538
# requests were removed at all. No condense() needed
@@ -562,11 +560,6 @@ def condense(self) -> None:
562
560
# Move active request down into empty request
563
561
# index.
564
562
self .batch_update_builder .pop_removed ()
565
- # Autoregressive models require detailed tracking of condense
566
- # operations to support logitsprocs
567
- self .batch_update_builder .moved .append (
568
- (last_req_index , empty_index ,
569
- MoveDirectionality .UNIDIRECTIONAL ))
570
563
req_id = self ._req_ids [last_req_index ]
571
564
output_token_ids = self .req_output_token_ids [last_req_index ]
572
565
assert req_id is not None
@@ -587,6 +580,21 @@ def condense(self) -> None:
587
580
self .num_computed_tokens_cpu [
588
581
empty_index ] = self .num_computed_tokens_cpu [last_req_index ]
589
582
self .block_table .move_row (last_req_index , empty_index )
583
+
584
+ self .request_lora_mapping [empty_index ] = self .request_lora_mapping [
585
+ last_req_index ]
586
+
587
+ if self .is_pooling_model :
588
+ last_req_index -= 1
589
+ # Samping state not used by pooling models.
590
+ continue
591
+
592
+ # Autoregressive models require detailed tracking of condense
593
+ # operations to support logitsprocs
594
+ self .batch_update_builder .moved .append (
595
+ (last_req_index , empty_index ,
596
+ MoveDirectionality .UNIDIRECTIONAL ))
597
+
590
598
self .temperature_cpu [empty_index ] = self .temperature_cpu [
591
599
last_req_index ]
592
600
self .top_p_cpu [empty_index ] = self .top_p_cpu [last_req_index ]
@@ -601,9 +609,6 @@ def condense(self) -> None:
601
609
if generator is not None :
602
610
self .generators [empty_index ] = generator
603
611
604
- self .request_lora_mapping [empty_index ] = self .request_lora_mapping [
605
- last_req_index ]
606
-
607
612
# TODO convert these to LogitsProcessors
608
613
if self .allowed_token_ids_mask_cpu_tensor is not None :
609
614
self .allowed_token_ids_mask_cpu_tensor [
@@ -626,8 +631,9 @@ def refresh_metadata(self):
626
631
"""Apply any batch updates to sampling metadata."""
627
632
628
633
if self .is_pooling_model :
629
- # Batch changes every step for pooling models.
630
- self .sampling_metadata = self ._make_sampling_metadata ()
634
+ batch_changed = self .batch_update_builder .reset ()
635
+ if batch_changed :
636
+ self .sampling_metadata = self ._make_sampling_metadata ()
631
637
return
632
638
633
639
# For non-pooling models - generate and apply logitsprocs update;
@@ -720,19 +726,19 @@ def pooling_metadata(self) -> PoolingMetadata:
720
726
)
721
727
722
728
def _make_prompt_token_ids_tensor (self ) -> torch .Tensor :
723
- max_prompt_len = self .num_prompt_tokens [:self .num_reqs ].max ()
729
+ num_reqs = self .num_reqs
730
+ max_prompt_len = self .num_prompt_tokens [:num_reqs ].max ()
724
731
prompt_token_ids_cpu_tensor = torch .empty (
725
732
(self .num_reqs , max_prompt_len ),
726
733
device = "cpu" ,
727
734
dtype = torch .int64 ,
728
735
pin_memory = self .pin_memory ,
729
736
)
730
737
prompt_token_ids = prompt_token_ids_cpu_tensor .numpy ()
731
- prompt_token_ids [:] = self .token_ids_cpu [:self .
732
- num_reqs , :max_prompt_len ]
738
+ prompt_token_ids [:] = self .token_ids_cpu [:num_reqs , :max_prompt_len ]
733
739
# Use the value of vocab_size as a pad since we don't have a
734
740
# token_id of this value.
735
- for i in range (self . num_reqs ):
741
+ for i in range (num_reqs ):
736
742
prompt_token_ids [i , self .num_prompt_tokens [i ]:] = self .vocab_size
737
743
return prompt_token_ids_cpu_tensor .to (device = self .device ,
738
744
non_blocking = True )
0 commit comments