@@ -266,11 +266,36 @@ class SchedulerConfig:
266
266
and generated text).
267
267
"""
268
268
269
- def __init__ (self , max_num_batched_tokens : int , max_num_seqs : int ,
270
- max_model_len : int ) -> None :
271
- self .max_num_batched_tokens = max_num_batched_tokens
269
+ def __init__ (
270
+ self ,
271
+ max_num_batched_tokens : Optional [int ],
272
+ max_num_seqs : int ,
273
+ max_model_len : int ,
274
+ ) -> None :
275
+ if max_num_batched_tokens is not None :
276
+ self .max_num_batched_tokens = max_num_batched_tokens
277
+ else :
278
+ # If max_model_len is too short, use 2048 as the default value for
279
+ # higher throughput.
280
+ self .max_num_batched_tokens = max (max_model_len , 2048 )
272
281
self .max_num_seqs = max_num_seqs
273
282
self .max_model_len = max_model_len
283
+ self ._verify_args ()
284
+
285
+ def _verify_args (self ) -> None :
286
+ if self .max_num_batched_tokens < self .max_model_len :
287
+ raise ValueError (
288
+ f"max_num_batched_tokens ({ self .max_num_batched_tokens } ) is "
289
+ f"smaller than max_model_len ({ self .max_model_len } ). "
290
+ "This effectively limits the maximum sequence length to "
291
+ "max_num_batched_tokens and makes vLLM reject longer "
292
+ "sequences. Please increase max_num_batched_tokens or "
293
+ "decrease max_model_len." )
294
+ if self .max_num_batched_tokens < self .max_num_seqs :
295
+ raise ValueError (
296
+ f"max_num_batched_tokens ({ self .max_num_batched_tokens } ) must "
297
+ "be greater than or equal to max_num_seqs "
298
+ f"({ self .max_num_seqs } )." )
274
299
275
300
276
301
_STR_DTYPE_TO_TORCH_DTYPE = {
@@ -350,14 +375,14 @@ def _get_and_verify_max_len(
350
375
max_len_key = getattr (hf_config , key , None )
351
376
if max_len_key is not None :
352
377
derived_max_model_len = min (derived_max_model_len , max_len_key )
378
+ if derived_max_model_len == float ("inf" ):
379
+ raise ValueError (
380
+ "The model's config.json must contain one of the following keys "
381
+ "to determine the original maximum length of the model: "
382
+ f"{ possible_keys } " )
353
383
354
384
rope_scaling = getattr (hf_config , "rope_scaling" , None )
355
385
if rope_scaling is not None :
356
- if derived_max_model_len == float ("inf" ):
357
- raise ValueError (
358
- "When using rope_scaling, the model's config.json must "
359
- "contain one of the following keys to determine the original "
360
- f"maximum length of the model: { possible_keys } " )
361
386
assert "factor" in rope_scaling
362
387
scaling_factor = rope_scaling ["factor" ]
363
388
derived_max_model_len *= scaling_factor
@@ -371,4 +396,4 @@ def _get_and_verify_max_len(
371
396
" in model's config.json). This may lead to incorrect model "
372
397
"outputs or CUDA errors. Make sure the value is correct and "
373
398
"within the model context size." )
374
- return max_model_len
399
+ return int ( max_model_len )
0 commit comments