Skip to content

Commit 91fce82

Browse files
authored
change the timing of sorting logits (#1309)
1 parent ac5cf86 commit 91fce82

File tree

1 file changed

+16
-24
lines changed

1 file changed

+16
-24
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -102,30 +102,24 @@ def _prune_hidden_states(
102102
hidden_states: torch.Tensor,
103103
input_metadata: InputMetadata,
104104
) -> torch.Tensor:
105-
last_token_indices = {t: [] for t in SamplingType}
105+
last_token_indices = []
106106
start_idx = 0
107107
for i, seq_group in enumerate(input_metadata.seq_groups):
108-
seq_ids, sampling_params = seq_group
109-
sampling_type = sampling_params.sampling_type
108+
seq_ids, _ = seq_group
110109
if i < input_metadata.num_prompts:
111110
assert len(seq_ids) == 1, "Prompt input should have only one seq."
112111
prompt_len = input_metadata.prompt_lens[i]
113-
last_token_indices[sampling_type].append(start_idx + prompt_len -
114-
1)
112+
last_token_indices.append(start_idx + prompt_len - 1)
115113
start_idx += prompt_len
116114
else:
117115
num_seqs = len(seq_ids)
118-
last_token_indices[sampling_type].extend(
119-
range(start_idx, start_idx + num_seqs))
116+
last_token_indices.extend(range(start_idx, start_idx + num_seqs))
120117
start_idx += num_seqs
121118

122-
all_last_token_indices = []
123-
for sampling_type in SamplingType:
124-
all_last_token_indices.extend(last_token_indices[sampling_type])
125-
all_last_token_indices = torch.tensor(all_last_token_indices,
126-
dtype=torch.long,
127-
device=hidden_states.device)
128-
return hidden_states.index_select(0, all_last_token_indices)
119+
last_token_indices = torch.tensor(last_token_indices,
120+
dtype=torch.long,
121+
device=hidden_states.device)
122+
return hidden_states.index_select(0, last_token_indices)
129123

130124

131125
def _get_penalties(
@@ -424,27 +418,26 @@ def _sample(
424418
input_metadata: InputMetadata,
425419
) -> SamplerOutput:
426420
categorized_seq_group_ids = {t: [] for t in SamplingType}
427-
category_num_tokens = {t: 0 for t in SamplingType}
421+
start_idx = 0
422+
categorized_seq_ids = {t: [] for t in SamplingType}
428423
for i, seq_group in enumerate(input_metadata.seq_groups):
429424
seq_ids, sampling_params = seq_group
430425
sampling_type = sampling_params.sampling_type
431426
categorized_seq_group_ids[sampling_type].append(i)
432427
num_seqs = len(seq_ids)
433-
category_num_tokens[sampling_type] += num_seqs
434-
428+
categorized_seq_ids[sampling_type].extend(
429+
range(start_idx, start_idx + num_seqs))
430+
start_idx += num_seqs
435431
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
436-
category_start_idx = 0
437432
for sampling_type in SamplingType:
438433
seq_group_ids = categorized_seq_group_ids[sampling_type]
439434
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
440435
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
441-
num_tokens = category_num_tokens[sampling_type]
436+
num_tokens = len(categorized_seq_ids[sampling_type])
442437
if num_tokens == 0:
443438
continue
444-
category_logprobs = logprobs[category_start_idx:category_start_idx +
445-
num_tokens]
446-
category_probs = probs[category_start_idx:category_start_idx +
447-
num_tokens]
439+
category_logprobs = logprobs[categorized_seq_ids[sampling_type]]
440+
category_probs = probs[categorized_seq_ids[sampling_type]]
448441
if sampling_type == SamplingType.GREEDY:
449442
sample_results = _greedy_sample(seq_groups, category_logprobs)
450443
elif sampling_type == SamplingType.RANDOM:
@@ -497,6 +490,5 @@ def _sample(
497490
sample_idx += num_parent_seqs
498491
result_idx += num_results
499492
assert sample_idx == num_tokens
500-
category_start_idx += num_tokens
501493

502494
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]

0 commit comments

Comments
 (0)