Skip to content

Commit 71ebc50

Browse files
committed
Put MM initialization back to EngineCore.add_request to avoid race condition
Signed-off-by: Jialin Ouyang <[email protected]>
1 parent bccab82 commit 71ebc50

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

vllm/v1/engine/core.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
from vllm.logger import init_logger
2424
from vllm.logging_utils.dump_input import dump_engine_exception
2525
from vllm.lora.request import LoRARequest
26+
from vllm.multimodal.inputs import MultiModalKwargs
2627
from vllm.transformers_utils.config import (
2728
maybe_register_config_serialize_by_value)
28-
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
29+
from vllm.utils import is_list_of, make_zmq_socket, resolve_obj_by_qualname
2930
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
3031
unify_kv_cache_configs)
3132
from vllm.v1.core.sched.interface import SchedulerInterface
@@ -203,6 +204,20 @@ def add_request(self, request: Request):
203204
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
204205
f"Supported tasks: {supported_pooling_tasks}")
205206

207+
if request.mm_hashes is not None:
208+
# Here, if hash exists for a multimodal input, then it will be
209+
# fetched from the cache, else it will be added to the cache.
210+
# Note that the cache here is mirrored with the client cache, so
211+
# anything that has a hash must have a HIT cache entry here
212+
# as well.
213+
assert request.mm_inputs is not None
214+
updated_mm_inputs = self.mm_input_cache_server.get_and_update_p1(
215+
request.mm_inputs, request.mm_hashes)
216+
assert isinstance(updated_mm_inputs, list)
217+
assert is_list_of(updated_mm_inputs, MultiModalKwargs), (
218+
"mm_inputs was not updated in EngineCore.add_request")
219+
request.mm_inputs = updated_mm_inputs
220+
206221
if request.use_structured_output:
207222
# Start grammar compilation asynchronously
208223
self.structured_output_manager.grammar_init(request)
@@ -830,16 +845,6 @@ def _post_process_add_request(self, request: EngineCoreRequest) -> Request:
830845
831846
This call would be executed in parallel with Model forward which
832847
relaxes request preparation works out from critical path."""
833-
if request.mm_hashes is not None:
834-
# Here, if hash exists for a multimodal input, then it will be
835-
# fetched from the cache, else it will be added to the cache.
836-
# Note that the cache here is mirrored with the client cache, so
837-
# anything that has a hash must have a HIT cache entry here
838-
# as well.
839-
assert request.mm_inputs is not None
840-
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
841-
request.mm_inputs, request.mm_hashes)
842-
843848
return Request.from_engine_core_request(request)
844849

845850

0 commit comments

Comments
 (0)