@@ -289,6 +289,26 @@ async fn do_run<B: BatchType>(
289
289
panic ! ( "max_prefill_padding ({}) must be a percentage in the range [0.0, 1.0]" , max_prefill_padding)
290
290
}
291
291
292
+ if args. max_new_tokens < 1 {
293
+ panic ! ( "max_new_tokens ({}) at least 1" , args. max_new_tokens)
294
+ }
295
+
296
+ if args. max_sequence_length < 2 {
297
+ panic ! ( "max_sequence_length ({}) must be at least 2 (1 input + 1 output)" , args. max_sequence_length)
298
+ }
299
+
300
+ let max_new_tokens = if args. max_new_tokens < args. max_sequence_length {
301
+ args. max_new_tokens
302
+ } else {
303
+ tracing:: warn!(
304
+ "adjusting max_new_tokens ({}) down to max_sequence_length - 1 ({})" ,
305
+ args. max_new_tokens,
306
+ args. max_sequence_length-1
307
+ ) ;
308
+ args. max_sequence_length - 1
309
+ } ;
310
+
311
+
292
312
let tokenizers = AsyncTokenizer :: new (
293
313
& args. tokenizer , args. tokenization_workers
294
314
) ;
@@ -318,14 +338,14 @@ async fn do_run<B: BatchType>(
318
338
tokenizers. clone ( ) ,
319
339
args. client ,
320
340
args. max_sequence_length ,
321
- args . max_new_tokens ,
341
+ max_new_tokens,
322
342
) ;
323
343
let shared_state = ServerState {
324
344
validation,
325
345
batcher,
326
346
limit_concurrent_requests : Arc :: new ( Semaphore :: new ( args. max_concurrent_requests ) ) ,
327
347
max_sequence_length : args. max_sequence_length ,
328
- max_new_tokens : args . max_new_tokens ,
348
+ max_new_tokens : max_new_tokens,
329
349
seq2seq,
330
350
default_include_stop_seqs : args. default_include_stop_seqs ,
331
351
} ;
@@ -353,7 +373,7 @@ async fn do_run<B: BatchType>(
353
373
// Generated tokens buckets
354
374
let generated_tokens_matcher = Matcher :: Full ( String :: from ( "tgi_request_generated_tokens" ) ) ;
355
375
let max_new_tokens_buckets: Vec < f64 > = ( 0 ..64 )
356
- . map ( |x| ( args . max_new_tokens as f64 / 64.0 ) * ( x + 1 ) as f64 )
376
+ . map ( |x| ( max_new_tokens as f64 / 64.0 ) * ( x + 1 ) as f64 )
357
377
. collect ( ) ;
358
378
// Max new tokens buckets
359
379
let max_new_tokens_matcher = Matcher :: Full ( String :: from ( "tgi_request_max_new_tokens" ) ) ;
0 commit comments