Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6cde323
cfg multi-stage
princepride Feb 21, 2026
4112f54
_forward_parent_with_cfg saved in parent_result['engine_outputs'] && …
princepride Feb 22, 2026
428e774
only check whether parent have complete
princepride Feb 22, 2026
a66e55f
change to deepcopy
princepride Feb 23, 2026
a32a813
update bagel e2e test
princepride Feb 23, 2026
883549d
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 23, 2026
321e5a4
add docs
princepride Feb 23, 2026
3ae335f
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 23, 2026
b233bb5
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 24, 2026
f6b0195
Remove unused completed_requests param from _forward_parent_with_cfg
princepride Feb 24, 2026
0156ff7
Build reverse index for O(1) companion-to-parent lookup
princepride Feb 24, 2026
18ccae3
Add single-threaded assumption note for source_outputs_override swap
princepride Feb 24, 2026
c0a1eb7
Remove unused is_cfg_companion_request and get_parent_request_id
princepride Feb 24, 2026
cea6766
Clarify why cfg_kv_collect_func is resolved on the parent side
princepride Feb 24, 2026
9dba25f
Fix _get_negative_prompt to treat empty string as absent
princepride Feb 24, 2026
6cca43d
Add comment explaining max_batch_size assumption in bagel.yaml
princepride Feb 24, 2026
9b9f06e
fix pre-commit
princepride Feb 24, 2026
a3a10f2
remove omni cfg_kv_collect_func
princepride Feb 24, 2026
06f2165
move load_func_from_config to stage_utils
princepride Feb 24, 2026
5ff56f6
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 24, 2026
3f90cdd
wrap cfg processor
princepride Feb 26, 2026
a694646
wrap cfg processor
princepride Feb 26, 2026
30b5db4
wrap cfg processor
princepride Feb 26, 2026
b50abab
wrap cfg processor
princepride Feb 26, 2026
a6fc71d
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 26, 2026
ebbc36f
Merge branch 'main' into feature/cfg-multi-stage
princepride Feb 27, 2026
6a04e70
correct process completed_requests if transfer data to next stage failed
princepride Feb 28, 2026
c2eb553
idiomatic return value
princepride Feb 28, 2026
8ccf529
if not sent_via_connector, throw error
princepride Feb 28, 2026
9c688d6
add test_cfg_companion_tracker and fix bagel exception handlers
princepride Mar 1, 2026
46ef8c5
Merge branch 'main' into feature/cfg-multi-stage
princepride Mar 1, 2026
753a19a
test kv_transfer_manager methods expand
princepride Mar 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/offline_inference/bagel/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def parse_args():
parser.add_argument("--cfg-text-scale", type=float, default=4.0, help="Text CFG scale (default: 4.0)")
parser.add_argument("--cfg-img-scale", type=float, default=1.5, help="Image CFG scale (default: 1.5)")
parser.add_argument(
"--negative-prompt", type=str, default=None, help="Negative prompt (not yet supported, reserved for future)"
"--negative-prompt", type=str, default=None, help="Negative prompt for CFG (default: empty prompt)"
)

args = parser.parse_args()
Expand Down Expand Up @@ -162,6 +162,8 @@ def main():
# text2img
final_prompt_text = f"<|im_start|>{p}<|im_end|>"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
if args.negative_prompt is not None:
prompt_dict["negative_prompt"] = args.negative_prompt
formatted_prompts.append(prompt_dict)

params_list = omni.default_sampling_params_list
Expand All @@ -170,10 +172,13 @@ def main():
if len(params_list) > 1:
diffusion_params = params_list[1]
diffusion_params.num_inference_steps = args.steps # type: ignore
diffusion_params.extra_args = { # type: ignore
extra = {
"cfg_text_scale": args.cfg_text_scale,
"cfg_img_scale": args.cfg_img_scale,
}
if args.negative_prompt is not None:
extra["negative_prompt"] = args.negative_prompt
diffusion_params.extra_args = extra # type: ignore

omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))

Expand Down
3 changes: 3 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ class OmniDiffusionConfig:
# Omni configuration (injected from stage config)
omni_kv_config: dict[str, Any] = field(default_factory=dict)

# Model-specific function for collecting CFG KV caches (set at runtime)
cfg_kv_collect_func: Any | None = None

# Quantization settings
# Supported methods: "fp8" (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs)
quantization: str | None = None
Expand Down
44 changes: 34 additions & 10 deletions vllm_omni/diffusion/models/bagel/pipeline_bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,40 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
gen_context["past_key_values"] = injected_kv
seq_len = injected_kv.key_cache[0].shape[0]
gen_context["kv_lens"] = [seq_len]
gen_context["ropes"] = [seq_len]

# Disable CFG: single KV cache cannot support 3-branch CFG
logger.warning("CFG is disabled when using injected KV Cache")
gen_params = BagelGenParams(
num_timesteps=gen_params.num_timesteps,
timestep_shift=gen_params.timestep_shift,
cfg_text_scale=1.0,
cfg_img_scale=1.0,
)
if req.sampling_params.kv_metadata and "ropes" in req.sampling_params.kv_metadata:
gen_context["ropes"] = req.sampling_params.kv_metadata["ropes"]
else:
gen_context["ropes"] = [seq_len]

cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None)
if cfg_text_kv is not None:
logger.info("CFG enabled with multi-KV: using injected cfg_text KV Cache")
cfg_text_seq_len = cfg_text_kv.key_cache[0].shape[0]
cfg_text_context["past_key_values"] = cfg_text_kv
cfg_text_context["kv_lens"] = [cfg_text_seq_len]
cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None)
if cfg_text_metadata and "ropes" in cfg_text_metadata:
cfg_text_context["ropes"] = cfg_text_metadata["ropes"]
else:
cfg_text_context["ropes"] = [cfg_text_seq_len]

cfg_img_kv = getattr(req.sampling_params, "cfg_img_past_key_values", None) or injected_kv
cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0]
cfg_img_context["past_key_values"] = cfg_img_kv
cfg_img_context["kv_lens"] = [cfg_img_seq_len]
cfg_img_metadata = getattr(req.sampling_params, "cfg_img_kv_metadata", None)
if cfg_img_metadata and "ropes" in cfg_img_metadata:
cfg_img_context["ropes"] = cfg_img_metadata["ropes"]
else:
cfg_img_context["ropes"] = [cfg_img_seq_len]
else:
logger.warning("CFG is disabled: only single KV cache available")
gen_params = BagelGenParams(
num_timesteps=gen_params.num_timesteps,
timestep_shift=gen_params.timestep_shift,
cfg_text_scale=1.0,
cfg_img_scale=1.0,
)

else:
image_input = (
Expand Down
8 changes: 6 additions & 2 deletions vllm_omni/diffusion/worker/diffusion_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,12 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput:
if len(req.prompts) == 0:
raise ValueError("Cannot execute model with empty request list")

# The manager handles the check for need_recv_cache internally
self.kv_transfer_manager.receive_kv_cache(req, target_device=getattr(self.pipeline, "device", None))
cfg_kv_collect_func = getattr(self.od_config, "cfg_kv_collect_func", None)
self.kv_transfer_manager.receive_multi_kv_cache(
req,
cfg_kv_collect_func=cfg_kv_collect_func,
target_device=getattr(self.pipeline, "device", None),
)

if req.sampling_params.generator is None and req.sampling_params.seed is not None:
if req.sampling_params.generator_device is not None:
Expand Down
47 changes: 47 additions & 0 deletions vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,50 @@ def receive_kv_cache(self, req: Any, target_device: torch.device | None = None)
self.apply_kv_cache_to_request(req, data)
return True
return False

def receive_multi_kv_cache(
self,
req: Any,
cfg_kv_collect_func: Callable | None = None,
target_device: torch.device | None = None,
) -> bool:
"""Receive primary KV cache and optional CFG companion KV caches.

First receives the primary KV cache (existing logic). Then, if the
request carries cfg_kv_request_ids and a model-specific
cfg_kv_collect_func is provided, calls it to fetch and attach the
companion KV caches to sampling_params.

Args:
req: Request object with request_id and sampling_params.
cfg_kv_collect_func: Model-specific function for collecting
CFG KV caches. Signature:
(request_id, cfg_request_ids, kv_transfer_manager, target_device)
-> dict[str, Any]
target_device: Device to move tensors to.

Returns:
True if primary KV cache was received successfully.
"""
primary_ok = self.receive_kv_cache(req, target_device)

cfg_ids = getattr(getattr(req, "sampling_params", None), "cfg_kv_request_ids", None)
if cfg_ids and cfg_kv_collect_func:
request_id = getattr(req, "request_id", None) or (
req.request_ids[0] if hasattr(req, "request_ids") and req.request_ids else None
)
try:
cfg_kvs = cfg_kv_collect_func(
request_id,
cfg_ids,
self,
target_device,
)
if cfg_kvs and hasattr(req, "sampling_params") and req.sampling_params is not None:
for key, value in cfg_kvs.items():
setattr(req.sampling_params, key, value)
logger.info("Applied CFG KV caches: %s", list(cfg_kvs.keys()))
except Exception:
logger.exception("Failed to collect CFG KV caches for %s", request_id)

return primary_ok
4 changes: 4 additions & 0 deletions vllm_omni/entrypoints/async_omni_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
# Capture stage info from kwargs before they might be filtered out
stage_id = kwargs.get("stage_id")
engine_input_source = kwargs.get("engine_input_source")
cfg_kv_collect_func = kwargs.pop("cfg_kv_collect_func", None)

# Build config
if od_config is None:
Expand Down Expand Up @@ -103,6 +104,9 @@ def __init__(
else:
raise

if cfg_kv_collect_func is not None:
od_config.cfg_kv_collect_func = cfg_kv_collect_func

# Initialize engine
self.engine: DiffusionEngine = DiffusionEngine.make_engine(od_config)

Expand Down
Loading