[Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training#21204
[Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training#21204Rockdu wants to merge 10 commits intosgl-project:mainfrom
Conversation
…port Rebased onto latest main.
Replace the original monolithic flow_matching_with_logprob patch with a modular mixin-based architecture: - SchedulerRLMixin: core rollout logic (prepare, SDE/CPS/ODE sampling, log-prob accumulation, resource lifecycle) - SchedulerRLDebugMixin: optional debug tensor collection - RolloutSessionData: per-batch state dataclass stored on batch object - All rollout state lives on the batch, keeping the scheduler stateless Made-with: Cursor
- Pass batch object through scheduler.step() to enable per-request rollout - Add _maybe_prepare_rollout / _maybe_collect_rollout_log_probs lifecycle hooks in the denoising stage - Wire rollout flow through decoding and denoising_dmd stages Made-with: Cursor
- Add rollout_sde_type, rollout_noise_level, rollout_log_prob_no_const, rollout_debug_mode to SamplingParams with validation - Propagate parameters through OpenAI-compatible image/video endpoints - Wire through diffusion_generator and gpu_worker Made-with: Cursor
- Add shard/gather latent helpers in zimage DiT for sequence parallelism Made-with: Cursor
- ODE mode: bit-exact alignment against FlowGRPO reference implementation - SDE/CPS mode: verify log-prob sign, shape, noise injection behavior - Validate prepare/consume/release lifecycle and edge cases Made-with: Cursor
…ults - Run isort/black/ruff formatting on all changed files - Remove unused TeaCacheParams imports from schedule_batch.py (F401) - Rewrite FlowGRPO alignment test: use verbatim reference from sd3_sde_with_logprob.py, verify log_prob at atol=1e-6 - Match FlowGRPO convention: SDE uses full Gaussian log-prob (no_const=False), CPS uses no_const=True - Remove explicit defaults from rollout CLI args to fix test_get_cli_args_drops_unset_sampling_params Made-with: Cursor
Rollout is an internal post-training feature; it should not be exposed through the standardized OpenAI image/video generation endpoints. Parameters remain accessible via SamplingParams CLI and direct generator API. Made-with: Cursor
…rollout_unit The file tests ODE, SDE, and CPS modes — the old name was misleading. Made-with: Cursor
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a robust and modular log-probability computation engine designed for reinforcement learning-based post-training of diffusion models. It enables the calculation of per-step log-probabilities during the denoising process, supporting various sampling strategies like SDE, CPS, and ODE. The design prioritizes statelessness and compatibility with distributed training setups, ensuring that the new functionality integrates seamlessly without impacting existing performance or correctness. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive implementation for computing rollout log-probabilities in diffusion models, which is essential for RL-based post-training. The changes are well-structured, using mixins to extend scheduler functionality and keeping the new logic isolated. The state management is thoughtfully designed to be stateless at the scheduler level, which is great for concurrent use. The addition of unit tests that align with a reference implementation from FlowGRPO provides strong confidence in the correctness of the complex SDE/CPS formulas.
I have two suggestions for improvement. One is a high-severity issue regarding a potential division-by-zero error when rollout_noise_level is 0. The other is a medium-severity suggestion for a minor performance optimization by pre-calculating a constant. Overall, this is a solid contribution.
python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py
Outdated
Show resolved
Hide resolved
| action="store_true", | ||
| help="Whether to return the trajectory", | ||
| ) | ||
| add_argument( |
There was a problem hiding this comment.
could we use a dedicated parser (for RL), and add arguments of the parser here
[Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training
1. Architecture & Design
Motivation
RL-based post-training of diffusion models (e.g., FlowGRPO) requires computing per-step log-probabilities along the denoising trajectory. This PR adds a modular rollout log-prob engine to SGLang-D that supports three denoising strategies and is compatible with Sequence Parallelism (SP).
Design Principles
RolloutSessionDatadataclass attached to the batch object (batch._rollout_session_data). The scheduler itself stores nothing, making it safe for concurrent use.SchedulerRLMixin(core) andSchedulerRLDebugMixin(debug tensor collection), mixed intoFlowMatchEulerDiscreteScheduleralongside the existingSchedulerMixinandConfigMixin.SamplingParams.rollout = True. When disabled, no additional code paths execute and the scheduler step is unchanged.Module Structure
Data Flow
2. Features, API & Reliability
Supported Rollout Strategies
rollout_sde_type"sde""cps""ode"API Parameters (
SamplingParams)rolloutboolFalserollout_sde_typestr"sde""sde","cps", or"ode"rollout_noise_levelfloat0.7rollout_log_prob_no_constboolFalserollout_debug_modeboolFalseOutput
When
rollout=True, the batch'srollout_trajectory_data(typeRolloutTrajectoryData) is populated:rollout_log_probs:Tensor [B, T]-- per-step log-probabilities, reduced asglobal_sum / global_countacross SP ranks.rollout_debug_tensors(whenrollout_debug_mode=True):RolloutDebugTensorscontaining per-step[B, T, ...]tensors forvariance_noise,prev_sample_mean,noise_std_dev, andmodel_output.These are available through the OpenAI-compatible image/video API via the existing extra-fields mechanism.
Reliability Testing
Three levels of testing validate correctness:
Unit tests (
test_scheduler_rollout_ode_unit.py):noise_std_devis zero and shapes are consistent.prev_sample,prev_sample_mean, andnoise_std_dev.ODE-vs-non-rollout end-to-end (Z-Image-Turbo, 1024x1024, seed=42): ODE rollout produces bit-identical output images to the standard non-rollout denoising path across all parallelism configs (TP1/SP2, TP2/SP1, TP1/SP1+CFGP).
SDE/CPS parallel consistency (Z-Image-Turbo, 1024x1024, seed=42): Intermediate tensors (variance noise, noise std, prev sample mean, model output) are compared across parallelism configs against a single-GPU reference, with all steps and first-step metrics reported.
3. Detailed Test Results
3.1 ODE vs Non-Rollout (Bit-Exact Alignment)
Model:
Tongyi-MAI/Z-Image-Turbo, seed=42, 1024x1024, 2 GPUs.All configs produce pixel-identical images, confirming the ODE path introduces zero numerical divergence from the standard denoising pipeline.
3.2 SDE/CPS Parallel Consistency
Model:
Tongyi-MAI/Z-Image-Turbo, seed=42, noise_level=0.5, 9 steps, 2 GPUs.Reference: single-GPU (TP1, SP1, no CFGP parallel).
SDE mode -- key metrics (all steps):
CPS mode -- key metrics (all steps):
Key observations across both modes:
variance_noiseandnoise_std_devare bit-exact across all configs -- the SP-aware noise generation correctly reproduces the single-GPU noise stream.prev_sample_meanandmodel_outputshow small accumulated differences in TP2/SP1 and CFGP-parallel configs, caused by non-deterministic floating-point reduction order in the transformer (not the rollout engine). First-step differences are negligible (cosine > 0.9999).4. Summary
shard_latents_for_sp-based noise generation guarantees identical random streams regardless of parallelism.prev_sample_meandivergences under TP2/CFGP-parallel are expected floating-point non-determinism from the DiT forward pass, not the rollout engine. These accumulate over steps but remain within cosine similarity > 0.996.RolloutSessionData-- the scheduler is fully stateless and safe for multi-request serving.