@@ -133,37 +133,22 @@ def _get_penalties(
133
133
# Collect the presence and frequency penalties.
134
134
presence_penalties : List [float ] = []
135
135
frequency_penalties : List [float ] = []
136
- for i , seq_group in enumerate ( input_metadata .seq_groups ) :
136
+ for seq_group in input_metadata .seq_groups :
137
137
seq_ids , sampling_params = seq_group
138
138
p = sampling_params .presence_penalty
139
139
f = sampling_params .frequency_penalty
140
- if i < input_metadata .num_prompts :
141
- # A prompt input.
142
- presence_penalties .append (p )
143
- frequency_penalties .append (f )
144
- else :
145
- # A generation token.
146
- presence_penalties += [p ] * len (seq_ids )
147
- frequency_penalties += [f ] * len (seq_ids )
140
+ presence_penalties += [p ] * len (seq_ids )
141
+ frequency_penalties += [f ] * len (seq_ids )
148
142
return presence_penalties , frequency_penalties
149
143
150
144
151
145
def _get_output_tokens (input_metadata : InputMetadata ) -> List [List [int ]]:
152
146
output_tokens : List [List [int ]] = []
153
- for i , seq_group in enumerate ( input_metadata .seq_groups ) :
147
+ for seq_group in input_metadata .seq_groups :
154
148
seq_ids , _ = seq_group
155
- if i < input_metadata .num_prompts :
156
- # A prompt input.
157
- # NOTE: While the prompt input usually has no output tokens,
158
- # it may have output tokens in the case of recomputation.
159
- seq_id = seq_ids [0 ]
149
+ for seq_id in seq_ids :
160
150
seq_data = input_metadata .seq_data [seq_id ]
161
151
output_tokens .append (seq_data .output_token_ids )
162
- else :
163
- # A generation token.
164
- for seq_id in seq_ids :
165
- seq_data = input_metadata .seq_data [seq_id ]
166
- output_tokens .append (seq_data .output_token_ids )
167
152
return output_tokens
168
153
169
154
@@ -221,21 +206,15 @@ def _apply_penalties(
221
206
def _get_temperatures (input_metadata : InputMetadata ) -> List [float ]:
222
207
# Collect the temperatures for the logits.
223
208
temperatures : List [float ] = []
224
- for i , seq_group in enumerate ( input_metadata .seq_groups ) :
209
+ for seq_group in input_metadata .seq_groups :
225
210
seq_ids , sampling_params = seq_group
226
211
temperature = sampling_params .temperature
227
212
if temperature < _SAMPLING_EPS :
228
213
# NOTE: Zero temperature means deterministic sampling
229
214
# (i.e., greedy sampling or beam search).
230
215
# Set the temperature to 1 to avoid division by zero.
231
216
temperature = 1.0
232
-
233
- if i < input_metadata .num_prompts :
234
- # A prompt input.
235
- temperatures .append (temperature )
236
- else :
237
- # A generation token.
238
- temperatures += [temperature ] * len (seq_ids )
217
+ temperatures += [temperature ] * len (seq_ids )
239
218
return temperatures
240
219
241
220
@@ -245,21 +224,15 @@ def _get_top_p_top_k(
245
224
) -> Tuple [List [float ], List [int ]]:
246
225
top_ps : List [float ] = []
247
226
top_ks : List [int ] = []
248
- for i , seq_group in enumerate ( input_metadata .seq_groups ) :
227
+ for seq_group in input_metadata .seq_groups :
249
228
seq_ids , sampling_params = seq_group
250
229
top_p = sampling_params .top_p
251
230
# k should not be greater than the vocab size.
252
231
top_k = min (sampling_params .top_k , vocab_size )
253
232
# k=-1 means no truncation.
254
233
top_k = vocab_size if top_k == - 1 else top_k
255
- if i < input_metadata .num_prompts :
256
- # A prompt input.
257
- top_ps .append (top_p )
258
- top_ks .append (top_k )
259
- else :
260
- # A generation token.
261
- top_ps += [top_p ] * len (seq_ids )
262
- top_ks += [top_k ] * len (seq_ids )
234
+ top_ps += [top_p ] * len (seq_ids )
235
+ top_ks += [top_k ] * len (seq_ids )
263
236
return top_ps , top_ks
264
237
265
238
0 commit comments