[Feature]: Support cfg kv-cache transfer in multi-stage#1422
[Feature]: Support cfg kv-cache transfer in multi-stage#1422hsliuustc0106 merged 32 commits intovllm-project:mainfrom
Conversation
Signed-off-by: princepride <wangzhipeng628@gmail.com>
|
@natureofnature @hsliuustc0106 @ZJY0516 Can take a look after holiday ends.😊 |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 6cde323408
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
hsliuustc
left a comment
There was a problem hiding this comment.
PR Review Summary
Overview
This PR adds CFG (Classifier-Free Guidance) KV-cache transfer for multi-stage inference in vllm-omni. It enables high-quality 3-branch CFG in Bagel's AR → Diffusion pipeline without degrading image quality.
Stats: 12 files changed, +487 / -18 lines
Critical Issues (Must Fix)
1. Race condition in _forward_parent_with_cfg (P1) 🚨
Location: vllm_omni/entrypoints/omni.py:~1046
Problem: When multiple CFG-enabled requests are in flight, _forward_parent_with_cfg recomputes next_inputs from shared mutable state (self.stage_list[0].engine_outputs) which can be overwritten by a different request's Stage-0 output. This causes the diffusion stage to receive token IDs from the wrong parent request.
Recommended Fix: Use the saved parent_result["engine_outputs"] instead of recomputing from shared state.
2. Missing error handling for companion request failures 🚨
Location: vllm_omni/entrypoints/omni.py ~806-825
Problem: If a CFG companion request fails at Stage-0, the parent request will wait indefinitely → deadlock.
Recommended Fix: Add timeout and error propagation.
3. Memory leak potential with _pending_parent_results 🚨
Location: vllm_omni/entrypoints/omni.py ~935-940
Problem: Failed companion requests are never cleaned up.
Recommended Fix: Implement cleanup on error/timeout paths.
Important Issues (Should Fix)
4. Missing validation in collect_cfg_kv_caches
vllm_omni/model_executor/stage_input_processors/bagel.py lines 107-139
5. Batch size limitation without documentation
vllm_omni/model_executor/stage_configs/bagel.yaml line 7
6. No handling for img2img case
vllm_omni/model_executor/stage_input_processors/bagel.py line 76
Positive Aspects ✅
- Clean separation of concerns
- Extensible hook-based design
- Backward compatible
- CI checks passing
Overall Assessment
The core design is sound, but the race condition is a must-fix before merge. Once the critical issues are addressed, this will be a valuable addition.
Action items: Fix race condition, add error handling, implement cleanup for memory leaks.
lishunyang12
left a comment
There was a problem hiding this comment.
Nice approach to CFG KV-cache transfer. I have a few concerns around error handling and concurrency.
…add companion timeout Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
|
@vllm-omni-reviewer |
🤖 VLLM-Omni PR ReviewCode Review: Support CFG KV-Cache Transfer in Multi-Stage1. OverviewThis PR implements CFG (Classifier-Free Guidance) KV-cache transfer for multi-stage inference in VLLM-Omni. The key insight is that BAGEL's 3-branch CFG requires multiple prompts through the AR stage (gen, cfg_text, cfg_img), and their KV caches need to be transferred to the diffusion stage. Key Changes:
Overall Assessment: Positive - The implementation is well-structured and addresses a real limitation. The before/after images demonstrate clear quality improvement. 2. Code QualityStrengths
Issues
sp_next = copy.copy(sampling_params_list[next_stage_id])This is a shallow copy. If
cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None)
# ... later ...
cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None)Consider extracting
return "<|im_start|><|im_end|>"Consider making this a constant or configurable value. 3. Architecture & DesignStrengths
Concerns
def _load_func_from_config(stage_config: Any, attr_name: str):
func_path = getattr(stage_config, attr_name, None)
if not func_path:
return None
module_path, func_name = func_path.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, func_name)This dynamically imports and calls functions based on config. If config files can be user-controlled, this is a potential security vector. Consider:
class CFGCompanionManager:
"""Manages CFG companion request lifecycle."""
def __init__(self, prompt_expand_func, timeout_s: float = 120.0):
self.companion_map: dict[str, dict[str, str]] = {}
self.companion_ids: set[str] = set()
self.companion_done: dict[str, set[str]] = {}
self.pending_parents: dict[str, Any] = {}
self.failed_parents: set[str] = set()
def expand_prompts(self, request_id_to_prompt, sampling_params): ...
def on_companion_complete(self, req_id): ...
def check_pending_timeouts(self): ...4. Security & SafetyResource Management
Input Validation
except Exception:
logger.exception("Failed to collect CFG KV caches for %s", request_id)Potential Issues
5. Testing & DocumentationTest Coverage
Documentation
Suggested Test Cases# Test prompt expansion
def test_expand_cfg_prompts_text2img():
prompt = {"prompt": "A cat", "modalities": ["image"]}
result = expand_cfg_prompts(prompt, mock_sampling_params)
assert len(result) == 1
assert result[0].role == "cfg_text"
# Test timeout handling
def test_cfg_companion_timeout():
# Verify parent request is properly cleaned up after timeout
pass
# Test companion failure propagation
def test_cfg_companion_failure_propagates():
# Verify parent fails when companion fails
pass6. Specific Suggestions
|
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
hsliuustc0106
left a comment
There was a problem hiding this comment.
Overall the design is solid — the hook-based approach keeps model-specific CFG logic out of the orchestrator core, and the companion request paradigm is a reasonable way to handle multi-branch CFG across disaggregated stages.
Main concerns:
- Verify request accounting (
completed_requests) is correct in all paths (normal, timeout, companion failure) — bugs there would cause hangs. - O(n) parent lookup for companion requests could be improved with a reverse index.
source_outputs_overridetemporary mutation of shared state is not thread-safe.- Unused utility functions (
is_cfg_companion_request,get_parent_request_id) should be removed or deferred.
See inline comments for details.
The parameter is passed as a plain int (value-copied), so any mutation inside the method would never propagate back to the caller's loop. The method never uses it either. Remove to avoid misleading future maintainers into thinking they can increment it here. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: princepride <wangzhipeng628@gmail.com>
Replace two O(parents × roles) linear scans in the error-handling and companion-completion paths with a pre-built cfg_companion_to_parent dict populated during the expansion phase. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: princepride <wangzhipeng628@gmail.com>
The temporary mutation of stage_list[].engine_outputs is safe today because the orchestrator loop is single-threaded, but would race under concurrency. Add a NOTE comment to flag this for future work. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: princepride <wangzhipeng628@gmail.com>
Neither function is called anywhere in the codebase. Removing them keeps the diff focused; they can be re-added when actually needed. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: princepride <wangzhipeng628@gmail.com>
The resolved function object is never used in the orchestrator process; only the raw string travels to the worker for re-resolution. Add a comment explaining this is intentional early-validation of the import. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: princepride <wangzhipeng628@gmail.com>
An empty string negative_prompt is not meaningful for Bagel CFG guidance. Change `if neg is not None` to `if neg` so that both None and "" fall through to the default token pair. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: princepride <wangzhipeng628@gmail.com>
Stage-0 max_batch_size=2 assumes single-prompt inference (1 user + 1 CFG companion). For multi-prompt batches it should scale accordingly. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: princepride <wangzhipeng628@gmail.com>
c4ef389 to
6cca43d
Compare
Signed-off-by: princepride <wangzhipeng628@gmail.com>
|
@vllm-omni-reviewer |
🤖 VLLM-Omni PR ReviewCode Review: Support CFG KV-Cache Transfer in Multi-Stage1. OverviewThis PR implements Classifier-Free Guidance (CFG) support across disaggregated multi-stage inference pipelines through a "companion request" paradigm. The implementation:
The approach is well-designed, allowing CFG to work efficiently without redundant text encoder evaluations on diffusion workers. Overall Assessment: Positive - The implementation is solid, well-documented, and follows the existing architecture patterns. A few suggestions for robustness and maintainability are noted below. 2. Code QualityStrengths
ConcernsMagic Strings: Several string literals are used throughout that should be constants: # In omni.py and bagel.py
"cfg_text", "cfg_img", "__cfg_text"Type Annotations: Some functions in Complex State Management: The orchestrator in 3. Architecture & DesignStrengths
SuggestionsThread Safety Acknowledgment: The comment in # NOTE: This relies on the orchestrator being single-threaded.
# If concurrency is introduced, replace with a per-call context
# or a thread-local to avoid racing on shared mutable state.Consider adding a docstring or assertion to document this assumption explicitly. 4. Security & SafetyDynamic Function LoadingThe
def _load_func_from_config(stage_config: Any, attr_name: str):
"""Dynamically import a function referenced by a dotted path in stage config."""
func_path = getattr(stage_config, attr_name, None)
if not func_path:
return None
module_path, func_name = func_path.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, func_name)Suggestion: Consider validating that the function path starts with an allowed prefix (e.g., Resource ManagementThe timeout mechanism for pending parents is good:
However, the default timeout of 120 seconds may be too long for some use cases. Consider documenting this environment variable more prominently. 5. Testing & DocumentationDocumentationDocumentation is comprehensive and well-written. The architecture overview clearly explains the CFG companion flow. TestingThe test changes update reference pixels to match new CFG-enabled output, but there are no unit tests for:
Suggestion: Add unit tests for the new processor functions in 6. Specific Suggestions
|
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
lishunyang12
left a comment
There was a problem hiding this comment.
Final Review Pass
All my original concerns have been addressed:
-
Deadlock risk -- Resolved. Companion failure propagation (error path + timeout expiry at
VLLM_CFG_PENDING_TIMEOUT_S) covers all three failure scenarios. -
Linear scan -- Resolved.
cfg_companion_to_parentreverse index gives O(1) lookup. -
Race condition on shared
engine_outputs-- Resolved._forward_parent_with_cfgpasses savedparent_result["engine_outputs"]viasource_outputs_override. Single-threaded assumption NOTE is appreciated. -
receive_kv_cache_for_requestdefinition -- Confirmed existing in the codebase. -
img2img CFG skip -- Understood; img2img bypasses CFG expansion by design.
-
Hard failure for connector -- Agreed; KV transfer is required for CFG,
RuntimeErroris correct. -
Batch size comment -- Added and clear.
Also noted that feedback from @hsliuustc0106 was addressed: unused completed_requests param removed, empty-string negative prompt handled, unused utility functions removed, clarifying comments added.
Minor suggestions for follow-up (non-blocking):
- Unit tests for
expand_cfg_promptsandcollect_cfg_kv_cacheswould improve confidence in edge cases - Role strings (
"cfg_text","cfg_img") could be extracted to constants alongsideCFG_TEXT_SUFFIX cfg_kv_collect_functyping indiffusion/data.pycould useCallableinstead ofAny
LGTM -- approving.
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
|
@hsliuustc0106 @tzhouam I encapsulated the CFG prompt expansion and companion tracking logic in the Omni orchestrator, PTAL |
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
lishunyang12
left a comment
There was a problem hiding this comment.
The refactoring into CfgCompanionTracker is a big improvement. One issue: if forward_parent_with_cfg hits an exception in process_engine_inputs, the parent was already popped from _pending_parents via pop_pending_parent. That means the request is orphaned -- it won't be caught by timeout, and completed_requests never increments, so the scheduling loop hangs.
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
@hsliuustc0106 @tzhouam I think @lishunyang12's suggestion is reasonable. Our original code did not properly handle data that failed to be sent to the next stage. |
lishunyang12
left a comment
There was a problem hiding this comment.
All three concerns addressed. The completed_requests accounting fix in the original code path is a good catch too. LGTM.
Review SummaryPR Type: Feature - CFG KV-cache transfer for multi-stage inference Key Observations:
Suggestions:
|
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Purpose
Related: #1419
Test Plan
Multi-Stage inference tasks can now generate outputs with the same high quality as only DIT stage.
Multi-Stage Test:
Diffusion Test:
Test Result
Before:

After:
