Skip to content

[Feature][WIP] Support Prefill-Decode disaggregation via vLLM KV transfer#1303

Open
ahengljh wants to merge 3 commits intovllm-project:mainfrom
ahengljh:feat/pd-disaggregation
Open

[Feature][WIP] Support Prefill-Decode disaggregation via vLLM KV transfer#1303
ahengljh wants to merge 3 commits intovllm-project:mainfrom
ahengljh:feat/pd-disaggregation

Conversation

@ahengljh
Copy link

Summary

  • Implements Prefill-Decode (PD) disaggregation for the thinker stage in vLLM-Omni, reusing vLLM's native KV connector infrastructure (MooncakeConnector)
  • Splits the thinker into separate prefill (KV producer) and decode (KV consumer) GPU instances, connected via RDMA/TCP KV cache transfer
  • Aligned with vLLM's disaggregated serving proxy pattern (max_tokens=1 for prefill, per-request kv_transfer_params)

Changes

vllm_omni/entrypoints/omni.py

  • Detect is_prefill_only / is_decode_only stage pairs automatically
  • Override max_tokens=1 for prefill stage so vLLM's KV connector saves the KV cache on FINISHED_LENGTH_CAPPED
  • Construct per-request kv_transfer_params for both prefill (do_remote_decode=True) and decode (do_remote_prefill=True, remote_engine_id, remote_bootstrap_addr)
  • Extract prefill engine connection info from YAML config (_get_pd_connector_info())
  • Auto-duplicate thinker sampling params for the decode stage so callers don't need to know about the internal PD split
  • Remove KVMetadataStore — no longer needed since metadata is constructed from config

qwen3_omni_moe_pd_separation.yaml

  • Fix heterogeneous TP (MooncakeConnector requires matching tensor_parallel_size on both sides)
  • Add explicit engine_id for both PD stages so the orchestrator can reference the prefill engine by name
  • Add kv_connector_extra_config with mooncake_bootstrap_port for each stage
  • Clean GPU layout: GPU 0 = prefill, GPU 1 = decode, GPU 2 = talker + code2wav

How to Test

Prerequisites

  1. Hardware: 3x H100-80G GPUs (or equivalent with enough VRAM for Qwen3-Omni-MoE at TP=1)
  2. Mooncake transfer engine: install via pip install mooncake-transfer-engine or build from source
  3. vllm-omni installed with vLLM as a dependency

Run (offline inference)

python examples/offline_inference/qwen3_omni/end2end.py \
    --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml \
    --query-type text \
    --output-wav output_pd_test

The existing end2end.py works unchanged — the orchestrator auto-duplicates thinker sampling params for the PD decode stage.

GPU Layout (default YAML, TP=1, 3 GPUs)

GPU Stage Role
0 Stage 0 Thinker Prefill (KV producer)
1 Stage 1 Thinker Decode (KV consumer)
2 Stage 2 + 3 Talker + Code2Wav

If the model requires TP=2

Edit the YAML — both PD stages must use the same tensor_parallel_size:

# Stage 0 (prefill)
devices: "0,1"
tensor_parallel_size: 2

# Stage 1 (decode)
devices: "2,3"
tensor_parallel_size: 2

# Stage 2+3 (talker, code2wav)
devices: "4"

This requires 5 GPUs total.

Success indicators in logs

  • PD disaggregation detected: prefill=stage-0, decode=stage-1
  • PD mode: auto-duplicated thinker sampling params for decode stage 1
  • PD routing: injected decode kv_transfer_params for req ...
  • MooncakeConnector logs showing KV transfer activity
  • Output text/audio files in the output directory

Common issues

Symptom Cause Fix
Mooncake transfer engine not found mooncake not installed pip install mooncake-transfer-engine
CUDA out of memory on stage 0 Model too large for TP=1 Set both PD stages to tensor_parallel_size: 2 with 5 GPUs
NotImplementedError: Heterogeneous TP Prefill and decode have different TP sizes Both must use the same tensor_parallel_size
Connection refused on bootstrap port Port conflict or firewall Check ports 25201/25202 are available

Test plan

  • Verify PD disaggregation detection logs appear on startup
  • Verify prefill stage produces KV cache (MooncakeConnector send logs)
  • Verify decode stage loads KV cache (MooncakeConnector receive logs)
  • Verify end-to-end text generation matches non-PD output quality
  • Verify end-to-end audio generation pipeline (prefill -> decode -> talker -> code2wav)
  • Test with --query-type use_audio (multimodal input)

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4fb129bceb

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +701 to +704
sp.max_tokens = 1
if hasattr(sp, "min_tokens"):
try:
sp.min_tokens = 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Force prefill leg to terminate by length in PD mode

This prefill setup only sets max_tokens/min_tokens to 1 but does not neutralize stop conditions, so prompts that hit EOS or stop_token_ids on the first token can finish with finish_reason='stop' instead of length; your own stage-worker note says non-length finishes are treated as not-processed and KV blocks are not sent, which leaves decode reading missing/invalid KV state and producing bad output. Ensure PD prefill requests are configured to always end as length-capped before handoff.

Useful? React with 👍 / 👎.

Comment on lines +508 to +512
decode_kv_params: dict[str, Any] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"transfer_id": f"xfer-{request_id}",
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Merge existing decode KV params in async PD routing

In the async PD branch, decode kv_transfer_params are rebuilt from scratch and then written back, which drops any caller-provided decode-side KV fields; unlike the sync path, this also has no post-merge transfer-id fallback when upstream metadata is partial. In async serving, requests that depend on extra connector parameters (or preserved transfer IDs) can fail remote KV lookup and decode incorrectly.

Useful? React with 👍 / 👎.

@ahengljh ahengljh changed the title [Feature][WIP] Support Prefill-Decode disaggregation via vLLM KV transfer #1 [Feature][WIP] Support Prefill-Decode disaggregation via vLLM KV transfer Feb 12, 2026
Copy link
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WIP design feedback -- the PD disaggregation idea is sound but the implementation has some structural issues worth sorting out before polish.

# expects bare list[int]. Normalise eagerly so we don't hit
# "tuple is not subscriptable" errors later.
req_id = getattr(request, "request_id", None)
if req_id and hasattr(self, "_reqs_need_send"):
Copy link
Contributor

@lishunyang12 lishunyang12 Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This skips calling super().add_new_req(), which means any future logic added to the base method would be silently missed. Would it be possible to call super() and then apply the PD-specific modifications on top?

def add_new_req(
self,
request_id: str,
local_block_ids: list[int],
Copy link
Contributor

@lishunyang12 lishunyang12 Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_kv_pull mutates _reqs_need_recv in-place, which could be tricky to reason about if the base class isn't expecting the modifications. Would a copy-and-return pattern be safer here?

Raises:
ValueError: If multiple PD pairs are detected (not supported).
"""
pd_pairs: list[tuple[int, int]] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is O(n*m) over all stages for PD detection. Fine for 4 stages, but the nested getattr(..., []) with both index-based and id-based matching makes the logic surprisingly hard to follow. Would be clearer as a single pass collecting {stage_id: index} maps first, then matching decode engine_input_source against that map.

"only a single PD pair per pipeline is supported"
)
return pd_pairs[0] if pd_pairs else None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_kv_cfg_to_dict and _normalize_kv_transfer_params are nearly identical 30-line methods that try dict/dataclass/pydantic/omegaconf/vars. They only differ in their fallback return (empty dict vs None). Factor this into a single _to_dict(obj, default=None) utility -- right now every reader has to mentally diff these two methods to see if they are actually different.

while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
_step += 1
_finished_this_step = False
Copy link
Contributor

@lishunyang12 lishunyang12 Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the logger.warning calls here would fire on every step in production, which could get pretty noisy. Would it make sense to use DEBUG level instead for these diagnostic messages?

if not pending:
logger.warning(
"[OmniLLM][KV-DIAG] flush: _reqs_need_send is empty — "
"request_finished() likely did NOT re-add the entry "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_flush_kv_connector_sends reaches deep into vLLM internals: engine_core.scheduler.connector.connector_scheduler._reqs_need_send, then fabricates a SchedulerOutput.make_empty() and calls model_executor.execute_model() directly. This is quite brittle -- any upstream refactor (rename, restructure) silently breaks it. Worth a comment explaining why this cannot be done through the public engine API, and consider pinning a minimum vLLM version.

"do_remote_decode": False,
"do_remote_prefill": True,
"transfer_id": f"xfer-{request_id}",
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The async PD routing builds decode_kv_params from scratch without merging any pre-existing user-provided decode-side KV fields first (unlike the sync path in omni.py which calls _normalize_kv_transfer_params on existing params). This means the two code paths have subtly different merge semantics -- worth unifying into a shared helper.

In PD disaggregation the decode engine only produces embeddings for the
tokens it actually computed. The prefill engine has embeddings for the
full prompt. We concatenate them, dynamically computing any overlap::

Copy link
Contributor

@lishunyang12 lishunyang12 Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_merge_pd_embeddings uses hard-coded layer indices "0" and "24" — I think this could silently break for any model with a different layer count. Would it make sense to derive these from the model config?

# don't have corresponding embeddings (e.g. in PD disaggregation).
target_len = thinker_result_ids.shape[-1]
if thinker_embed.shape[0] < target_len:
pad_len = target_len - thinker_embed.shape[0]
Copy link
Contributor

@lishunyang12 lishunyang12 Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zero-padding thinker_embed means the talker would attend to meaningless (all-zero) embeddings for those positions. Is there a masking mechanism downstream that handles this, or could it affect output quality?

_np = getattr(_kv_meta, "reqs_not_processed", None)
logger.warning(
"[GPUARModelRunner][KV-DIAG] no_forward path: "
"reqs_to_send=%d %s, reqs_to_recv=%d %s, "
Copy link
Contributor

@lishunyang12 lishunyang12 Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note as in omni_llm.py — the [KV-DIAG] logging is all at WARNING level, which would fire on every step in production. Would it make sense to drop these to DEBUG?

@ahengljh
Copy link
Author

WIP design feedback -- the PD disaggregation idea is sound but the implementation has some structural issues worth sorting out before polish.

Thank you for comments and I'll work on them soon.

@hsliuustc0106
Copy link
Collaborator

@vllm-omni-reviewer

@ahengljh ahengljh force-pushed the feat/pd-disaggregation branch from b315e6b to 606e7cf Compare February 25, 2026 08:04
ahengljh added a commit to ahengljh/vllm-omni that referenced this pull request Feb 27, 2026
…iew comments

- Remove non-PD files: gpu_ar_model_runner.py (debug logging only),
  omni_ar_scheduler.py and omni_generation_scheduler.py (general compat
  shims, not PD-specific), pd_server_patch_guide.md (superseded by
  monkey_patch.py)
- Downgrade all KV-DIAG logging from WARNING to DEBUG (omni_llm.py,
  omni_stage.py)
- Strip verbose per-step/per-batch diagnostic scaffolding from
  omni_llm.py and omni_stage.py
- patched_mooncake_connector: call super().add_new_req() instead of
  skipping; use copy-and-restore pattern in group_kv_pull
- omni.py: refactor _detect_pd_separation to single-pass; deduplicate
  _kv_cfg_to_dict/_normalize_kv_transfer_params into _to_dict()
- async_omni.py: unify PD routing merge semantics with sync path
- qwen3_omni stage_input_processors: replace hardcoded "0"/"24" layer
  keys with named constants
- qwen3_omni model: document zero-padding safety for PD disaggregation
- omni_llm: add comment explaining why _flush_kv_connector_sends
  reaches into vLLM internals

PR scope reduced from 15 to 11 files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

This PR implements Prefill-Decode (PD) disaggregation for vLLM-Omni. While the feature is architecturally sound, the implementation has several critical issues that need to be addressed.

Critical Issues:

  • Memory leak: _pd_kv_params_by_req never cleaned up on request failure
  • Silent failures in config parsing with empty dict fallbacks
  • Race conditions in state management despite locks
  • Fragile monkey-patching of vLLM internals
  • Hardcoded defaults (bootstrap port 25201) without documentation

Moderate Issues:

  • Complex state management spread across multiple dictionaries
  • Inconsistent error handling (some raise, some return None)
  • Missing validation for edge cases
  • No version compatibility checks for vLLM

Minor Issues:

  • Debug-level logging for important events
  • Large PR mixing feature + tests makes review difficult

Recommendation: Request changes - address memory leak and silent failures before merge.

from pprint import pformat
from typing import Any, Literal, overload

import huggingface_hub
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Memory leak in request tracking

_pd_kv_params_by_req stores KV params per request but is never cleaned up when requests fail or complete normally. This will cause unbounded memory growth in long-running servers.

Looking at the code:

  • Params are stored in _run_generation when extracting from engine outputs
  • They're popped in _pop_pd_kv_params but only when routing to decode stage
  • If a request fails before reaching decode, or if decode stage errors out, the entry remains forever

Required fix:
Add cleanup in the request finalization path:

# In _run_generation, after request completes/fails:
finally:
    self._drop_pd_kv_params(req_id)

@@ -181,6 +183,22 @@ def __init__(self, model: str, **kwargs: Any) -> None:
logger.info(f"Initializing stages for model: {model}")
self._initialize_stages(model, kwargs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Silent failures in config parsing

_kv_cfg_to_dict returns an empty dict {} when parsing fails:

def _kv_cfg_to_dict(self, kv_cfg: Any) -> dict[str, Any]:
    return self._to_dict(kv_cfg, default={}) or {}

This means if the config is malformed, you get {} instead of an error. Later code checks if not kv_cfg_dict but by then you've lost context about what failed.

Problems:

  1. _get_pd_connector_info silently returns None if parsing fails
  2. _validate_pd_separation_config raises "could not be parsed" but doesn't say why
  3. Users get cryptic errors far from the actual problem

Required fix:
Fail fast with clear errors:

def _kv_cfg_to_dict(self, kv_cfg: Any) -> dict[str, Any]:
    result = self._to_dict(kv_cfg, default=None)
    if result is None:
        raise ValueError(f"Failed to parse kv_transfer_config: {type(kv_cfg).__name__}")
    return result

self._pd_connector_info: dict[str, Any] | None = None
self._pd_kv_params_by_req: dict[str, dict[str, Any]] = {}
self._pd_kv_params_lock = threading.Lock()
if self._pd_separation_pair is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Race condition in KV params management

The lock protects individual dict operations but not the full read-modify-write sequence:

# Thread 1: Store params
with self._pd_kv_params_lock:
    self._pd_kv_params_by_req[req_id] = kv_params

# Thread 2: Pop params (different request path)
with self._pd_kv_params_lock:
    stored = self._pd_kv_params_by_req.pop(req_id, None)

If thread 1 stores params after thread 2 checks, thread 2 gets None but params are still stored (memory leak).

Fix: Use a proper request lifecycle state machine or ensure single-threaded request processing.

logger.info(
"[%s] PD disaggregation detected: prefill=stage-%d, decode=stage-%d",
self._name,
p_id,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Hardcoded default with no documentation

if bootstrap_port is None:
    bootstrap_port = 25201

Why 25201? What if it's in use? No documentation in code or config file.

Required:

  1. Document this default in the YAML config file
  2. Add a constant: DEFAULT_MOONCAKE_BOOTSTRAP_PORT = 25201
  3. Explain port allocation strategy (25201 for prefill, 25202 for decode?)

@@ -0,0 +1,91 @@
"""Monkey-patch vLLM's native ``MooncakeConnector`` with the patched version
that fixes request-ID mismatch in PD disaggregation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: No version compatibility checks

Monkey-patching vLLM's MooncakeConnector with no version checks is dangerous. If vLLM changes the connector API, this will break silently or cause subtle bugs.

Required:

  1. Add version check at module import:
import vllm
if not hasattr(vllm, '__version__'):
    raise RuntimeError("Cannot determine vLLM version for monkey-patch compatibility")
# Check against known compatible versions
  1. Document which vLLM versions this was tested with
  2. Consider upstreaming this to vLLM instead of monkey-patching

)

def _get_default_cache_config(self, cache_backend: str | None) -> dict[str, Any] | None:
if cache_backend == "cache_dit":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Missing validation for stage configuration

The code doesn't validate:

  1. Both PD stages have same tensor_parallel_size (mentioned in PR description as required)
  2. Decode stage's engine_input_source actually points to prefill stage
  3. No other stages depend on the prefill stage (would break routing)

Add validation:

def _validate_pd_separation_config(self):
    # ... existing checks ...
    
    # Check TP sizes match
    p_tp = getattr(p_stage.engine_args, 'tensor_parallel_size', 1)
    d_tp = getattr(d_stage.engine_args, 'tensor_parallel_size', 1)
    if p_tp != d_tp:
        raise ValueError(
            f"PD stages must have matching tensor_parallel_size: "
            f"prefill={p_tp}, decode={d_tp}"
        )

self._pd_separation_pair is not None
and len(sampling_params_list) == len(self.stage_list) - 1
):
p_id, d_id = self._pd_separation_pair
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Silent behavior change

Auto-duplicating sampling params is convenient but changes behavior silently. If a user provides N-1 params by mistake (not knowing about PD), they get unexpected duplication.

Suggestion:
Log at WARNING level when this happens:

logger.warning(
    "[%s] Detected %d sampling params for %d stages (PD mode). "
    "Auto-duplicating thinker params for decode stage %d. "
    "To avoid this warning, provide %d sampling params explicitly.",
    self._name, len(sampling_params_list), len(self.stage_list),
    d_id, len(self.stage_list)
)

# Safety note: Zero-padded positions are safe because the talker's
# ChatML-segment loop (below) only slices embeddings within
# im_start_index boundaries. The padded tail falls outside the last
# assistant segment and is never attended to. Additionally, the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Magic number without validation

target_len = thinker_result_ids.shape[-1]
# ... zero-padding logic ...

The comment says padding is safe, but there's no assertion to catch unexpected cases. What if target_len is 10,000 tokens larger due to a bug?

Add safety check:

if target_len > thinker_embed.shape[1]:
    pad_len = target_len - thinker_embed.shape[1]
    if pad_len > 512:  # Reasonable threshold
        raise ValueError(
            f"Unexpectedly large padding required: {pad_len} tokens. "
            f"This may indicate a bug in PD disaggregation."
        )


Tests the PD detection, validation, config parsing, sampling param
preparation, and routing logic added by the PD disaggregation feature
(issue #1188). All tests run without GPU by using the same mocking
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Test coverage gaps

The tests are comprehensive for happy paths but missing:

  1. What happens when prefill stage crashes mid-request?
  2. What happens when decode stage can't connect to prefill?
  3. What happens when KV transfer times out?
  4. What happens with concurrent requests (race conditions)?
  5. Memory leak test (does _pd_kv_params_by_req grow unbounded)?

These failure modes are critical for production use.

Copy link
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the earlier feedback -- the single-pass detection rewrite, _to_dict dedup, logging downgrade, and super().add_new_req() call all look correct now. A few remaining items:

return pd_pairs[0] if pd_pairs else None

@staticmethod
def _to_dict(obj: Any, default: Any = None) -> dict[str, Any] | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_to_dict silently swallows every conversion failure and falls back to default. But _validate_pd_separation_config calls _kv_cfg_to_dict (which calls _to_dict(kv_cfg, default={})) and then checks if not cfg_dict. If kv_cfg is a broken object that partially converts, you get an incomplete dict that passes validation but has missing keys. Consider raising on conversion failure inside _validate_pd_separation_config.

Comment on lines 607 to 615
# PD safeguard: store kv_transfer_params as a plain-dict backup in the
# payload so it definitely survives pickle even if the msgspec.Struct
# extra_args field is silently dropped.
if "_kv_transfer_params" not in payload:
sp = payload.get("sampling_params")
if sp is not None and hasattr(sp, "extra_args") and sp.extra_args:
kv_tp = sp.extra_args.get("kv_transfer_params")
if kv_tp is not None:
payload["_kv_transfer_params"] = dict(kv_tp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _kv_transfer_params backup-and-restore pattern is a workaround for a msgspec serialization issue. Is this actually reproducible, or was it a hypothesis during debugging? If confirmed, worth opening a vLLM issue. If not, this adds complexity for a speculative fix.

return result

# Preserve the original module name for isinstance checks in vLLM
PatchedMooncakeConnector.__qualname__ = "MooncakeConnector"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the original's qualname instead of hardcoding "MooncakeConnector" -- if the upstream class is ever renamed, this stays correct.

Suggested change
PatchedMooncakeConnector.__qualname__ = "MooncakeConnector"
PatchedMooncakeConnector.__qualname__ = _OriginalMooncakeConnector.__qualname__

merged_emb = torch.cat([p_emb, decode_emb[overlap:]], dim=0)
merged_hid = torch.cat([p_hid, decode_hid[overlap:]], dim=0)

logger.info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logger.info inside _merge_pd_embeddings fires on every request in PD mode. Should be logger.debug.

@ahengljh ahengljh force-pushed the feat/pd-disaggregation branch from bddce52 to df087f3 Compare March 2, 2026 02:44
ahengljh added a commit to ahengljh/vllm-omni that referenced this pull request Mar 2, 2026
…iew comments

- Remove non-PD files: gpu_ar_model_runner.py (debug logging only),
  omni_ar_scheduler.py and omni_generation_scheduler.py (general compat
  shims, not PD-specific), pd_server_patch_guide.md (superseded by
  monkey_patch.py)
- Downgrade all KV-DIAG logging from WARNING to DEBUG (omni_llm.py,
  omni_stage.py)
- Strip verbose per-step/per-batch diagnostic scaffolding from
  omni_llm.py and omni_stage.py
- patched_mooncake_connector: call super().add_new_req() instead of
  skipping; use copy-and-restore pattern in group_kv_pull
- omni.py: refactor _detect_pd_separation to single-pass; deduplicate
  _kv_cfg_to_dict/_normalize_kv_transfer_params into _to_dict()
- async_omni.py: unify PD routing merge semantics with sync path
- qwen3_omni stage_input_processors: replace hardcoded "0"/"24" layer
  keys with named constants
- qwen3_omni model: document zero-padding safety for PD disaggregation
- omni_llm: add comment explaining why _flush_kv_connector_sends
  reaches into vLLM internals

PR scope reduced from 15 to 11 files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ahengljh added a commit to ahengljh/vllm-omni that referenced this pull request Mar 2, 2026
…ests, e2e

- Neutralize stop/stop_token_ids in prefill sampling params to ensure
  finish_reason='length' (prevents MooncakeConnector KV transfer cancel)
- Add _DEFAULT_MOONCAKE_BOOTSTRAP_PORT named constant
- Add tensor_parallel_size validation in PD config check
- Improve error messages with type info for kv_transfer_config parsing
- Add defense-in-depth cleanup of _pd_kv_params_by_req after generation
- Upgrade auto-duplication log to WARNING with suppression hint
- Downgrade per-request PD routing/trace logs from INFO to DEBUG
- Add vLLM version compatibility warning in monkey_patch.py
- Use dynamic __qualname__ from original MooncakeConnector
- Add padding threshold warning (512 tokens) in model zero-padding
- Add clarifying comments on threading model, merge order, save-patch-restore
- Add unit tests: stop neutralization, failure/leak cleanup, TP validation
- Add PD e2e tests for both text and audio modalities (offline + online)
- Add PD CI stage config with load_format: dummy

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ahengljh ahengljh force-pushed the feat/pd-disaggregation branch from 217d30d to 68e0f9e Compare March 2, 2026 06:28
Split the thinker stage into separate prefill and decode instances that
communicate via vLLM's native KV transfer (MooncakeConnector). The prefill
engine processes prompts and saves KV cache; the decode engine loads the
cache and generates tokens.

Key changes:
- PD detection, validation, and routing in OmniBase and AsyncOmni
- Prefill sampling params: max_tokens=1, neutralize stop conditions
- Patched MooncakeConnector with remote_request_id for cross-engine KV lookup
- Monkey-patch infrastructure with vLLM version compatibility check
- Embedding merge (prefill + decode) in thinker2talker stage processor
- Zero-padding safety with threshold warning in talker model
- Defense-in-depth cleanup of KV params after generation
- Unit tests for PD detection, validation, routing, stop neutralization,
  failure modes, memory leak prevention, and TP validation
- E2E tests for both text and audio modalities (offline + online)
- PD CI stage config with load_format: dummy

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ahengljh ahengljh force-pushed the feat/pd-disaggregation branch from 68e0f9e to 0483a26 Compare March 2, 2026 06:44
ahengljh and others added 2 commits March 2, 2026 17:11
… flow

Update test mocking infrastructure to align with the refactored OmniBase
initialization chain:
- Mock load_and_resolve_stage_configs instead of removed load_stage_configs_from_model
- Mock omni_snapshot_download, initialize_orchestrator_connectors, _start_stages,
  _wait_for_stages_ready, and try_send_via_connector for full init bypass
- Replace _FakeOrchestratorMetrics with _FakeOrchestratorAggregator matching
  the current class interface (new methods, updated signatures)
- Add missing final_output/final_output_type attrs in test_stage_payload_includes_pd_flags

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The three failure mode tests (error_path, completion, multiple_requests)
hang because _run_generation's error handler calls _drop_pd_kv_params
but does not increment completed_requests, causing an infinite loop.
Remove for now until the production error-handling path is fixed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants