Skip to content

[Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training#21204

Open
Rockdu wants to merge 10 commits intosgl-project:mainfrom
MikukuOvO:feat/rollout-logprob-support-revamp-20260323
Open

[Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training#21204
Rockdu wants to merge 10 commits intosgl-project:mainfrom
MikukuOvO:feat/rollout-logprob-support-revamp-20260323

Conversation

@Rockdu
Copy link

@Rockdu Rockdu commented Mar 23, 2026

[Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training

1. Architecture & Design

Motivation

RL-based post-training of diffusion models (e.g., FlowGRPO) requires computing per-step log-probabilities along the denoising trajectory. This PR adds a modular rollout log-prob engine to SGLang-D that supports three denoising strategies and is compatible with Sequence Parallelism (SP).

Design Principles

  • Stateless scheduler: All intermediate rollout state lives in a RolloutSessionData dataclass attached to the batch object (batch._rollout_session_data). The scheduler itself stores nothing, making it safe for concurrent use.
  • Composable mixins: Rollout logic is encapsulated in SchedulerRLMixin (core) and SchedulerRLDebugMixin (debug tensor collection), mixed into FlowMatchEulerDiscreteScheduler alongside the existing SchedulerMixin and ConfigMixin.
  • Opt-in activation: Rollout is triggered by SamplingParams.rollout = True. When disabled, no additional code paths execute and the scheduler step is unchanged.

Module Structure

runtime/post_training/
  rl_dataclasses.py            # RolloutSessionData, RolloutDebugTensors, RolloutTrajectoryData
  scheduler_rl_mixin.py        # SchedulerRLMixin  (core log-prob engine)
  scheduler_rl_debug_mixin.py  # SchedulerRLDebugMixin  (debug tensor accumulation)

Data Flow

prepare_rollout(batch, pipeline_config)
  --> creates RolloutSessionData on batch

for each denoising step:
  scheduler.step(batch=batch, ...)
    --> flow_sde_sampling(batch, model_output, sample, sigma, sigma_next, generator)
        --> computes prev_sample, log_prob_local_sum, local_elem_count
    --> append_local_rollout_log_probs(batch, ...)

after denoising loop:
  collect_rollout_log_probs(batch)      --> batch.rollout_trajectory_data.rollout_log_probs
  collect_rollout_debug_tensors(batch)  --> batch.rollout_trajectory_data.rollout_debug_tensors  (optional)
  release_rollout_resources(batch)      --> batch._rollout_session_data = None

2. Features, API & Reliability

Supported Rollout Strategies

Strategy rollout_sde_type Description
SDE "sde" Standard stochastic differential equation with Gaussian diffusion noise
CPS "cps" Coefficients-Preserving Sampling https://arxiv.org/pdf/2509.05952
ODE "ode" Deterministic ODE step (zero variance, for alignment verification)

API Parameters (SamplingParams)

Parameter Type Default Description
rollout bool False Enable rollout log-prob computation
rollout_sde_type str "sde" Step strategy: "sde", "cps", or "ode"
rollout_noise_level float 0.7 Noise level for SDE/CPS
rollout_log_prob_no_const bool False Omit constant terms in log-prob (common for RL loss)
rollout_debug_mode bool False Return intermediate tensors (variance noise, mean, std, model output)

Output

When rollout=True, the batch's rollout_trajectory_data (type RolloutTrajectoryData) is populated:

  • rollout_log_probs: Tensor [B, T] -- per-step log-probabilities, reduced as global_sum / global_count across SP ranks.
  • rollout_debug_tensors (when rollout_debug_mode=True): RolloutDebugTensors containing per-step [B, T, ...] tensors for variance_noise, prev_sample_mean, noise_std_dev, and model_output.

These are available through the OpenAI-compatible image/video API via the existing extra-fields mechanism.

Reliability Testing

Three levels of testing validate correctness:

  1. Unit tests (test_scheduler_rollout_ode_unit.py):

    • ODE step determinism: confirms no variance noise is sampled.
    • Debug tensor shapes: verifies noise_std_dev is zero and shapes are consistent.
    • SDE/CPS FlowGRPO alignment: each single step is compared against a standalone FlowGRPO reference implementation across 4 seeds. Max absolute difference <= 1e-6 for prev_sample, prev_sample_mean, and noise_std_dev.
  2. ODE-vs-non-rollout end-to-end (Z-Image-Turbo, 1024x1024, seed=42): ODE rollout produces bit-identical output images to the standard non-rollout denoising path across all parallelism configs (TP1/SP2, TP2/SP1, TP1/SP1+CFGP).

  3. SDE/CPS parallel consistency (Z-Image-Turbo, 1024x1024, seed=42): Intermediate tensors (variance noise, noise std, prev sample mean, model output) are compared across parallelism configs against a single-GPU reference, with all steps and first-step metrics reported.


3. Detailed Test Results

3.1 ODE vs Non-Rollout (Bit-Exact Alignment)

Model: Tongyi-MAI/Z-Image-Turbo, seed=42, 1024x1024, 2 GPUs.

Config TP SP CFGP Parallel Exact Match Max Abs Diff MSE Cosine Sim
TP1 SP2 1 2 No True 0.0 0.0 1.0
TP2 SP1 2 1 No True 0.0 0.0 1.0
TP1 SP1 CFGP 1 1 Yes True 0.0 0.0 1.0

All configs produce pixel-identical images, confirming the ODE path introduces zero numerical divergence from the standard denoising pipeline.

3.2 SDE/CPS Parallel Consistency

Model: Tongyi-MAI/Z-Image-Turbo, seed=42, noise_level=0.5, 9 steps, 2 GPUs.
Reference: single-GPU (TP1, SP1, no CFGP parallel).

SDE mode -- key metrics (all steps):

Config Tensor Max Abs Diff Mean MSE Min Cosine Mean Cosine
TP1 SP2 variance_noise 0 0 1.0 1.0
TP1 SP2 prev_sample_mean 0 0 1.0 1.0
TP1 SP2 noise_std_dev 0 0 1.0 1.0
TP2 SP1 variance_noise 0 0 1.0 1.0
TP2 SP1 prev_sample_mean 2.17 3.68e-3 0.9978 0.9990
TP2 SP1 noise_std_dev 0 0 1.0 1.0
TP1 SP1 CFGP variance_noise 0 0 1.0 1.0
TP1 SP1 CFGP prev_sample_mean 1.45 5.97e-4 0.9996 0.9998
TP1 SP1 CFGP noise_std_dev 0 0 1.0 1.0

CPS mode -- key metrics (all steps):

Config Tensor Max Abs Diff Mean MSE Min Cosine Mean Cosine
TP1 SP2 variance_noise 0 0 1.0 1.0
TP1 SP2 prev_sample_mean 0 0 1.0 1.0
TP2 SP1 variance_noise 0 0 1.0 1.0
TP2 SP1 prev_sample_mean 2.32 5.12e-3 0.9961 0.9980
TP1 SP1 CFGP variance_noise 0 0 1.0 1.0
TP1 SP1 CFGP prev_sample_mean 2.10 9.20e-4 0.9992 0.9997

Key observations across both modes:

  • variance_noise and noise_std_dev are bit-exact across all configs -- the SP-aware noise generation correctly reproduces the single-GPU noise stream.
  • prev_sample_mean and model_output show small accumulated differences in TP2/SP1 and CFGP-parallel configs, caused by non-deterministic floating-point reduction order in the transformer (not the rollout engine). First-step differences are negligible (cosine > 0.9999).
  • SP2 (sequence parallel on the latent axis) is bit-exact with the single-GPU reference on all tensors.

4. Summary

  • The rollout log-prob engine correctly implements SDE, CPS, and ODE strategies as validated by unit tests aligned with FlowGRPO reference formulas.
  • ODE mode produces bit-identical outputs to the standard non-rollout path across all parallelism configurations, confirming zero regression on existing functionality.
  • Noise generation is bit-exact under SP -- the shard_latents_for_sp-based noise generation guarantees identical random streams regardless of parallelism.
  • Small prev_sample_mean divergences under TP2/CFGP-parallel are expected floating-point non-determinism from the DiT forward pass, not the rollout engine. These accumulate over steps but remain within cosine similarity > 0.996.
  • All rollout state is batch-scoped via RolloutSessionData -- the scheduler is fully stateless and safe for multi-request serving.

MikukuOvO and others added 9 commits March 23, 2026 07:33
Replace the original monolithic flow_matching_with_logprob patch with a
modular mixin-based architecture:

- SchedulerRLMixin: core rollout logic (prepare, SDE/CPS/ODE sampling,
  log-prob accumulation, resource lifecycle)
- SchedulerRLDebugMixin: optional debug tensor collection
- RolloutSessionData: per-batch state dataclass stored on batch object
- All rollout state lives on the batch, keeping the scheduler stateless

Made-with: Cursor
- Pass batch object through scheduler.step() to enable per-request rollout
- Add _maybe_prepare_rollout / _maybe_collect_rollout_log_probs lifecycle
  hooks in the denoising stage
- Wire rollout flow through decoding and denoising_dmd stages

Made-with: Cursor
- Add rollout_sde_type, rollout_noise_level, rollout_log_prob_no_const,
  rollout_debug_mode to SamplingParams with validation
- Propagate parameters through OpenAI-compatible image/video endpoints
- Wire through diffusion_generator and gpu_worker

Made-with: Cursor
- Add shard/gather latent helpers in zimage DiT for sequence parallelism

Made-with: Cursor
- ODE mode: bit-exact alignment against FlowGRPO reference implementation
- SDE/CPS mode: verify log-prob sign, shape, noise injection behavior
- Validate prepare/consume/release lifecycle and edge cases

Made-with: Cursor
…ults

- Run isort/black/ruff formatting on all changed files
- Remove unused TeaCacheParams imports from schedule_batch.py (F401)
- Rewrite FlowGRPO alignment test: use verbatim reference from
  sd3_sde_with_logprob.py, verify log_prob at atol=1e-6
- Match FlowGRPO convention: SDE uses full Gaussian log-prob
  (no_const=False), CPS uses no_const=True
- Remove explicit defaults from rollout CLI args to fix
  test_get_cli_args_drops_unset_sampling_params

Made-with: Cursor
Rollout is an internal post-training feature; it should not be exposed
through the standardized OpenAI image/video generation endpoints.
Parameters remain accessible via SamplingParams CLI and direct generator API.

Made-with: Cursor
…rollout_unit

The file tests ODE, SDE, and CPS modes — the old name was misleading.

Made-with: Cursor
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a robust and modular log-probability computation engine designed for reinforcement learning-based post-training of diffusion models. It enables the calculation of per-step log-probabilities during the denoising process, supporting various sampling strategies like SDE, CPS, and ODE. The design prioritizes statelessness and compatibility with distributed training setups, ensuring that the new functionality integrates seamlessly without impacting existing performance or correctness.

Highlights

  • Modular Rollout Log-Prob Engine: Introduced a new modular engine to compute per-step log-probabilities along the denoising trajectory, essential for RL-based post-training of diffusion models.
  • Multiple Denoising Strategies: Added support for three distinct denoising strategies: Stochastic Differential Equation (SDE), Coefficients-Preserving Sampling (CPS), and deterministic Ordinary Differential Equation (ODE) steps.
  • Stateless Scheduler Design: Implemented a stateless scheduler where all intermediate rollout state is managed via a RolloutSessionData dataclass attached to the batch object, ensuring concurrency safety.
  • Sequence Parallelism Compatibility: Ensured the new log-prob engine is fully compatible with Sequence Parallelism (SP), with noise generation being bit-exact across different parallelism configurations.
  • Comprehensive Reliability Testing: Validated the implementation through extensive unit tests, end-to-end ODE-vs-non-rollout comparisons (achieving bit-identical outputs), and SDE/CPS parallel consistency checks.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive implementation for computing rollout log-probabilities in diffusion models, which is essential for RL-based post-training. The changes are well-structured, using mixins to extend scheduler functionality and keeping the new logic isolated. The state management is thoughtfully designed to be stateless at the scheduler level, which is great for concurrent use. The addition of unit tests that align with a reference implementation from FlowGRPO provides strong confidence in the correctness of the complex SDE/CPS formulas.

I have two suggestions for improvement. One is a high-severity issue regarding a potential division-by-zero error when rollout_noise_level is 0. The other is a medium-severity suggestion for a minor performance optimization by pre-calculating a constant. Overall, this is a solid contribution.

action="store_true",
help="Whether to return the trajectory",
)
add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use a dedicated parser (for RL), and add arguments of the parser here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants