@@ -141,9 +141,11 @@ def __init__(
141141
142142 if rq_state is None :
143143 rq_state = {}
144- self .is_requeued = True
145- else :
146144 self .is_requeued = False
145+ self .rq_new_tokens = 0
146+ else :
147+ self .is_requeued = True
148+ self .rq_new_tokens = rq_state ["rq_new_tokens" ]
147149
148150 self .generator = None
149151 self .pagetable = None
@@ -189,7 +191,6 @@ def __init__(
189191 self .sequences .append (seq )
190192
191193 # Generation parameters
192- assert max_new_tokens >= 2
193194 self .max_new_tokens = max_new_tokens - 1 or 100
194195 self .min_new_tokens = min_new_tokens
195196 self .new_tokens = 0 if self .prefix_token is None else - 1
@@ -589,7 +590,7 @@ def emit(
589590 self .is_finished = True
590591 r .update ({
591592 "full_completion" : self .full_completion ,
592- "new_tokens" : self .new_tokens ,
593+ "new_tokens" : self .rq_new_tokens + self . new_tokens ,
593594 "prompt_tokens" : len (self .sequences [0 ].input_ids ),
594595 "time_enqueued" : self .time_enqueued ,
595596 "time_prefill" : self .time_prefill ,
@@ -796,6 +797,7 @@ def prepare_for_requeue(self):
796797 "time_enqueued" : self .time_enqueued ,
797798 "time_prefill" : self .time_prefill ,
798799 "time_generate" : self .time_generate ,
800+ "rq_new_tokens" : self .new_tokens - 1
799801 }
800802
801803 rq_job = Job (
0 commit comments