Skip to content

Commit ceca060

Browse files
[Deprecation] Deprecate seed=None (#29185)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 75648b1 commit ceca060

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

vllm/engine/arg_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ class EngineArgs:
367367
config_format: str = ModelConfig.config_format
368368
dtype: ModelDType = ModelConfig.dtype
369369
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
370-
seed: int | None = None
370+
seed: int | None = 0
371371
max_model_len: int | None = ModelConfig.max_model_len
372372
cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes
373373
cudagraph_capture_sizes: list[int] | None = (
@@ -1192,6 +1192,12 @@ def create_model_config(self) -> ModelConfig:
11921192
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
11931193
# doesn't affect the user process.
11941194
if self.seed is None:
1195+
logger.warning_once(
1196+
"`seed=None` is equivalent to `seed=0` in V1 Engine. "
1197+
"You will no longer be allowed to pass `None` in v0.13.",
1198+
scope="local",
1199+
)
1200+
11951201
self.seed = 0
11961202
if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
11971203
logger.warning(
@@ -1203,28 +1209,31 @@ def create_model_config(self) -> ModelConfig:
12031209
)
12041210

12051211
if self.disable_mm_preprocessor_cache:
1206-
logger.warning(
1212+
logger.warning_once(
12071213
"`--disable-mm-preprocessor-cache` is deprecated "
12081214
"and will be removed in v0.13. "
12091215
"Please use `--mm-processor-cache-gb 0` instead.",
1216+
scope="local",
12101217
)
12111218

12121219
self.mm_processor_cache_gb = 0
12131220
elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
1214-
logger.warning(
1221+
logger.warning_once(
12151222
"VLLM_MM_INPUT_CACHE_GIB` is deprecated "
12161223
"and will be removed in v0.13. "
12171224
"Please use `--mm-processor-cache-gb %d` instead.",
12181225
envs.VLLM_MM_INPUT_CACHE_GIB,
1226+
scope="local",
12191227
)
12201228

12211229
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
12221230

12231231
if self.enable_multimodal_encoder_data_parallel:
1224-
logger.warning(
1232+
logger.warning_once(
12251233
"--enable-multimodal-encoder-data-parallel` is deprecated "
12261234
"and will be removed in v0.13. "
1227-
"Please use `--mm-encoder-tp-mode data` instead."
1235+
"Please use `--mm-encoder-tp-mode data` instead.",
1236+
scope="local",
12281237
)
12291238

12301239
self.mm_encoder_tp_mode = "data"

0 commit comments

Comments
 (0)