Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
from vllm.utils import is_list_of, make_zmq_socket, resolve_obj_by_qualname
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
Expand Down Expand Up @@ -194,7 +195,9 @@ def _initialize_kv_caches(
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config

def add_request(self, request: EngineCoreRequest):
def add_request(self, request: Union[EngineCoreRequest, Request]):
if type(request) is EngineCoreRequest:
request = self._preprocess_add_request(request)
"""Add request to the scheduler."""
if pooling_params := request.pooling_params:
supported_pooling_tasks = (
Expand All @@ -203,27 +206,30 @@ def add_request(self, request: EngineCoreRequest):
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
f"Supported tasks: {supported_pooling_tasks}")

if request.mm_hashes is not None:
if request.mm_hashes:
# Here, if hash exists for a multimodal input, then it will be
# fetched from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client cache, so
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
assert request.mm_inputs
updated_mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
assert isinstance(updated_mm_inputs, list)
assert is_list_of(updated_mm_inputs, MultiModalKwargs), (
"Invalid updated mm_inputs in EngineCore.add_request")
request.mm_inputs = updated_mm_inputs

req = Request.from_engine_core_request(request)
if req.use_structured_output:
if request.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)
self.structured_output_manager.grammar_init(request)

if req.kv_transfer_params is not None and (
if request.kv_transfer_params is not None and (
not self.scheduler.get_kv_connector()):
logger.warning("Got kv_transfer_params, but no KVConnector found. "
"Disabling KVTransfer for this request.")

self.scheduler.add_request(req)
self.scheduler.add_request(request)

def abort_requests(self, request_ids: list[str]):
"""Abort requests from the scheduler."""
Expand Down Expand Up @@ -385,6 +391,13 @@ def save_tensorized_model(
self.model_executor.save_tensorized_model(
tensorizer_config=tensorizer_config, )

def _preprocess_add_request(self, request: EngineCoreRequest) -> Request:
"""Preprocess the request.

This function could be directly used in input processing thread to allow
request initialization running in parallel with Model forward"""
return Request.from_engine_core_request(request)


class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
Expand Down Expand Up @@ -766,10 +779,11 @@ def process_input_sockets(self, input_addresses: list[str],
bytes(type_frame.buffer))

# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
if request_type == EngineCoreRequestType.ADD:
request = add_request_decoder.decode(data_frames)
request = self._preprocess_add_request(request)
else:
request = generic_decoder.decode(data_frames)

# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
Expand Down Expand Up @@ -915,7 +929,7 @@ def shutdown(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

def add_request(self, request: EngineCoreRequest):
def add_request(self, request: Union[EngineCoreRequest, Request]):
if self.has_coordinator and request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager, launch_core_engines)
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr

logger = init_logger(__name__)
Expand Down Expand Up @@ -104,7 +105,7 @@ def shutdown(self):
def get_output(self) -> EngineCoreOutputs:
raise NotImplementedError

def add_request(self, request: EngineCoreRequest) -> None:
def add_request(self, request: Union[EngineCoreRequest, Request]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where else are we calling add_request that we need to keep union of 2 types here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least 2 places I noticed. Especially for the later one, not fully sure if we need to further update interface upstream.

An alternative approach is to

  • Add non-Union API as add_request(self, request: Request)
  • Expose EngineCoreRequest to Request conversion as API (e.g. preprocess_ad_request(EngineCoreRequest) -> Request:, and update logic on caller side

WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those should be for the sync engine. we should be able to trigger them directly using pythonic api or bench throughput. do we see similar benefits?

raise NotImplementedError

def profile(self, is_start: bool = True) -> None:
Expand Down Expand Up @@ -238,7 +239,7 @@ def get_output(self) -> EngineCoreOutputs:
outputs, _ = self.engine_core.step()
return outputs.get(0) or EngineCoreOutputs()

def add_request(self, request: EngineCoreRequest) -> None:
def add_request(self, request: Union[EngineCoreRequest, Request]) -> None:
self.engine_core.add_request(request)

def abort_requests(self, request_ids: list[str]) -> None:
Expand Down Expand Up @@ -603,7 +604,7 @@ def call_utility(self, method: str, *args) -> Any:

return future.result()

def add_request(self, request: EngineCoreRequest) -> None:
def add_request(self, request: Union[EngineCoreRequest, Request]) -> None:
if self.is_dp:
self.engines_running = True
self._send_input(EngineCoreRequestType.ADD, request)
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def __init__(
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
current_wave: int = 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding this to the internal Request, maybe convert EngineCoreRequest to a tuple[Request, int]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I might leaning toward adding it in Request instead.

  1. current_wave is an existing field in EngineCoreRequest
  2. If we go with tuple[Request, int], I'm afraid we might end up having tuple[Request, A, B, C, D, ...] in the future :/

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_wave is logically separate from the request, it's only used for coordination purposes at the point that the request is received. Request is the scheduler's state for the request so it doesn't really belong in there. So I don't think what you mentioned will be a concern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about wrapping current_wave in a new class (e.g. RequestEnv)? And the interface would become tuple[Request, RequestEnv]

In this way,

  • non-request data (i.e. current_wave) go into RequestEnv
  • when new similar coming in, we have a place for them without touching the interface

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, gentle nudge @njhill for your thoughts :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New class would be more overhead, tuples are very cheap (small allocs are reused). I don't think we have to worry about a place for other values, I don't think it's likely that there will be, and it's better to do that if/when needed in future. This isn't an external interface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill Appreciate for your inputs.

I've handed over the idea to @linzebing, and most of your comments should had been addressed. Let's move the discussion there. #21627

priority: int = 0,
) -> None:
self.request_id = request_id
self.client_index = client_index
self.current_wave = current_wave
self.priority = priority
self.sampling_params = sampling_params
self.pooling_params = pooling_params
Expand Down Expand Up @@ -131,6 +133,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
sampling_params=request.sampling_params) \
if request.sampling_params else None,
cache_salt=request.cache_salt,
current_wave=request.current_wave,
priority=request.priority,
)

Expand Down