Skip to content

Commit 52170da

Browse files
maxdebaysernjhill
authored andcommitted
Add input validation for max_sequence_length and max_new_tokens
Adjust max_new_tokens to be less than max_sequence_length and warn the user instead of crashing in warmup logic
1 parent 36b8d86 commit 52170da

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

router/src/server.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,26 @@ async fn do_run<B: BatchType>(
289289
panic!("max_prefill_padding ({}) must be a percentage in the range [0.0, 1.0]", max_prefill_padding)
290290
}
291291

292+
if args.max_new_tokens < 1 {
293+
panic!("max_new_tokens ({}) at least 1", args.max_new_tokens)
294+
}
295+
296+
if args.max_sequence_length < 2 {
297+
panic!("max_sequence_length ({}) must be at least 2 (1 input + 1 output)", args.max_sequence_length)
298+
}
299+
300+
let max_new_tokens = if args.max_new_tokens < args.max_sequence_length {
301+
args.max_new_tokens
302+
} else {
303+
tracing::warn!(
304+
"adjusting max_new_tokens ({}) down to max_sequence_length - 1 ({})",
305+
args.max_new_tokens,
306+
args.max_sequence_length-1
307+
);
308+
args.max_sequence_length - 1
309+
};
310+
311+
292312
let tokenizers = AsyncTokenizer::new(
293313
&args.tokenizer, args.tokenization_workers
294314
);
@@ -318,14 +338,14 @@ async fn do_run<B: BatchType>(
318338
tokenizers.clone(),
319339
args.client,
320340
args.max_sequence_length,
321-
args.max_new_tokens,
341+
max_new_tokens,
322342
);
323343
let shared_state = ServerState {
324344
validation,
325345
batcher,
326346
limit_concurrent_requests: Arc::new(Semaphore::new(args.max_concurrent_requests)),
327347
max_sequence_length: args.max_sequence_length,
328-
max_new_tokens: args.max_new_tokens,
348+
max_new_tokens: max_new_tokens,
329349
seq2seq,
330350
default_include_stop_seqs: args.default_include_stop_seqs,
331351
};
@@ -353,7 +373,7 @@ async fn do_run<B: BatchType>(
353373
// Generated tokens buckets
354374
let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens"));
355375
let max_new_tokens_buckets: Vec<f64> = (0..64)
356-
.map(|x| (args.max_new_tokens as f64 / 64.0) * (x + 1) as f64)
376+
.map(|x| (max_new_tokens as f64 / 64.0) * (x + 1) as f64)
357377
.collect();
358378
// Max new tokens buckets
359379
let max_new_tokens_matcher = Matcher::Full(String::from("tgi_request_max_new_tokens"));

server/text_generation_server/utils/warmup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def __eval_shape(batch_size: int, input_length: int, num_new_tokens: int):
5353

5454
def __safe_eval_shape(batch_size: int, input_length: int, num_new_tokens: int):
5555
try:
56+
if batch_size == 0 or input_length == 0 or num_new_tokens == 0:
57+
# If input or output is 0, this means that max_input_len_for_nt or max_output_len_for_nt
58+
# couldn't find a safe sequence length
59+
print(f">> skipping __eval_shape({batch_size}, {input_length}, {num_new_tokens}) due to zero argument")
60+
return
5661
__eval_shape(batch_size, input_length, num_new_tokens)
5762
except torch.cuda.OutOfMemoryError as e:
5863
print(">> caught OOM error: ", e)

0 commit comments

Comments
 (0)