@@ -102,30 +102,24 @@ def _prune_hidden_states(
102102 hidden_states : torch .Tensor ,
103103 input_metadata : InputMetadata ,
104104) -> torch .Tensor :
105- last_token_indices = { t : [] for t in SamplingType }
105+ last_token_indices = []
106106 start_idx = 0
107107 for i , seq_group in enumerate (input_metadata .seq_groups ):
108- seq_ids , sampling_params = seq_group
109- sampling_type = sampling_params .sampling_type
108+ seq_ids , _ = seq_group
110109 if i < input_metadata .num_prompts :
111110 assert len (seq_ids ) == 1 , "Prompt input should have only one seq."
112111 prompt_len = input_metadata .prompt_lens [i ]
113- last_token_indices [sampling_type ].append (start_idx + prompt_len -
114- 1 )
112+ last_token_indices .append (start_idx + prompt_len - 1 )
115113 start_idx += prompt_len
116114 else :
117115 num_seqs = len (seq_ids )
118- last_token_indices [sampling_type ].extend (
119- range (start_idx , start_idx + num_seqs ))
116+ last_token_indices .extend (range (start_idx , start_idx + num_seqs ))
120117 start_idx += num_seqs
121118
122- all_last_token_indices = []
123- for sampling_type in SamplingType :
124- all_last_token_indices .extend (last_token_indices [sampling_type ])
125- all_last_token_indices = torch .tensor (all_last_token_indices ,
126- dtype = torch .long ,
127- device = hidden_states .device )
128- return hidden_states .index_select (0 , all_last_token_indices )
119+ last_token_indices = torch .tensor (last_token_indices ,
120+ dtype = torch .long ,
121+ device = hidden_states .device )
122+ return hidden_states .index_select (0 , last_token_indices )
129123
130124
131125def _get_penalties (
@@ -424,27 +418,26 @@ def _sample(
424418 input_metadata : InputMetadata ,
425419) -> SamplerOutput :
426420 categorized_seq_group_ids = {t : [] for t in SamplingType }
427- category_num_tokens = {t : 0 for t in SamplingType }
421+ start_idx = 0
422+ categorized_seq_ids = {t : [] for t in SamplingType }
428423 for i , seq_group in enumerate (input_metadata .seq_groups ):
429424 seq_ids , sampling_params = seq_group
430425 sampling_type = sampling_params .sampling_type
431426 categorized_seq_group_ids [sampling_type ].append (i )
432427 num_seqs = len (seq_ids )
433- category_num_tokens [sampling_type ] += num_seqs
434-
428+ categorized_seq_ids [sampling_type ].extend (
429+ range (start_idx , start_idx + num_seqs ))
430+ start_idx += num_seqs
435431 seq_outputs_dict : Dict [int , List [SequenceOutputs ]] = {}
436- category_start_idx = 0
437432 for sampling_type in SamplingType :
438433 seq_group_ids = categorized_seq_group_ids [sampling_type ]
439434 seq_groups = [input_metadata .seq_groups [i ] for i in seq_group_ids ]
440435 is_prompts = [i < input_metadata .num_prompts for i in seq_group_ids ]
441- num_tokens = category_num_tokens [sampling_type ]
436+ num_tokens = len ( categorized_seq_ids [sampling_type ])
442437 if num_tokens == 0 :
443438 continue
444- category_logprobs = logprobs [category_start_idx :category_start_idx +
445- num_tokens ]
446- category_probs = probs [category_start_idx :category_start_idx +
447- num_tokens ]
439+ category_logprobs = logprobs [categorized_seq_ids [sampling_type ]]
440+ category_probs = probs [categorized_seq_ids [sampling_type ]]
448441 if sampling_type == SamplingType .GREEDY :
449442 sample_results = _greedy_sample (seq_groups , category_logprobs )
450443 elif sampling_type == SamplingType .RANDOM :
@@ -497,6 +490,5 @@ def _sample(
497490 sample_idx += num_parent_seqs
498491 result_idx += num_results
499492 assert sample_idx == num_tokens
500- category_start_idx += num_tokens
501493
502494 return [seq_outputs_dict [i ] for i in range (len (input_metadata .seq_groups ))]
0 commit comments