Skip to content

Commit e88db68

Browse files
authored
[Platform] platform agnostic for EngineArgs initialization (#11225)
Signed-off-by: wangxiyuan <[email protected]>
1 parent 59c9b6e commit e88db68

File tree

9 files changed

+37
-6
lines changed

9 files changed

+37
-6
lines changed

vllm/engine/arg_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ class EngineArgs:
112112
pipeline_parallel_size: int = 1
113113
tensor_parallel_size: int = 1
114114
max_parallel_loading_workers: Optional[int] = None
115-
# NOTE(kzawora): default block size for Gaudi should be 128
116-
# smaller sizes still work, but very inefficiently
117-
block_size: int = 16 if not current_platform.is_hpu() else 128
115+
block_size: Optional[int] = None
118116
enable_prefix_caching: Optional[bool] = None
119117
disable_sliding_window: bool = False
120118
use_v2_block_manager: bool = True
@@ -1036,9 +1034,7 @@ def create_engine_config(self,
10361034
self.enable_prefix_caching = False
10371035

10381036
cache_config = CacheConfig(
1039-
# neuron needs block_size = max_model_len
1040-
block_size=self.block_size if self.device != "neuron" else
1041-
(self.max_model_len if self.max_model_len is not None else 0),
1037+
block_size=self.block_size,
10421038
gpu_memory_utilization=self.gpu_memory_utilization,
10431039
swap_space=self.swap_space,
10441040
cache_dtype=self.kv_cache_dtype,

vllm/platforms/cpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6060

6161
cache_config = vllm_config.cache_config
6262

63+
if cache_config and cache_config.block_size is None:
64+
cache_config.block_size = 16
65+
6366
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
6467

6568
if kv_cache_space >= 0:

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
137137
else:
138138
parallel_config.worker_cls = "vllm.worker.worker.Worker"
139139

140+
cache_config = vllm_config.cache_config
141+
if cache_config and cache_config.block_size is None:
142+
cache_config.block_size = 16
143+
140144

141145
# NVML utils
142146
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/hpu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
4848
if parallel_config.worker_cls == "auto":
4949
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
5050

51+
# NOTE(kzawora): default block size for Gaudi should be 128
52+
# smaller sizes still work, but very inefficiently
53+
cache_config = vllm_config.cache_config
54+
if cache_config and cache_config.block_size is None:
55+
cache_config.block_size = 128
56+
5157
@classmethod
5258
def is_pin_memory_available(cls):
5359
logger.warning("Pin memory is not supported on HPU.")

vllm/platforms/neuron.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
3333
parallel_config.worker_cls = \
3434
"vllm.worker.neuron_worker.NeuronWorker"
3535

36+
cache_config = vllm_config.cache_config
37+
if cache_config:
38+
# neuron needs block_size = max_model_len
39+
vllm_config.cache_config.block_size = \
40+
vllm_config.model_config.max_model_len
41+
3642
@classmethod
3743
def is_pin_memory_available(cls) -> bool:
3844
logger.warning("Pin memory is not supported on Neuron.")

vllm/platforms/openvino.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8787
# check and update cache config
8888
ov_core = ov.Core()
8989
cache_config = vllm_config.cache_config
90+
if cache_config and cache_config.block_size is None:
91+
cache_config.block_size = 16
92+
9093
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
9194
if not OpenVinoPlatform.is_openvino_cpu():
9295
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"

vllm/platforms/rocm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
8484

8585
@classmethod
8686
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
87+
cache_config = vllm_config.cache_config
88+
if cache_config and cache_config.block_size is None:
89+
cache_config.block_size = 16
90+
8791
parallel_config = vllm_config.parallel_config
8892
scheduler_config = vllm_config.scheduler_config
8993
if parallel_config.worker_cls == "auto":

vllm/platforms/tpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def inference_mode(cls):
4646
@classmethod
4747
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
4848
from vllm.config import CompilationLevel
49+
50+
cache_config = vllm_config.cache_config
51+
if cache_config and cache_config.block_size is None:
52+
cache_config.block_size = 16
53+
4954
compilation_config = vllm_config.compilation_config
5055
if compilation_config.level == CompilationLevel.NO_COMPILATION:
5156
# TPU does not support NO_COMPILATION

vllm/platforms/xpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def inference_mode():
5151

5252
@classmethod
5353
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
54+
cache_config = vllm_config.cache_config
55+
if cache_config and cache_config.block_size is None:
56+
cache_config.block_size = 16
57+
5458
# check and update model config
5559
model_config = vllm_config.model_config
5660
if model_config.dtype == torch.bfloat16:

0 commit comments

Comments
 (0)