@@ -523,6 +523,7 @@ def beam_search(
523
523
params : BeamSearchParams ,
524
524
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
525
525
use_tqdm : bool = False ,
526
+ concurrency_limit : Optional [int ] = None ,
526
527
) -> list [BeamSearchOutput ]:
527
528
"""
528
529
Generate sequences using beam search.
@@ -533,6 +534,8 @@ def beam_search(
533
534
params: The beam search parameters.
534
535
lora_request: LoRA request to use for generation, if any.
535
536
use_tqdm: Whether to use tqdm to display the progress bar.
537
+ concurrency_limit: The maximum number of concurrent requests.
538
+ If None, the number of concurrent requests is unlimited.
536
539
"""
537
540
# TODO: how does beam search work together with length penalty,
538
541
# frequency, penalty, and stopping criteria, etc.?
@@ -551,6 +554,15 @@ def beam_search(
551
554
length_penalty ,
552
555
)
553
556
557
+ if use_tqdm and concurrency_limit is not None :
558
+ logger .warning (
559
+ "Progress bar is not supported when using concurrency_limit. "
560
+ "Disabling progress bar." )
561
+ use_tqdm = False
562
+
563
+ if concurrency_limit is None :
564
+ concurrency_limit = len (prompts )
565
+
554
566
def create_tokens_prompt_from_beam (
555
567
beam : BeamSearchSequence ) -> TokensPrompt :
556
568
token_prompt_kwargs : TokensPrompt = {
@@ -595,73 +607,79 @@ def create_tokens_prompt_from_beam(
595
607
** mm_kwargs ,
596
608
), )
597
609
598
- token_iter = range (max_tokens )
599
- if use_tqdm :
600
- token_iter = tqdm (token_iter ,
601
- desc = "Beam search" ,
602
- unit = "token" ,
603
- unit_scale = False )
604
- logger .warning (
605
- "The progress bar shows the upper bound on token steps and "
606
- "may finish early due to stopping conditions. It does not "
607
- "reflect instance-level progress." )
608
-
609
- for _ in token_iter :
610
- all_beams : list [BeamSearchSequence ] = list (
611
- sum ((instance .beams for instance in instances ), []))
612
- pos = [0 ] + list (
613
- itertools .accumulate (
614
- len (instance .beams ) for instance in instances ))
615
- instance_start_and_end : list [tuple [int , int ]] = list (
616
- zip (pos [:- 1 ], pos [1 :]))
617
-
618
- if len (all_beams ) == 0 :
619
- break
620
-
621
- # create the corresponding batch entries for prompt & optional lora
622
- prompts_batch , lora_req_batch = zip (
623
- * [(create_tokens_prompt_from_beam (beam ), beam .lora_request )
624
- for beam in all_beams ])
625
-
626
- # only runs for one step
627
- # we don't need to use tqdm here
628
- output = self .generate (prompts_batch ,
629
- sampling_params = beam_search_params ,
630
- use_tqdm = False ,
631
- lora_request = lora_req_batch )
632
-
633
- for (start , end ), instance in zip (instance_start_and_end ,
634
- instances ):
635
- instance_new_beams = []
636
- for i in range (start , end ):
637
- current_beam = all_beams [i ]
638
- result = output [i ]
639
-
640
- if result .outputs [0 ].logprobs is not None :
641
- # if `result.outputs[0].logprobs` is None, it means
642
- # the sequence is completed because of the max-model-len
643
- # or abortion. we don't need to add it to the new beams.
644
- logprobs = result .outputs [0 ].logprobs [0 ]
645
- for token_id , logprob_obj in logprobs .items ():
646
- new_beam = BeamSearchSequence (
647
- tokens = current_beam .tokens + [token_id ],
648
- logprobs = current_beam .logprobs + [logprobs ],
649
- lora_request = current_beam .lora_request ,
650
- cum_logprob = current_beam .cum_logprob +
651
- logprob_obj .logprob ,
652
- multi_modal_data = current_beam .multi_modal_data ,
653
- mm_processor_kwargs = current_beam .
654
- mm_processor_kwargs )
655
-
656
- if token_id == tokenizer .eos_token_id and \
657
- not ignore_eos :
658
- instance .completed .append (new_beam )
659
- else :
660
- instance_new_beams .append (new_beam )
661
- sorted_beams = sorted (instance_new_beams ,
662
- key = sort_beams_key ,
663
- reverse = True )
664
- instance .beams = sorted_beams [:beam_width ]
610
+ for prompt_start in range (0 , len (prompts ), concurrency_limit ):
611
+ instances_batch = instances [prompt_start :prompt_start +
612
+ concurrency_limit ]
613
+
614
+ token_iter = range (max_tokens )
615
+ if use_tqdm :
616
+ token_iter = tqdm (token_iter ,
617
+ desc = "Beam search" ,
618
+ unit = "token" ,
619
+ unit_scale = False )
620
+ logger .warning (
621
+ "The progress bar shows the upper bound on token steps and "
622
+ "may finish early due to stopping conditions. It does not "
623
+ "reflect instance-level progress." )
624
+ for _ in token_iter :
625
+ all_beams : list [BeamSearchSequence ] = list (
626
+ sum ((instance .beams for instance in instances_batch ), []))
627
+ pos = [0 ] + list (
628
+ itertools .accumulate (
629
+ len (instance .beams ) for instance in instances_batch ))
630
+ instance_start_and_end : list [tuple [int , int ]] = list (
631
+ zip (pos [:- 1 ], pos [1 :]))
632
+
633
+ if len (all_beams ) == 0 :
634
+ break
635
+
636
+ # create corresponding batch entries for prompt & optional lora
637
+ prompts_batch , lora_req_batch = zip (
638
+ * [(create_tokens_prompt_from_beam (beam ), beam .lora_request )
639
+ for beam in all_beams ])
640
+
641
+ # only runs for one step
642
+ # we don't need to use tqdm here
643
+ output = self .generate (prompts_batch ,
644
+ sampling_params = beam_search_params ,
645
+ use_tqdm = False ,
646
+ lora_request = lora_req_batch )
647
+
648
+ for (start , end ), instance in zip (instance_start_and_end ,
649
+ instances_batch ):
650
+ instance_new_beams = []
651
+ for i in range (start , end ):
652
+ current_beam = all_beams [i ]
653
+ result = output [i ]
654
+
655
+ if result .outputs [0 ].logprobs is not None :
656
+ # if `result.outputs[0].logprobs` is None, it means
657
+ # the sequence is completed because of the
658
+ # max-model-len or abortion. we don't need to add
659
+ # it to the new beams.
660
+ logprobs = result .outputs [0 ].logprobs [0 ]
661
+ for token_id , logprob_obj in logprobs .items ():
662
+ new_beam = BeamSearchSequence (
663
+ tokens = current_beam .tokens + [token_id ],
664
+ logprobs = current_beam .logprobs +
665
+ [logprobs ],
666
+ lora_request = current_beam .lora_request ,
667
+ cum_logprob = current_beam .cum_logprob +
668
+ logprob_obj .logprob ,
669
+ multi_modal_data = current_beam .
670
+ multi_modal_data ,
671
+ mm_processor_kwargs = current_beam .
672
+ mm_processor_kwargs )
673
+
674
+ if token_id == tokenizer .eos_token_id and \
675
+ not ignore_eos :
676
+ instance .completed .append (new_beam )
677
+ else :
678
+ instance_new_beams .append (new_beam )
679
+ sorted_beams = sorted (instance_new_beams ,
680
+ key = sort_beams_key ,
681
+ reverse = True )
682
+ instance .beams = sorted_beams [:beam_width ]
665
683
666
684
outputs = []
667
685
for instance in instances :
0 commit comments