Skip to content

Commit 391612e

Browse files
[Frontend] Consolidate tokenizer init code (#26276)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 77c95f7 commit 391612e

File tree

8 files changed

+46
-70
lines changed

8 files changed

+46
-70
lines changed

tests/test_inputs.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from vllm.inputs import zip_enc_dec_prompts
88
from vllm.inputs.parse import parse_raw_prompts
99
from vllm.inputs.preprocess import InputPreprocessor
10-
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
1110

1211
pytestmark = pytest.mark.cpu_test
1312

@@ -107,8 +106,7 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
107106
)
108107
def test_preprocessor_text_no_mm_inputs(model_id, prompt):
109108
model_config = ModelConfig(model=model_id)
110-
tokenizer = init_tokenizer_from_configs(model_config)
111-
input_preprocessor = InputPreprocessor(model_config, tokenizer)
109+
input_preprocessor = InputPreprocessor(model_config)
112110

113111
with pytest.raises(ValueError, match="does not support multimodal inputs"):
114112
input_preprocessor.preprocess(prompt)
@@ -129,8 +127,8 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt):
129127
)
130128
def test_preprocessor_always_mm_code_path(model_id, prompt):
131129
model_config = ModelConfig(model=model_id)
132-
tokenizer = init_tokenizer_from_configs(model_config)
133-
input_preprocessor = InputPreprocessor(model_config, tokenizer)
130+
input_preprocessor = InputPreprocessor(model_config)
131+
tokenizer = input_preprocessor.tokenizer
134132

135133
# HF processor adds sep token
136134
sep_token_id = tokenizer.vocab[tokenizer.sep_token]

tests/v1/engine/test_processor_multi_modal_uuids.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ def __init__(self, gb: float):
6565
device_config=DeviceConfig(device="cpu"),
6666
)
6767

68-
# Pass tokenizer=None; InputPreprocessor handles None when
69-
# skip_tokenizer_init is True.
70-
return Processor(vllm_config, tokenizer=None) # type: ignore[arg-type]
68+
return Processor(vllm_config)
7169

7270

7371
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):

vllm/entrypoints/llm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
AnyTokenizer,
7575
MistralTokenizer,
7676
get_cached_tokenizer,
77-
init_tokenizer_from_configs,
7877
)
7978
from vllm.usage.usage_lib import UsageContext
8079
from vllm.utils import Counter, Device, as_iter, is_list_of
@@ -367,11 +366,8 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
367366
def _get_processor(self) -> Processor:
368367
if not hasattr(self, "_processor"):
369368
vllm_config = self.llm_engine.vllm_config
370-
if self.model_config.skip_tokenizer_init:
371-
tokenizer = None
372-
else:
373-
tokenizer = init_tokenizer_from_configs(self.model_config)
374-
self._processor = Processor(vllm_config, tokenizer)
369+
self._processor = Processor(vllm_config)
370+
375371
return self._processor
376372

377373
def get_default_sampling_params(self) -> SamplingParams:

vllm/entrypoints/openai/serving_engine.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing_extensions import TypeIs
1717

1818
from vllm.entrypoints.utils import _validate_truncation_size
19-
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
2019
from vllm.v1.engine import EngineCoreRequest
2120
from vllm.v1.engine.processor import Processor
2221

@@ -272,11 +271,8 @@ def __init__(
272271
async def _get_processor(self) -> Processor:
273272
if not hasattr(self, "_processor"):
274273
vllm_config = await self.engine_client.get_vllm_config()
275-
if self.model_config.skip_tokenizer_init:
276-
tokenizer = None
277-
else:
278-
tokenizer = init_tokenizer_from_configs(self.model_config)
279-
self._processor = Processor(vllm_config, tokenizer)
274+
self._processor = Processor(vllm_config)
275+
280276
return self._processor
281277

282278
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:

vllm/inputs/preprocess.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
MultiModalUUIDDict,
1818
)
1919
from vllm.multimodal.processing import BaseMultiModalProcessor
20-
from vllm.transformers_utils.tokenizer import AnyTokenizer
20+
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
21+
from vllm.utils.jsontree import json_iter_leaves
2122

2223
from .data import (
2324
DecoderOnlyInputs,
@@ -44,17 +45,20 @@ class InputPreprocessor:
4445
def __init__(
4546
self,
4647
model_config: ModelConfig,
47-
tokenizer: Optional[AnyTokenizer],
4848
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
4949
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
5050
) -> None:
5151
super().__init__()
5252

5353
self.model_config = model_config
54-
self.tokenizer = tokenizer
5554
self.mm_registry = mm_registry
5655
self.mm_processor_cache = mm_processor_cache
5756

57+
if model_config.skip_tokenizer_init:
58+
self.tokenizer = None
59+
else:
60+
self.tokenizer = init_tokenizer_from_configs(model_config)
61+
5862
def get_tokenizer(self) -> AnyTokenizer:
5963
if self.tokenizer is None:
6064
raise ValueError(
@@ -273,7 +277,10 @@ def _process_multimodal(
273277
mm_hashes = mm_input["mm_hashes"]
274278

275279
# Validate that all mm items have a string as their hash
276-
if not contains_only_strings(mm_hashes):
280+
contains_only_strings = all(
281+
isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes)
282+
)
283+
if not contains_only_strings:
277284
raise ValueError(
278285
f"mm_hashes must contain only strings, got: {mm_hashes}. "
279286
"This is likely due to an incorrect custom implementation of "
@@ -693,15 +700,3 @@ def preprocess(
693700
def clear_cache(self) -> None:
694701
if self.mm_processor_cache is not None:
695702
self.mm_processor_cache.clear_cache()
696-
697-
698-
# Helper function to validate that a nested dictionary contains
699-
# only strings or list of strings as the leaf values.
700-
def contains_only_strings(obj: object):
701-
if isinstance(obj, str):
702-
return True
703-
if isinstance(obj, list):
704-
return all(isinstance(x, str) for x in obj)
705-
if isinstance(obj, dict):
706-
return all(contains_only_strings(v) for v in obj.values())
707-
return False

vllm/v1/engine/async_llm.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.tasks import SupportedTask
2929
from vllm.tracing import init_tracer
3030
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
31-
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
31+
from vllm.transformers_utils.tokenizer import AnyTokenizer
3232
from vllm.usage.usage_lib import UsageContext
3333
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs
3434
from vllm.v1.engine import EngineCoreRequest
@@ -104,20 +104,8 @@ def __init__(
104104
"logger list; enabling logging without default stat loggers"
105105
)
106106

107-
if self.model_config.skip_tokenizer_init:
108-
self.tokenizer = None
109-
else:
110-
# Tokenizer (+ ensure liveness if running in another process).
111-
self.tokenizer = init_tokenizer_from_configs(
112-
model_config=vllm_config.model_config
113-
)
114-
115107
# Processor (converts Inputs --> EngineCoreRequests).
116-
self.processor = Processor(
117-
vllm_config=vllm_config,
118-
tokenizer=self.tokenizer,
119-
mm_registry=mm_registry,
120-
)
108+
self.processor = Processor(vllm_config, mm_registry=mm_registry)
121109

122110
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
123111
self.output_processor = OutputProcessor(
@@ -257,6 +245,10 @@ def shutdown(self):
257245

258246
cancel_task_threadsafe(getattr(self, "output_handler", None))
259247

248+
@property
249+
def tokenizer(self) -> Optional[AnyTokenizer]:
250+
return self.processor.tokenizer
251+
260252
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
261253
return await self.engine_core.get_supported_tasks_async()
262254

vllm/v1/engine/llm_engine.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from vllm.sampling_params import SamplingParams
2424
from vllm.tasks import SupportedTask
2525
from vllm.tracing import init_tracer
26-
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
26+
from vllm.transformers_utils.tokenizer import AnyTokenizer
2727
from vllm.usage.usage_lib import UsageContext
2828
from vllm.utils import Device
2929
from vllm.v1.engine import EngineCoreRequest
@@ -95,18 +95,8 @@ def __init__(
9595
self.dp_group = None
9696
self.should_execute_dummy_batch = False
9797

98-
if self.model_config.skip_tokenizer_init:
99-
self.tokenizer = None
100-
else:
101-
# Tokenizer (+ ensure liveness if running in another process).
102-
self.tokenizer = init_tokenizer_from_configs(
103-
model_config=vllm_config.model_config
104-
)
105-
10698
# Processor (convert Inputs --> EngineCoreRequests)
107-
self.processor = Processor(
108-
vllm_config=vllm_config, tokenizer=self.tokenizer, mm_registry=mm_registry
109-
)
99+
self.processor = Processor(vllm_config, mm_registry=mm_registry)
110100

111101
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
112102
self.output_processor = OutputProcessor(
@@ -214,6 +204,14 @@ def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
214204
def validate_outputs(cls, outputs, output_type):
215205
return outputs
216206

207+
@property
208+
def tokenizer(self) -> Optional[AnyTokenizer]:
209+
return self.processor.tokenizer
210+
211+
@tokenizer.setter
212+
def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
213+
self.processor.tokenizer = tokenizer
214+
217215
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
218216
return self.engine_core.get_supported_tasks()
219217

vllm/v1/engine/processor.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,13 @@ class Processor:
3737
def __init__(
3838
self,
3939
vllm_config: VllmConfig,
40-
tokenizer: AnyTokenizer,
4140
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
42-
):
41+
) -> None:
4342
self.vllm_config = vllm_config
4443
self.model_config = vllm_config.model_config
4544
self.cache_config = vllm_config.cache_config
4645
self.lora_config = vllm_config.lora_config
4746
self.structured_outputs_config = vllm_config.structured_outputs_config
48-
self.tokenizer = tokenizer
4947

5048
self.generation_config_fields = self.model_config.try_get_generation_config()
5149

@@ -54,11 +52,18 @@ def __init__(
5452

5553
self.input_preprocessor = InputPreprocessor(
5654
self.model_config,
57-
self.tokenizer,
5855
mm_registry,
5956
mm_processor_cache=self.mm_processor_cache,
6057
)
6158

59+
@property
60+
def tokenizer(self) -> Optional[AnyTokenizer]:
61+
return self.input_preprocessor.tokenizer
62+
63+
@tokenizer.setter
64+
def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
65+
self.input_preprocessor.tokenizer = tokenizer
66+
6267
def _validate_logprobs(
6368
self,
6469
params: SamplingParams,
@@ -511,10 +516,8 @@ def _validate_model_input(
511516
else:
512517
raise ValueError(f"The {prompt_type} prompt cannot be empty")
513518

514-
if self.model_config.skip_tokenizer_init:
515-
tokenizer = None
516-
else:
517-
tokenizer = self.tokenizer
519+
tokenizer = self.tokenizer
520+
if tokenizer is not None:
518521
max_input_id = max(prompt_ids or [], default=0)
519522

520523
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while

0 commit comments

Comments
 (0)