Skip to content

Commit f187877

Browse files
authored
[FIX] Simplify sampler logic (#1156)
1 parent 947b794 commit f187877

File tree

1 file changed

+10
-37
lines changed

1 file changed

+10
-37
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -133,37 +133,22 @@ def _get_penalties(
133133
# Collect the presence and frequency penalties.
134134
presence_penalties: List[float] = []
135135
frequency_penalties: List[float] = []
136-
for i, seq_group in enumerate(input_metadata.seq_groups):
136+
for seq_group in input_metadata.seq_groups:
137137
seq_ids, sampling_params = seq_group
138138
p = sampling_params.presence_penalty
139139
f = sampling_params.frequency_penalty
140-
if i < input_metadata.num_prompts:
141-
# A prompt input.
142-
presence_penalties.append(p)
143-
frequency_penalties.append(f)
144-
else:
145-
# A generation token.
146-
presence_penalties += [p] * len(seq_ids)
147-
frequency_penalties += [f] * len(seq_ids)
140+
presence_penalties += [p] * len(seq_ids)
141+
frequency_penalties += [f] * len(seq_ids)
148142
return presence_penalties, frequency_penalties
149143

150144

151145
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
152146
output_tokens: List[List[int]] = []
153-
for i, seq_group in enumerate(input_metadata.seq_groups):
147+
for seq_group in input_metadata.seq_groups:
154148
seq_ids, _ = seq_group
155-
if i < input_metadata.num_prompts:
156-
# A prompt input.
157-
# NOTE: While the prompt input usually has no output tokens,
158-
# it may have output tokens in the case of recomputation.
159-
seq_id = seq_ids[0]
149+
for seq_id in seq_ids:
160150
seq_data = input_metadata.seq_data[seq_id]
161151
output_tokens.append(seq_data.output_token_ids)
162-
else:
163-
# A generation token.
164-
for seq_id in seq_ids:
165-
seq_data = input_metadata.seq_data[seq_id]
166-
output_tokens.append(seq_data.output_token_ids)
167152
return output_tokens
168153

169154

@@ -221,21 +206,15 @@ def _apply_penalties(
221206
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
222207
# Collect the temperatures for the logits.
223208
temperatures: List[float] = []
224-
for i, seq_group in enumerate(input_metadata.seq_groups):
209+
for seq_group in input_metadata.seq_groups:
225210
seq_ids, sampling_params = seq_group
226211
temperature = sampling_params.temperature
227212
if temperature < _SAMPLING_EPS:
228213
# NOTE: Zero temperature means deterministic sampling
229214
# (i.e., greedy sampling or beam search).
230215
# Set the temperature to 1 to avoid division by zero.
231216
temperature = 1.0
232-
233-
if i < input_metadata.num_prompts:
234-
# A prompt input.
235-
temperatures.append(temperature)
236-
else:
237-
# A generation token.
238-
temperatures += [temperature] * len(seq_ids)
217+
temperatures += [temperature] * len(seq_ids)
239218
return temperatures
240219

241220

@@ -245,21 +224,15 @@ def _get_top_p_top_k(
245224
) -> Tuple[List[float], List[int]]:
246225
top_ps: List[float] = []
247226
top_ks: List[int] = []
248-
for i, seq_group in enumerate(input_metadata.seq_groups):
227+
for seq_group in input_metadata.seq_groups:
249228
seq_ids, sampling_params = seq_group
250229
top_p = sampling_params.top_p
251230
# k should not be greater than the vocab size.
252231
top_k = min(sampling_params.top_k, vocab_size)
253232
# k=-1 means no truncation.
254233
top_k = vocab_size if top_k == -1 else top_k
255-
if i < input_metadata.num_prompts:
256-
# A prompt input.
257-
top_ps.append(top_p)
258-
top_ks.append(top_k)
259-
else:
260-
# A generation token.
261-
top_ps += [top_p] * len(seq_ids)
262-
top_ks += [top_k] * len(seq_ids)
234+
top_ps += [top_p] * len(seq_ids)
235+
top_ks += [top_k] * len(seq_ids)
263236
return top_ps, top_ks
264237

265238

0 commit comments

Comments
 (0)