Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
"""Test basic request lifecycle."""

# First request.
engine_core.add_request(make_request())
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0

Expand All @@ -74,7 +75,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 1

# Second request.
engine_core.add_request(make_request())
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 1

Expand All @@ -83,8 +85,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 2

# Add two requests in a row.
engine_core.add_request(make_request())
engine_core.add_request(make_request())
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
assert len(engine_core.scheduler.waiting) == 2
assert len(engine_core.scheduler.running) == 2

Expand All @@ -104,7 +108,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req = make_request()
request_id = req.request_id

engine_core.add_request(req)
engine_core.add_request(*engine_core.preprocess_add_request(req))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
assert engine_core.scheduler.has_unfinished_requests()
Expand All @@ -131,16 +135,16 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req1 = make_request()
req2 = make_request()

engine_core.add_request(req0)
engine_core.add_request(req1)
engine_core.add_request(*engine_core.preprocess_add_request(req0))
engine_core.add_request(*engine_core.preprocess_add_request(req1))
assert len(engine_core.scheduler.waiting) == 2
assert len(engine_core.scheduler.running) == 0

_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2

engine_core.add_request(req2)
engine_core.add_request(*engine_core.preprocess_add_request(req2))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 2

Expand All @@ -166,12 +170,12 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req0 = make_request()
req1 = make_request()
req0.request_id = req1.request_id = "test"
engine_core.add_request(req0)
engine_core.add_request(*engine_core.preprocess_add_request(req0))

while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass

engine_core.add_request(req1)
engine_core.add_request(*engine_core.preprocess_add_request(req1))
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass

Expand Down Expand Up @@ -207,7 +211,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
repetition_penalty=0.1,
stop_token_ids=[1001, 1002],
)
engine_core.add_request(request)
engine_core.add_request(*engine_core.preprocess_add_request(request))

def _check_engine_state():
assert len(engine_core.scheduler.waiting) == 1
Expand All @@ -226,7 +230,7 @@ def _check_engine_state():
top_p=0.99,
top_k=50,
)
engine_core.add_request(request2)
engine_core.add_request(*engine_core.preprocess_add_request(request2))
_check_engine_state()


Expand Down Expand Up @@ -298,9 +302,9 @@ def shutdown(self):

# Add two requests in a row. Each request have 12 prompt tokens.
req0 = make_request_with_max_tokens("0", 5)
engine_core.add_request(req0)
engine_core.add_request(*engine_core.preprocess_add_request(req0))
req1 = make_request_with_max_tokens("1", 5)
engine_core.add_request(req1)
engine_core.add_request(*engine_core.preprocess_add_request(req1))

# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue()[0] is None
Expand Down Expand Up @@ -436,26 +440,30 @@ def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch):

with pytest.raises(TypeError,
match="request_id must be a string, got.*UUID"):
engine_core.add_request(uuid_request)
engine_core.add_request(
*engine_core.preprocess_add_request(uuid_request))

# Test with integer
int_request = make_request()
int_request.request_id = 12345

with pytest.raises(TypeError,
match="request_id must be a string, got.*int"):
engine_core.add_request(int_request)
engine_core.add_request(
*engine_core.preprocess_add_request(int_request))

# Test with None
none_request = make_request()
none_request.request_id = None

with pytest.raises(TypeError,
match="request_id must be a string, got.*NoneType"):
engine_core.add_request(none_request)
engine_core.add_request(
*engine_core.preprocess_add_request(none_request))

# Verify engine is still functional after errors
valid_request = make_request()
engine_core.add_request(valid_request)
engine_core.add_request(
*engine_core.preprocess_add_request(valid_request))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
74 changes: 45 additions & 29 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,12 @@ def _initialize_kv_caches(
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_executor.supported_tasks

def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
def add_request(self, request: Request, request_wave: int = 0):
"""Add request to the scheduler.

`request_wave`: indicate which wave of requests this is expected to
belong to in DP case
"""
# Validate the request_id type.
if not isinstance(request.request_id, str):
raise TypeError(
Expand All @@ -222,27 +226,12 @@ def add_request(self, request: EngineCoreRequest):
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
f"Supported tasks: {supported_pooling_tasks}")

if request.mm_hashes is not None:
# Here, if hash exists for a multimodal input, then it will be
# fetched from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client cache, so
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)

req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)

if req.kv_transfer_params is not None and (
if request.kv_transfer_params is not None and (
not self.scheduler.get_kv_connector()):
logger.warning("Got kv_transfer_params, but no KVConnector found. "
"Disabling KVTransfer for this request.")

self.scheduler.add_request(req)
self.scheduler.add_request(request)

def abort_requests(self, request_ids: list[str]):
"""Abort requests from the scheduler."""
Expand Down Expand Up @@ -414,6 +403,31 @@ def save_tensorized_model(
self.model_executor.save_tensorized_model(
tensorizer_config=tensorizer_config, )

def preprocess_add_request(
self, request: EngineCoreRequest) -> tuple[Request, int]:
"""Preprocess the request.

This function could be directly used in input processing thread to allow
request initialization running in parallel with Model forward
"""
if request.mm_hashes is not None:
assert request.mm_inputs is not None
# Note on thread safety: no race condition.
# `mm_input_cache_server` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)

req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Note on thread safety: no race condition.
# `grammar_init` is only invoked in input processing thread. For
# `structured_output_manager`, each request is independent and
# grammar compilation is async. Scheduler always checks grammar
# compilation status before scheduling request.
self.structured_output_manager.grammar_init(req)
return req, request.current_wave


class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
Expand Down Expand Up @@ -707,7 +721,8 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
"""Dispatch request from client."""

if request_type == EngineCoreRequestType.ADD:
self.add_request(request)
req, request_wave = request
self.add_request(req, request_wave)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.UTILITY:
Expand Down Expand Up @@ -806,10 +821,11 @@ def process_input_sockets(self, input_addresses: list[str],
bytes(type_frame.buffer))

# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
if request_type == EngineCoreRequestType.ADD:
request = add_request_decoder.decode(data_frames)
request = self.preprocess_add_request(request)
else:
request = generic_decoder.decode(data_frames)

# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
Expand Down Expand Up @@ -939,17 +955,17 @@ def shutdown(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

def add_request(self, request: EngineCoreRequest):
if self.has_coordinator and request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
def add_request(self, request: Request, request_wave: int = 0):
if self.has_coordinator and request_wave != self.current_wave:
if request_wave > self.current_wave:
self.current_wave = request_wave
elif not self.engines_running:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave)))

super().add_request(request)
super().add_request(request, request_wave)

def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()

def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)
req, request_wave = self.engine_core.preprocess_add_request(request)
self.engine_core.add_request(req, request_wave)

def abort_requests(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
Expand Down