@@ -81,6 +81,29 @@ def process_prompt_logprob(self, seq_group: SequenceGroup,
81
81
82
82
def _process_sequence_group_outputs (self , seq_group : SequenceGroup ,
83
83
outputs : SequenceGroupOutput ) -> None :
84
+ sampling_params = seq_group .sampling_params
85
+ if sampling_params .n == 1 and not sampling_params .use_beam_search :
86
+ # only have one output sample
87
+ sample = outputs .samples [0 ]
88
+ # only have one sequence
89
+ seq = seq_group .seqs [0 ]
90
+ seq .append_token_id (sample .output_token , sample .logprobs )
91
+ if sampling_params .detokenize and self .detokenizer :
92
+ new_char_count = self .detokenizer .decode_sequence_inplace (
93
+ seq , sampling_params )
94
+ else :
95
+ new_char_count = 0
96
+ self .stop_checker .maybe_stop_sequence (
97
+ seq ,
98
+ new_char_count ,
99
+ sampling_params ,
100
+ lora_req = seq_group .lora_request ,
101
+ )
102
+ if seq .is_finished ():
103
+ for scheduler in self .scheduler :
104
+ scheduler .free_seq (seq )
105
+ return
106
+
84
107
# Process samples
85
108
samples = outputs .samples
86
109
parent_seqs = seq_group .get_seqs (status = SequenceStatus .RUNNING )
@@ -127,20 +150,20 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
127
150
child_seqs .append ((parent , parent ))
128
151
129
152
for seq , _ in child_seqs :
130
- if seq_group . sampling_params .detokenize and self .detokenizer :
153
+ if sampling_params .detokenize and self .detokenizer :
131
154
new_char_count = self .detokenizer .decode_sequence_inplace (
132
- seq , seq_group . sampling_params )
155
+ seq , sampling_params )
133
156
else :
134
157
new_char_count = 0
135
158
self .stop_checker .maybe_stop_sequence (
136
159
seq ,
137
160
new_char_count ,
138
- seq_group . sampling_params ,
161
+ sampling_params ,
139
162
lora_req = seq_group .lora_request ,
140
163
)
141
164
142
165
# Non-beam search case
143
- if not seq_group . sampling_params .use_beam_search :
166
+ if not sampling_params .use_beam_search :
144
167
# For newly created child sequences, add them to the sequence group
145
168
# and fork them in block manager if they are not finished.
146
169
for seq , parent in child_seqs :
@@ -164,8 +187,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
164
187
# Select the child sequences to keep in the sequence group.
165
188
selected_child_seqs : List [Tuple [Sequence , Optional [Sequence ]]] = []
166
189
unselected_child_seqs : List [Tuple [Sequence , Optional [Sequence ]]] = []
167
- beam_width = seq_group . sampling_params .best_of
168
- length_penalty = seq_group . sampling_params .length_penalty
190
+ beam_width = sampling_params .best_of
191
+ length_penalty = sampling_params .length_penalty
169
192
170
193
# Select the newly finished sequences with the highest scores
171
194
# to replace existing finished sequences.
@@ -219,8 +242,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
219
242
best_running_seq = running_child_seqs [0 ][0 ]
220
243
current_worst_seq = all_finished_seqs [beam_width - 1 ][0 ]
221
244
stop_beam_search = self ._check_beam_search_early_stopping (
222
- seq_group . sampling_params .early_stopping ,
223
- seq_group . sampling_params , best_running_seq , current_worst_seq )
245
+ sampling_params .early_stopping , sampling_params ,
246
+ best_running_seq , current_worst_seq )
224
247
225
248
if stop_beam_search :
226
249
# Stop the beam search and remove all the running sequences from
0 commit comments