Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6cde323
cfg multi-stage
princepride Feb 21, 2026
4112f54
_forward_parent_with_cfg saved in parent_result['engine_outputs'] && …
princepride Feb 22, 2026
428e774
only check whether parent have complete
princepride Feb 22, 2026
a66e55f
change to deepcopy
princepride Feb 23, 2026
a32a813
update bagel e2e test
princepride Feb 23, 2026
883549d
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 23, 2026
321e5a4
add docs
princepride Feb 23, 2026
3ae335f
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 23, 2026
b233bb5
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 24, 2026
f6b0195
Remove unused completed_requests param from _forward_parent_with_cfg
princepride Feb 24, 2026
0156ff7
Build reverse index for O(1) companion-to-parent lookup
princepride Feb 24, 2026
18ccae3
Add single-threaded assumption note for source_outputs_override swap
princepride Feb 24, 2026
c0a1eb7
Remove unused is_cfg_companion_request and get_parent_request_id
princepride Feb 24, 2026
cea6766
Clarify why cfg_kv_collect_func is resolved on the parent side
princepride Feb 24, 2026
9dba25f
Fix _get_negative_prompt to treat empty string as absent
princepride Feb 24, 2026
6cca43d
Add comment explaining max_batch_size assumption in bagel.yaml
princepride Feb 24, 2026
9b9f06e
fix pre-commit
princepride Feb 24, 2026
a3a10f2
remove omni cfg_kv_collect_func
princepride Feb 24, 2026
06f2165
move load_func_from_config to stage_utils
princepride Feb 24, 2026
5ff56f6
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 24, 2026
3f90cdd
wrap cfg processor
princepride Feb 26, 2026
a694646
wrap cfg processor
princepride Feb 26, 2026
30b5db4
wrap cfg processor
princepride Feb 26, 2026
b50abab
wrap cfg processor
princepride Feb 26, 2026
a6fc71d
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 26, 2026
ebbc36f
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 27, 2026
6a04e70
correct process completed_requests if transfer data to next stage failed
princepride Feb 28, 2026
c2eb553
idiomatic return value
princepride Feb 28, 2026
8ccf529
if not sent_via_connector, throw error
princepride Feb 28, 2026
9c688d6
add test_cfg_companion_tracker and fix bagel exception handlers
princepride Mar 1, 2026
46ef8c5
Merge branch 'main' into feature/cfg-multi-stage
princepride Mar 1, 2026
753a19a
test kv_transfer_manager methods expand
princepride Mar 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/configuration/stage_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ Each stage in the `stage_args` list contains the following configuration options

A unique identifier for each stage in the multi-stage pipeline. Stages are numbered sequentially starting from 0, and this ID is used to reference stages in inter-stage dependencies (e.g., `engine_input_source`).

### `prompt_expand_func` (Optional)

A custom Python function hook for the LLM stage (Stage 0) that expands a single incoming prompt object into multiple prompts. This is primarily used for multi-modal Classifier-Free Guidance (CFG), where it generates the necessary companion requests (like a negative text prompt) and tags them with internal roles (e.g., `cfg_text`). This ensures the upstream LLM generates the needed contextual hidden states for both the conditional and unconditional generations simultaneously.

### `cfg_kv_collect_func` (Optional)

A custom Python function hook for downstream diffusion stages (Stage 1+) to collect, map, and process the KV caches transferred from the companion requests fired by `prompt_expand_func`. It aggregates the hidden condition states cleanly (e.g., binding them as `cfg_text_past_key_values` and `cfg_text_kv_metadata`), allowing the diffusion runtime to perform CFG smoothly without redundantly evaluating text paths on the DiT workers.

### `runtime`

Configuration for disaggregated execution of the stage, controlling how the stage is deployed and executed.
Expand Down
7 changes: 7 additions & 0 deletions docs/design/architecture_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ The framework achieves high performance through several optimization techniques:
* **Quantization:** Supports various quantization implementations including FP8 and AWQ.
* **FusedOps:** Allows for custom and third-party integration.

### Classifier-Free Guidance (CFG) Companion Flow

vLLM-Omni natively models Classifier-Free Guidance (CFG) across disaggregated multi-stage setups via a "companion request" paradigm, eliminating redundant textual/multimodal context computation boundaries:
1. **Prompt Expansion:** In the initial autoregressive (AR) stage, a customized `prompt_expand_func` hook intercepts incoming generation prompts and pairs them directly with negative companion prompts (e.g., a default negative prompt) on the fly, tagging the secondary prompt with a specific internal role (`cfg_text`).
2. **Synchronized KV Cache Transfer:** The AR stage evaluates both the primary and companion sequence batches concurrently. The `OmniConnector` captures these specific structural dependencies and reliably passes the positive and negative outcome KV caches seamlessly across stage boundaries via shared memory or network protocols.
3. **KV Cache Collection & Injection:** Upon reaching the downstream Diffusion (DiT) Engine, an assigned `cfg_kv_collect_func` automatically intercepts the mapped companion caches (`cfg_text_past_key_values`). These auxiliary dependencies are natively gathered and seamlessly bound to the primary generation sequence variables, enabling the DiT Engine to cleanly implement cross-attention CFG guidance over accurate conditioning and unconditioning structures in parallel.

### Flexibility and Usability

vLLM-Omni is designed to be flexible and straightforward for users:
Expand Down
9 changes: 7 additions & 2 deletions examples/offline_inference/bagel/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def parse_args():
parser.add_argument("--cfg-text-scale", type=float, default=4.0, help="Text CFG scale (default: 4.0)")
parser.add_argument("--cfg-img-scale", type=float, default=1.5, help="Image CFG scale (default: 1.5)")
parser.add_argument(
"--negative-prompt", type=str, default=None, help="Negative prompt (not yet supported, reserved for future)"
"--negative-prompt", type=str, default=None, help="Negative prompt for CFG (default: empty prompt)"
)

args = parser.parse_args()
Expand Down Expand Up @@ -162,6 +162,8 @@ def main():
# text2img
final_prompt_text = f"<|im_start|>{p}<|im_end|>"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
if args.negative_prompt is not None:
prompt_dict["negative_prompt"] = args.negative_prompt
formatted_prompts.append(prompt_dict)

params_list = omni.default_sampling_params_list
Expand All @@ -170,10 +172,13 @@ def main():
if len(params_list) > 1:
diffusion_params = params_list[1]
diffusion_params.num_inference_steps = args.steps # type: ignore
diffusion_params.extra_args = { # type: ignore
extra = {
"cfg_text_scale": args.cfg_text_scale,
"cfg_img_scale": args.cfg_img_scale,
}
if args.negative_prompt is not None:
extra["negative_prompt"] = args.negative_prompt
diffusion_params.extra_args = extra # type: ignore

omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))

Expand Down
5 changes: 4 additions & 1 deletion tests/diffusion/test_diffusion_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def _make_runner(cache_backend, cache_backend_name: str, enable_cache_dit_summar
enable_cache_dit_summary=enable_cache_dit_summary,
parallel_config=SimpleNamespace(use_hsdp=False),
)
runner.kv_transfer_manager = SimpleNamespace(receive_kv_cache=lambda req, target_device: None)
runner.kv_transfer_manager = SimpleNamespace(
receive_kv_cache=lambda req, target_device=None: None,
receive_multi_kv_cache=lambda req, cfg_kv_collect_func=None, target_device=None: None,
)
return runner


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
stage_args:
- stage_id: 0
stage_type: llm
prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
runtime:
devices: "0"
max_batch_size: 1
Expand Down Expand Up @@ -39,6 +40,7 @@ stage_args:
to_stage_1: mooncake_connector
- stage_id: 1
stage_type: diffusion
cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
runtime:
devices: "0"
max_batch_size: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
stage_args:
- stage_id: 0
stage_type: llm
prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
runtime:
devices: "0"
max_batch_size: 1
Expand Down Expand Up @@ -38,6 +39,7 @@ stage_args:

- stage_id: 1
stage_type: diffusion
cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
runtime:
devices: "0"
max_batch_size: 1
Expand Down
24 changes: 14 additions & 10 deletions tests/e2e/offline_inference/test_bagel_text2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@
# "Generated with seed=52, num_inference_steps=15,
# prompt='A futuristic city skyline at twilight, cyberpunk style'"
REFERENCE_PIXELS = [
{"position": (100, 100), "rgb": (68, 107, 134)},
{"position": (400, 50), "rgb": (95, 139, 166)},
{"position": (700, 100), "rgb": (99, 122, 151)},
{"position": (150, 400), "rgb": (111, 125, 153)},
{"position": (512, 512), "rgb": (97, 107, 131)},
{"position": (700, 400), "rgb": (48, 64, 98)},
{"position": (100, 700), "rgb": (79, 63, 84)},
{"position": (400, 700), "rgb": (40, 58, 79)},
{"position": (700, 700), "rgb": (60, 75, 103)},
{"position": (256, 256), "rgb": (97, 128, 156)},
{"position": (100, 100), "rgb": (49, 96, 134)},
{"position": (400, 50), "rgb": (63, 127, 167)},
{"position": (700, 100), "rgb": (70, 101, 141)},
{"position": (150, 400), "rgb": (115, 90, 150)},
{"position": (512, 512), "rgb": (98, 86, 119)},
{"position": (700, 400), "rgb": (29, 42, 91)},
{"position": (100, 700), "rgb": (47, 50, 88)},
{"position": (400, 700), "rgb": (36, 52, 91)},
{"position": (700, 700), "rgb": (45, 58, 99)},
{"position": (256, 256), "rgb": (62, 94, 135)},
]

# Maximum allowed difference per color channel
Expand Down Expand Up @@ -80,6 +80,10 @@ def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_st
params_list[0].max_tokens = max_tokens # type: ignore
if len(params_list) > 1:
params_list[1].num_inference_steps = num_inference_steps # type: ignore
params_list[1].extra_args = { # type: ignore
"cfg_text_scale": 4.0,
"cfg_img_scale": 1.5,
}
return params_list


Expand Down
114 changes: 114 additions & 0 deletions tests/entrypoints/test_cfg_companion_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import time
from types import SimpleNamespace

import pytest

from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


def dummy_expand_func(prompt, sp0):
if prompt == "expand_me":
return [SimpleNamespace(prompt={"prompt": "neg"}, role="cfg_text", request_id_suffix="__cfg_text")]
return []


@pytest.fixture
def tracker():
sp0 = SimpleNamespace()
return CfgCompanionTracker(prompt_expand_func=dummy_expand_func, stage0_sampling_params=sp0, timeout_s=0.1)


def test_companion_tracker_initialization(tracker):
assert not tracker.is_active
assert tracker.num_companions == 0


def test_expand_prompts_registers_companions(tracker):
request_id_to_prompt = {"req1": "expand_me", "req2": "do_not_expand"}

pairs = tracker.expand_prompts(request_id_to_prompt)

assert len(pairs) == 1
companion_id, prompt = pairs[0]
assert companion_id == "req1__cfg_text"
assert prompt == {"prompt": "neg"}

assert tracker.is_active
assert tracker.num_companions == 1
assert tracker.is_companion("req1__cfg_text")
assert not tracker.is_companion("req2__cfg_text")
assert tracker.has_companions("req1")
assert not tracker.has_companions("req2")

comp_map = tracker.get_companion_request_ids("req1")
assert comp_map == {"cfg_text": "req1__cfg_text"}


def test_companion_lifecycle_success(tracker):
request_id_to_prompt = {"req1": "expand_me"}
tracker.expand_prompts(request_id_to_prompt)

# Defer parent
engine_outputs = {"out": 123}
tracker.defer_parent("req1", engine_outputs, stage_id=0)

# Initially not done
assert not tracker.all_companions_done("req1")

# Companion completes
parent_id = tracker.on_companion_completed("req1__cfg_text")

# Parent should be returned since all companions are done and it is pending
assert parent_id == "req1"
assert tracker.all_companions_done("req1")

# Pop pending parent
popped = tracker.pop_pending_parent("req1")
assert popped is not None
assert popped["engine_outputs"] == engine_outputs
assert popped["stage_id"] == 0


def test_companion_lifecycle_failure(tracker):
request_id_to_prompt = {"req1": "expand_me"}
tracker.expand_prompts(request_id_to_prompt)

tracker.defer_parent("req1", {"out": 123}, stage_id=0)

# Companion fails
parent_id, aborted = tracker.on_companion_error("req1__cfg_text")

assert parent_id == "req1"
assert aborted is True
assert tracker.is_parent_failed("req1")

# Parent should be removed from pending list
assert tracker.pop_pending_parent("req1") is None

# Consume failure
tracker.consume_parent_failure("req1")
assert not tracker.is_parent_failed("req1")


def test_companion_lifecycle_timeout(tracker):
request_id_to_prompt = {"req1": "expand_me"}
tracker.expand_prompts(request_id_to_prompt)

tracker.defer_parent("req1", {"out": 123}, stage_id=0)

# Initially no timeouts
timeouts = tracker.check_timeouts()
assert len(timeouts) == 0

# Wait for timeout
time.sleep(0.15)

# Check timeouts again
timeouts = tracker.check_timeouts()
assert len(timeouts) == 1
assert timeouts[0] == "req1"

# Should be removed from pending
assert tracker.pop_pending_parent("req1") is None
3 changes: 3 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ class OmniDiffusionConfig:
# Omni configuration (injected from stage config)
omni_kv_config: dict[str, Any] = field(default_factory=dict)

# Model-specific function for collecting CFG KV caches (set at runtime)
cfg_kv_collect_func: Any | None = None

# Quantization settings
# Supported methods: "fp8" (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs)
quantization: str | None = None
Expand Down
44 changes: 34 additions & 10 deletions vllm_omni/diffusion/models/bagel/pipeline_bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,40 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
gen_context["past_key_values"] = injected_kv
seq_len = injected_kv.key_cache[0].shape[0]
gen_context["kv_lens"] = [seq_len]
gen_context["ropes"] = [seq_len]

# Disable CFG: single KV cache cannot support 3-branch CFG
logger.warning("CFG is disabled when using injected KV Cache")
gen_params = BagelGenParams(
num_timesteps=gen_params.num_timesteps,
timestep_shift=gen_params.timestep_shift,
cfg_text_scale=1.0,
cfg_img_scale=1.0,
)
if req.sampling_params.kv_metadata and "ropes" in req.sampling_params.kv_metadata:
gen_context["ropes"] = req.sampling_params.kv_metadata["ropes"]
else:
gen_context["ropes"] = [seq_len]

cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None)
if cfg_text_kv is not None:
logger.info("CFG enabled with multi-KV: using injected cfg_text KV Cache")
cfg_text_seq_len = cfg_text_kv.key_cache[0].shape[0]
cfg_text_context["past_key_values"] = cfg_text_kv
cfg_text_context["kv_lens"] = [cfg_text_seq_len]
cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None)
if cfg_text_metadata and "ropes" in cfg_text_metadata:
cfg_text_context["ropes"] = cfg_text_metadata["ropes"]
else:
cfg_text_context["ropes"] = [cfg_text_seq_len]

cfg_img_kv = getattr(req.sampling_params, "cfg_img_past_key_values", None) or injected_kv
cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0]
cfg_img_context["past_key_values"] = cfg_img_kv
cfg_img_context["kv_lens"] = [cfg_img_seq_len]
cfg_img_metadata = getattr(req.sampling_params, "cfg_img_kv_metadata", None)
if cfg_img_metadata and "ropes" in cfg_img_metadata:
cfg_img_context["ropes"] = cfg_img_metadata["ropes"]
else:
cfg_img_context["ropes"] = [cfg_img_seq_len]
else:
logger.warning("CFG is disabled: only single KV cache available")
gen_params = BagelGenParams(
num_timesteps=gen_params.num_timesteps,
timestep_shift=gen_params.timestep_shift,
cfg_text_scale=1.0,
cfg_img_scale=1.0,
)

else:
image_input = (
Expand Down
6 changes: 5 additions & 1 deletion vllm_omni/diffusion/worker/diffusion_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput:
grad_context = torch.no_grad() if use_hsdp else torch.inference_mode()
with grad_context:
# The manager handles the check for need_recv_cache internally
self.kv_transfer_manager.receive_kv_cache(req, target_device=getattr(self.pipeline, "device", None))
self.kv_transfer_manager.receive_multi_kv_cache(
req,
cfg_kv_collect_func=getattr(self.od_config, "cfg_kv_collect_func", None),
target_device=getattr(self.pipeline, "device", None),
)

if req.sampling_params.generator is None and req.sampling_params.seed is not None:
if req.sampling_params.generator_device is not None:
Expand Down
47 changes: 47 additions & 0 deletions vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,50 @@ def receive_kv_cache(self, req: Any, target_device: torch.device | None = None)
self.apply_kv_cache_to_request(req, data)
return True
return False

def receive_multi_kv_cache(
self,
req: Any,
cfg_kv_collect_func: Callable | None = None,
target_device: torch.device | None = None,
) -> bool:
"""Receive primary KV cache and optional CFG companion KV caches.

First receives the primary KV cache (existing logic). Then, if the
request carries cfg_kv_request_ids and a model-specific
cfg_kv_collect_func is provided, calls it to fetch and attach the
companion KV caches to sampling_params.

Args:
req: Request object with request_id and sampling_params.
cfg_kv_collect_func: Model-specific function for collecting
CFG KV caches. Signature:
(request_id, cfg_request_ids, kv_transfer_manager, target_device)
-> dict[str, Any]
target_device: Device to move tensors to.

Returns:
True if primary KV cache was received successfully.
"""
primary_ok = self.receive_kv_cache(req, target_device)

cfg_ids = getattr(getattr(req, "sampling_params", None), "cfg_kv_request_ids", None)
if cfg_ids and cfg_kv_collect_func:
request_id = getattr(req, "request_id", None) or (
req.request_ids[0] if hasattr(req, "request_ids") and req.request_ids else None
)
try:
cfg_kvs = cfg_kv_collect_func(
request_id,
cfg_ids,
self,
target_device,
)
if cfg_kvs and hasattr(req, "sampling_params") and req.sampling_params is not None:
for key, value in cfg_kvs.items():
setattr(req.sampling_params, key, value)
logger.info("Applied CFG KV caches: %s", list(cfg_kvs.keys()))
except Exception:
logger.exception("Failed to collect CFG KV caches for %s", request_id)

return primary_ok
Loading