Skip to content

Commit a19bc5c

Browse files
authored
Automatically configure max_num_batched_tokens (#1198)
1 parent 28e616c commit a19bc5c

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

vllm/config.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,36 @@ class SchedulerConfig:
266266
and generated text).
267267
"""
268268

269-
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
270-
max_model_len: int) -> None:
271-
self.max_num_batched_tokens = max_num_batched_tokens
269+
def __init__(
270+
self,
271+
max_num_batched_tokens: Optional[int],
272+
max_num_seqs: int,
273+
max_model_len: int,
274+
) -> None:
275+
if max_num_batched_tokens is not None:
276+
self.max_num_batched_tokens = max_num_batched_tokens
277+
else:
278+
# If max_model_len is too short, use 2048 as the default value for
279+
# higher throughput.
280+
self.max_num_batched_tokens = max(max_model_len, 2048)
272281
self.max_num_seqs = max_num_seqs
273282
self.max_model_len = max_model_len
283+
self._verify_args()
284+
285+
def _verify_args(self) -> None:
286+
if self.max_num_batched_tokens < self.max_model_len:
287+
raise ValueError(
288+
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
289+
f"smaller than max_model_len ({self.max_model_len}). "
290+
"This effectively limits the maximum sequence length to "
291+
"max_num_batched_tokens and makes vLLM reject longer "
292+
"sequences. Please increase max_num_batched_tokens or "
293+
"decrease max_model_len.")
294+
if self.max_num_batched_tokens < self.max_num_seqs:
295+
raise ValueError(
296+
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
297+
"be greater than or equal to max_num_seqs "
298+
f"({self.max_num_seqs}).")
274299

275300

276301
_STR_DTYPE_TO_TORCH_DTYPE = {
@@ -350,14 +375,14 @@ def _get_and_verify_max_len(
350375
max_len_key = getattr(hf_config, key, None)
351376
if max_len_key is not None:
352377
derived_max_model_len = min(derived_max_model_len, max_len_key)
378+
if derived_max_model_len == float("inf"):
379+
raise ValueError(
380+
"The model's config.json must contain one of the following keys "
381+
"to determine the original maximum length of the model: "
382+
f"{possible_keys}")
353383

354384
rope_scaling = getattr(hf_config, "rope_scaling", None)
355385
if rope_scaling is not None:
356-
if derived_max_model_len == float("inf"):
357-
raise ValueError(
358-
"When using rope_scaling, the model's config.json must "
359-
"contain one of the following keys to determine the original "
360-
f"maximum length of the model: {possible_keys}")
361386
assert "factor" in rope_scaling
362387
scaling_factor = rope_scaling["factor"]
363388
derived_max_model_len *= scaling_factor
@@ -371,4 +396,4 @@ def _get_and_verify_max_len(
371396
" in model's config.json). This may lead to incorrect model "
372397
"outputs or CUDA errors. Make sure the value is correct and "
373398
"within the model context size.")
374-
return max_model_len
399+
return int(max_model_len)

vllm/engine/arg_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class EngineArgs:
2525
block_size: int = 16
2626
swap_space: int = 4 # GiB
2727
gpu_memory_utilization: float = 0.90
28-
max_num_batched_tokens: int = 2560
28+
max_num_batched_tokens: Optional[int] = None
2929
max_num_seqs: int = 256
3030
disable_log_stats: bool = False
3131
revision: Optional[str] = None
@@ -34,7 +34,6 @@ class EngineArgs:
3434
def __post_init__(self):
3535
if self.tokenizer is None:
3636
self.tokenizer = self.model
37-
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
3837

3938
@staticmethod
4039
def add_cli_args(

0 commit comments

Comments
 (0)