Skip to content

Commit 6789bb6

Browse files
ahengljhclaude
andcommitted
[Review] Address PR #1303 feedback: reduce scope, fix review comments
- Remove non-PD files: gpu_ar_model_runner.py (debug logging only), omni_ar_scheduler.py and omni_generation_scheduler.py (general compat shims, not PD-specific), pd_server_patch_guide.md (superseded by monkey_patch.py) - Downgrade all KV-DIAG logging from WARNING to DEBUG (omni_llm.py, omni_stage.py) - Strip verbose per-step/per-batch diagnostic scaffolding from omni_llm.py and omni_stage.py - patched_mooncake_connector: call super().add_new_req() instead of skipping; use copy-and-restore pattern in group_kv_pull - omni.py: refactor _detect_pd_separation to single-pass; deduplicate _kv_cfg_to_dict/_normalize_kv_transfer_params into _to_dict() - async_omni.py: unify PD routing merge semantics with sync path - qwen3_omni stage_input_processors: replace hardcoded "0"/"24" layer keys with named constants - qwen3_omni model: document zero-padding safety for PD disaggregation - omni_llm: add comment explaining why _flush_kv_connector_sends reaches into vLLM internals PR scope reduced from 15 to 11 files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 606e7cf commit 6789bb6

File tree

11 files changed

+150
-566
lines changed

11 files changed

+150
-566
lines changed

docs/design/pd_server_patch_guide.md

Lines changed: 0 additions & 185 deletions
This file was deleted.

vllm_omni/core/sched/omni_ar_scheduler.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -68,29 +68,6 @@ def __init__(self, *args, **kwargs):
6868
if getattr(model_config, "async_chunk", False):
6969
self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config)
7070

71-
def _get_routed_experts(self, request: Request):
72-
"""Return routed-experts array for *request*, or ``None``.
73-
74-
Delegates to the parent ``Scheduler`` when it provides this method
75-
(vLLM >= 0.9); otherwise returns ``None`` so that older vLLM
76-
installations don't crash.
77-
"""
78-
parent = getattr(super(), "_get_routed_experts", None)
79-
if parent is not None:
80-
return parent(request)
81-
return None
82-
83-
def _handle_stopped_request(self, request: Request) -> bool:
84-
"""Handle a stopped request — returns ``True`` when truly finished.
85-
86-
Delegates to the parent ``Scheduler`` when it provides this method
87-
(vLLM >= 0.9); otherwise falls back to checking the request status.
88-
"""
89-
parent = getattr(super(), "_handle_stopped_request", None)
90-
if parent is not None:
91-
return parent(request)
92-
return request.status.is_finished
93-
9471
def _get_kv_transfer_criteria(self) -> dict | None:
9572
# Note: vllm_config is available in Scheduler after super().__init__
9673
if not hasattr(self, "vllm_config"):

vllm_omni/core/sched/omni_generation_scheduler.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,6 @@ def __init__(self, *args, **kwargs):
3030
if getattr(model_config, "async_chunk", False):
3131
self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config)
3232

33-
def _get_routed_experts(self, request: Request):
34-
"""Return routed-experts array for *request*, or ``None``.
35-
36-
Delegates to the parent ``Scheduler`` when it provides this method
37-
(vLLM >= 0.9); otherwise returns ``None`` so that older vLLM
38-
installations don't crash.
39-
"""
40-
parent = getattr(super(), "_get_routed_experts", None)
41-
if parent is not None:
42-
return parent(request)
43-
return None
44-
45-
def _handle_stopped_request(self, request: Request) -> bool:
46-
"""Handle a stopped request — returns ``True`` when truly finished.
47-
48-
Delegates to the parent ``Scheduler`` when it provides this method
49-
(vLLM >= 0.9); otherwise falls back to checking the request status.
50-
"""
51-
parent = getattr(super(), "_handle_stopped_request", None)
52-
if parent is not None:
53-
return parent(request)
54-
return request.status.is_finished
55-
5633
def schedule(self) -> SchedulerOutput:
5734
"""Diffusion fast path:
5835
- Feed all input tokens of the request at once

vllm_omni/distributed/kv_transfer/patched_mooncake_connector.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,22 @@ def add_new_req(
149149
kv_transfer_params: dict[str, Any] | None = None,
150150
**kwargs: Any,
151151
) -> None:
152-
"""Override to store a ``PatchedRecvReqMeta`` that remembers the
153-
prefill engine's ``remote_request_id``.
152+
"""Call ``super().add_new_req()`` for all requests, then layer
153+
PD-specific ``PatchedRecvReqMeta`` on top for decode-side
154+
(``load_remote_cache=True``) requests.
154155
155-
When ``kv_transfer_params`` contains ``"remote_request_id"``, we
156-
use it for the ZMQ look-up key. Otherwise we fall back to the
157-
local ``request_id`` (original behaviour).
156+
This ensures any future logic added to the base method is
157+
always executed, while still providing the
158+
``remote_request_id`` mapping needed for PD disaggregation.
158159
"""
160+
# Always call super() so base-class bookkeeping is preserved.
161+
super().add_new_req(
162+
request_id,
163+
local_block_ids,
164+
kv_transfer_params,
165+
**kwargs,
166+
)
167+
159168
kv_transfer_params = kv_transfer_params or {}
160169
load_remote_cache = kv_transfer_params.get(
161170
"do_remote_prefill",
@@ -172,7 +181,8 @@ def add_new_req(
172181
local_block_ids=local_block_ids,
173182
kv_transfer_params=kv_transfer_params,
174183
)
175-
# Store in the same structure the base class uses
184+
# Override the entry created by super() with our patched
185+
# version that carries remote_request_id.
176186
if not hasattr(self, "_reqs_need_recv"):
177187
self._reqs_need_recv = {}
178188
self._reqs_need_recv[request_id] = meta
@@ -183,27 +193,25 @@ def add_new_req(
183193
remote_request_id,
184194
self.engine_id,
185195
)
186-
else:
187-
# Producer side — delegate to original
188-
super().add_new_req(
189-
request_id,
190-
local_block_ids,
191-
kv_transfer_params,
192-
**kwargs,
193-
)
194196

195197
def group_kv_pull(self, metadata: Any | None = None) -> None:
196198
"""Override to use ``meta.remote_request_id`` as the ZMQ look-up
197199
key instead of the local request ID.
198200
199-
After issuing the pull, we record the remote→local mapping in
200-
``self.remote_to_local_req`` so ``receive_kv`` can translate
201-
back.
201+
We build a patched copy of ``_reqs_need_recv`` with
202+
``remote_request_id`` as the key so the base class ZMQ logic
203+
looks up the correct remote KV cache. The original dict is
204+
restored after ``super().group_kv_pull()`` returns to avoid
205+
confusing the base class with unexpected mutations.
202206
"""
203207
if not hasattr(self, "_reqs_need_recv") or not self._reqs_need_recv:
204208
return
205209

206-
for local_id, meta in list(self._reqs_need_recv.items()):
210+
# Build a patched copy; keep the original for restoration.
211+
original_recv = self._reqs_need_recv.copy()
212+
patched_recv: dict[str, Any] = {}
213+
214+
for local_id, meta in original_recv.items():
207215
if isinstance(meta, PatchedRecvReqMeta):
208216
remote_id = meta.remote_request_id
209217
self.remote_to_local_req[remote_id] = local_id
@@ -213,19 +221,30 @@ def group_kv_pull(self, metadata: Any | None = None) -> None:
213221
remote_id,
214222
local_id,
215223
)
216-
# Replace with a fake meta that uses remote_id as request_id
217-
# so the base class ZMQ logic uses remote_id to look up KV
224+
# Use remote_id as key so the base class ZMQ logic
225+
# looks up KV under the prefill engine's request ID.
218226
patched_meta = type(meta)(
219227
request_id=remote_id,
220228
remote_request_id=remote_id,
221229
local_block_ids=meta.local_block_ids,
222230
kv_transfer_params=meta.kv_transfer_params,
223231
)
224-
self._reqs_need_recv[local_id] = patched_meta
232+
patched_recv[remote_id] = patched_meta
233+
else:
234+
patched_recv[local_id] = meta
225235

226-
# Delegate the actual ZMQ transfer to the base class
236+
# Swap in the patched dict, delegate to the base class, then
237+
# restore entries that weren't consumed.
238+
self._reqs_need_recv = patched_recv
227239
super().group_kv_pull(metadata)
228240

241+
# Restore any entries that the base class didn't consume
242+
# (e.g. still pending transfer) back to their original keys.
243+
for remote_id, local_id in list(self.remote_to_local_req.items()):
244+
if remote_id in self._reqs_need_recv:
245+
entry = self._reqs_need_recv.pop(remote_id)
246+
self._reqs_need_recv[local_id] = original_recv.get(local_id, entry)
247+
229248
def receive_kv(self, path: Any = None, req_blocks: Any = None) -> Any:
230249
"""After the base class completes the ZMQ transfer, map
231250
``remote_id`` back to ``local_id`` in any result structures.

vllm_omni/entrypoints/async_omni.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,18 +518,29 @@ async def _process_sequential_results(
518518
"transfer_id": f"xfer-{request_id}",
519519
}
520520

521+
# Merge any user-provided decode-side kv_transfer_params
522+
# first (same semantics as the sync path in omni.py).
523+
existing_kv_params = self._normalize_kv_transfer_params(
524+
sp_next.extra_args.get("kv_transfer_params")
525+
)
526+
if existing_kv_params:
527+
decode_kv_params.update(existing_kv_params)
528+
529+
# Add prefill engine connection info from config
530+
# (only fill in keys that aren't already present).
521531
if self._pd_connector_info:
522532
eid = self._pd_connector_info.get("prefill_engine_id")
523-
if eid is not None:
533+
if eid is not None and "remote_engine_id" not in decode_kv_params:
524534
decode_kv_params["remote_engine_id"] = eid
525535
baddr = self._pd_connector_info.get("prefill_bootstrap_addr")
526-
if baddr is not None:
536+
if baddr is not None and "remote_bootstrap_addr" not in decode_kv_params:
527537
decode_kv_params["remote_bootstrap_addr"] = baddr
528538

529539
kv_from_prefill = self._extract_kv_transfer_params(engine_outputs)
530540
if kv_from_prefill:
531541
decode_kv_params.update(kv_from_prefill)
532542

543+
# Ensure the decode role flags are correct after merges
533544
decode_kv_params["do_remote_prefill"] = True
534545
decode_kv_params["do_remote_decode"] = False
535546

0 commit comments

Comments
 (0)