Skip to content

[diffusion] feat: add rollout log_prob with flow-matching SDE/CPS support#18806

Open
MikukuOvO wants to merge 2 commits intosgl-project:mainfrom
MikukuOvO:feat/rollout-logprob-support
Open

[diffusion] feat: add rollout log_prob with flow-matching SDE/CPS support#18806
MikukuOvO wants to merge 2 commits intosgl-project:mainfrom
MikukuOvO:feat/rollout-logprob-support

Conversation

@MikukuOvO
Copy link
Contributor

Motivation

This PR adds rollout log_prob support for diffusion flow-matching pipelines.
Previously, rollout paths did not expose consistent log_prob signals for flow-matching variants (especially SDE/CPS), which limited downstream training/evaluation workflows that depend on likelihood-based objectives.

Modifications

  • Add rollout log_prob computation in the diffusion rollout path.
  • Add flow-matching rollout support for SDE mode.
  • Add flow-matching rollout support for CPS mode.
  • Expose log_prob in rollout outputs with consistent shape/semantics across modes.
  • Keep backward compatibility for existing rollout callers when log_prob is not requested.
  • Add/extend related tests and validations for shape/dtype, mode coverage (SDE/CPS), and regression behavior.

Accuracy Tests

All tests were run with fixed random seeds.
Under the same seed and configuration, rollout outputs are deterministic: latent and log_prob are consistent across repeated runs.

Flux

  • Sampling modes tested: no rollout, CPS, SDE.
  • Verified deterministic consistency of both latent and log_prob under fixed seed.
  • Figure: sampling_comparison_figure_v2

Qwen

  • Sampling modes tested: no rollout, CPS, SDE.
  • Verified deterministic consistency of both latent and log_prob under fixed seed.
  • Figure: sampling_comparison_figure

Z-Image

  • Sampling modes tested: no rollout, CPS, SDE.
  • Verified deterministic consistency of both latent and log_prob under fixed seed.
  • Figure: sampling_comparison_figure_v3_compressed

Benchmarking and Profiling

  • No dedicated speed benchmark/profiling numbers are included in this PR yet.
  • Functional focus of this PR is rollout log_prob correctness and mode coverage.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added the diffusion SGLang Diffusion label Feb 13, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @MikukuOvO, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the diffusion flow-matching pipeline by integrating the capability to compute and expose log_prob during the rollout process. This addition addresses a critical limitation where consistent log_prob signals were not available for flow-matching variants, particularly SDE and CPS modes. By providing these likelihood-based metrics, the change unlocks new possibilities for training and evaluation workflows that rely on such objectives, ensuring greater utility and flexibility for diffusion models.

Highlights

  • Rollout Log-Probability Support: Implemented log_prob computation during the diffusion rollout process, enabling likelihood-based objectives for downstream tasks.
  • Flow-Matching SDE/CPS Modes: Added specific support for Stochastic Differential Equation (SDE) and Conditional Probability Score (CPS) modes within the flow-matching rollout for log_prob calculation.
  • API and Parameter Exposure: Introduced new rollout and rollout_sde_type parameters in SamplingParams and exposed them through the OpenAI-compatible image and video generation APIs.
  • Trajectory Log-Probability Output: Ensured log_prob is consistently exposed in rollout outputs as trajectory_log_probs with appropriate shape and semantics across different modes.
  • Backward Compatibility: Maintained backward compatibility for existing rollout callers that do not explicitly request log_prob.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/multimodal_gen/configs/sample/sampling_params.py
    • Added rollout and rollout_sde_type fields to SamplingParams with default values.
    • Extended add_cli_args to include command-line arguments for --rollout and --rollout-sde-type.
  • python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py
    • Included trajectory_log_probs in the dictionary of output results for each prompt.
  • python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py
    • Added rollout and rollout_sde_type as optional parameters to _build_sampling_params_from_request.
    • Passed rollout and rollout_sde_type to the SamplingParams construction.
    • Included rollout and rollout_sde_type in the generations endpoint request.
  • python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py
    • Added optional rollout and rollout_sde_type fields to ImageGenerationsRequest.
    • Added optional rollout and rollout_sde_type fields to VideoGenerationsRequest.
  • python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py
    • Conditionally added rollout and rollout_sde_type to sampling_kwargs in _build_sampling_params_from_request.
    • Added rollout and rollout_sde_type as form parameters to the create_video endpoint.
    • Passed rollout and rollout_sde_type to the VideoGenerationsRequest.
  • python/sglang/multimodal_gen/runtime/managers/gpu_worker.py
    • Included trajectory_log_probs when constructing the OutputBatch from the result.
  • python/sglang/multimodal_gen/runtime/pipelines/patches/init.py
    • Added an __init__.py file to the new patches directory.
  • python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
    • Added a new file containing _as_timestep_tensor and sde_step_with_logprob functions for computing log-probabilities in SDE and CPS modes.
  • python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py
    • Added trajectory_log_probs as an optional torch.Tensor to the Req class.
    • Added trajectory_log_probs as an optional torch.Tensor to the OutputBatch class.
  • python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py
    • Passed trajectory_log_probs to the OutputBatch constructor.
  • python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py
    • Passed trajectory_log_probs to the OutputBatch constructor.
  • python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py
    • Imported sde_step_with_logprob from the new patches module.
    • Patched sde_step_with_logprob onto the scheduler instance if not already present.
    • Initialized an empty list trajectory_log_probs for collecting log probabilities.
    • Added logic to determine rollout_enabled and rollout_sde_type, including warnings for invalid types.
    • Conditionally used self.scheduler.sde_step_with_logprob to compute latents and step_log_prob when rollout is enabled, appending step_log_prob to the trajectory.
    • Updated _post_denoising_loop to accept and process trajectory_log_probs, stacking them into a tensor if available.
  • python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py
    • Initialized an empty trajectory_log_probs list when calling _post_denoising_loop.
Activity
  • The author, MikukuOvO, has implemented the core functionality for log_prob computation in diffusion flow-matching pipelines.
  • The author has integrated the new rollout and rollout_sde_type parameters across various configuration, API, and runtime files.
  • The author has added a new module flow_matching_with_logprob.py to handle the specific SDE/CPS log-probability calculations.
  • The author has performed accuracy tests across Flux, Qwen, and Z-Image models, verifying deterministic consistency of latent and log_prob under fixed seeds.
  • The author has noted that benchmarking and profiling are not yet included, focusing on functional correctness and mode coverage for this PR.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces log_prob computation for rollouts in flow-matching diffusion pipelines, adding support for SDE and CPS modes. The changes are well-structured, correctly adding new sampling parameters, API endpoints, and plumbing the log_prob results through the pipeline stages. My review includes one suggestion to improve performance in the newly added sde_step_with_logprob function by optimizing a loop that could cause GPU-CPU synchronization overhead.

@MikukuOvO MikukuOvO force-pushed the feat/rollout-logprob-support branch from 23595dd to 020befb Compare February 13, 2026 20:20
@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

1 similar comment
@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

@MikukuOvO
Copy link
Contributor Author

/rerun-failed-ci

1 similar comment
@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

@zhaochenyang20
Copy link
Collaborator

rebase and fix lint please.

@zhaochenyang20
Copy link
Collaborator

Under the same seed and configuration, rollout outputs are deterministic: latent and log_prob are consistent across repeated runs.

This is nice. Do you think we can leverage latent and log_prob as metrics for CI?

Copy link
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Fix the lint.
  2. Could you add unit test to SDE and CPS. These APIs are important.

@zhaochenyang20
Copy link
Collaborator

unit test could be like:

#19164 (comment)

haonan3 added a commit to celve/sglang that referenced this pull request Feb 25, 2026
…#19153 (sleep/wake)

Cherry-picked from:
- PR sgl-project#18806 (MikukuOvO): flow-matching SDE/CPS log_prob
- PR sgl-project#19153 (Godmook): release/resume memory occupation

Known issues:
- t.item() in log_prob path causes GPU sync overhead
- release_memory_occupation tags only supports "weights"
@MikukuOvO
Copy link
Contributor Author

Under the same seed and configuration, rollout outputs are deterministic: latent and log_prob are consistent across repeated runs.

This is nice. Do you think we can leverage latent and log_prob as metrics for CI?

Great suggestion. Yes, I think we can leverage both latent and log_prob as CI metrics for rollout regression checks.

@MikukuOvO
Copy link
Contributor Author

Thanks for the detailed review. I have addressed all the requested changes above (sync/device handling, comment/doc cleanup, and removal of dynamic getattr/hasattr usage in the touched paths).

I am currently working on adding unit tests for both SDE and CPS rollout paths, and will push the test updates next.

haonan3 added a commit to celve/sglang that referenced this pull request Feb 26, 2026
…#19153 (sleep/wake)

Cherry-picked from:
- PR sgl-project#18806 (MikukuOvO): flow-matching SDE/CPS log_prob
- PR sgl-project#19153 (Godmook): release/resume memory occupation

Known issues:
- t.item() in log_prob path causes GPU sync overhead
- release_memory_occupation tags only supports "weights"
@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

MikukuOvO added a commit to MikukuOvO/sglang that referenced this pull request Mar 1, 2026
MikukuOvO added a commit to MikukuOvO/sglang that referenced this pull request Mar 2, 2026
@MikukuOvO
Copy link
Contributor Author

Thanks for the reviews! I've gone through all your comments and pushed the fixes.

@zhaochenyang20 zhaochenyang20 force-pushed the feat/rollout-logprob-support branch from 4abb27b to f1d30d1 Compare March 3, 2026 02:07
@zhaochenyang20
Copy link
Collaborator

zhaochenyang20 commented Mar 3, 2026

This is my verification commands:

  1. Install the changes:
cd python 
uv pip install -e ".[diffusion]"
  1. With Python. Note that only python API can get log_probs.
from sglang import DiffGenerator

  gen = DiffGenerator.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")

  # Mode 1: No Rollout (baseline)
  result_baseline = gen.generate(sampling_params_kwargs={
      "prompt": "A curious raccoon in a forest",
      "rollout": False, "save_output": True,
  })

  # Mode 2: SDE Rollout
  result_sde = gen.generate(sampling_params_kwargs={
      "prompt": "A curious raccoon in a forest",
      "rollout": True, "rollout_sde_type": "sde",
      "rollout_noise_level": 0.7, "return_trajectory_latents": True, "save_output": True,
  })

  # Mode 3: CPS Rollout
  result_cps = gen.generate(sampling_params_kwargs={
      "prompt": "A curious raccoon in a forest",
      "rollout": True, "rollout_sde_type": "cps",
      "rollout_noise_level": 0.7, "return_trajectory_latents": True, "save_output": True,
  })

  results = []
  for i in range(2):
      r = gen.generate(sampling_params_kwargs={
          "prompt": "A curious raccoon in a forest",
          "rollout": True, "rollout_sde_type": "sde",
          "rollout_noise_level": 0.7, "return_trajectory_latents": True, "seed": 42,
      })
      results.append(r)

  import torch
  print(f"Log_probs match: {torch.allclose(results[0].trajectory_log_probs, results[1].trajectory_log_probs)}")
  print(f"Latents match: {torch.allclose(results[0].trajectory_latents, results[1].trajectory_latents)}")
  1. With Server
sglang serve --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers --num-gpus 1
  # No Rollout                                                                                                                                         
  curl http://localhost:30000/v1/images/generations \             
    -H "Content-Type: application/json" \                                                                                                              
    -d '{                                                         
      "prompt": "A curious raccoon in a forest",
      "model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
    }'

  # SDE Rollout
  curl http://localhost:30000/v1/images/generations \
    -H "Content-Type: application/json" \
    -d '{
      "prompt": "A curious raccoon in a forest",
      "model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
      "rollout": true,
      "rollout_sde_type": "sde",
      "rollout_noise_level": 0.7
    }'

  # CPS Rollout
  curl http://localhost:30000/v1/images/generations \
    -H "Content-Type: application/json" \
    -d '{
      "prompt": "A curious raccoon in a forest",
      "model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
      "rollout": true,
      "rollout_sde_type": "cps",
      "rollout_noise_level": 0.7
    }'

@zhaochenyang20
Copy link
Collaborator

With Python. Note that only python API can get log_probs.

This is a critical issue. Every RL workload is running on a server, so the curl API should 100% have some ways to get log_probs and latent.

@zhaochenyang20
Copy link
Collaborator

zhaochenyang20 commented Mar 4, 2026

I do think that adding latent and log_prob in the response of the server is a problem.

  1. LLM also has this requirement. Of course this is required.
  2. The code - change is limited. I believe no more than 30 lines of code.
  3. The latency effect is limited as well. If the user does not require these two fields, we do not send them back to the user. Then the latency is not affected.

Considering the latency of transfering large data via https, if you think that's slow, you are correct. But we still need to send them back. So,

  1. Adds a request filed, like in your request, have a get_log_probs: True, get_latent: True. They are default to be False. But if true, transfer them back.
  2. Evaluate the time for transferring these large data through the HTTPS server. The basic idea is copy what we do for LLM. Just check how LLM gets log_probs and so on. Do the same thing.

@MikukuOvO
Copy link
Contributor Author

Quick manual verification (Qwen, Python API only)

cd python
uv pip install -e ".[diffusion]"

python - <<'PY'
import torch
from sglang.multimodal_gen import DiffGenerator

MODEL = "Qwen/Qwen-Image"
PROMPT = "A curious raccoon in a forest"

def one(gen, **kwargs):
    r = gen.generate(sampling_params_kwargs=kwargs)
    return r[0] if isinstance(r, list) else r

common = dict(
    prompt=PROMPT,
    save_output=False,              # quick test only
    rollout_noise_level=0.7,
    num_inference_steps=20,         # reduce runtime
)

with DiffGenerator.from_pretrained(model_path=MODEL, num_gpus=1) as gen:
    baseline = one(gen, **common, rollout=False)
    sde = one(gen, **common, rollout=True, rollout_sde_type="sde", return_trajectory_latents=True)
    cps = one(gen, **common, rollout=True, rollout_sde_type="cps", return_trajectory_latents=True)

    assert baseline.trajectory_log_probs is None
    assert sde.trajectory_log_probs is not None and sde.trajectory_latents is not None
    assert cps.trajectory_log_probs is not None and cps.trajectory_latents is not None

    r0 = one(gen, **common, rollout=True, rollout_sde_type="sde", return_trajectory_latents=True, seed=42)
    r1 = one(gen, **common, rollout=True, rollout_sde_type="sde", return_trajectory_latents=True, seed=42)

    lp_ok = torch.allclose(r0.trajectory_log_probs, r1.trajectory_log_probs)
    lat_ok = torch.allclose(r0.trajectory_latents, r1.trajectory_latents)
    assert lp_ok and lat_ok

    print("SDE log_probs shape:", tuple(sde.trajectory_log_probs.shape))
    print("SDE latents shape:", tuple(sde.trajectory_latents.shape))
    print("CPS log_probs shape:", tuple(cps.trajectory_log_probs.shape))
    print("CPS latents shape:", tuple(cps.trajectory_latents.shape))
    print("Determinism log_probs:", lp_ok)
    print("Determinism latents:", lat_ok)
    print("Quick rollout verification passed.")
PY

@yhyang201
Copy link
Collaborator

@mickqian Nvidia CI passed and PR is approved, ready for merge

Updated docstring to clarify sde_type options.
Comment on lines +1 to +2
# SPDX-License-Identifier: Apache-2.0
"""Flow-matching rollout step utilities for log-prob computation."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we adapt from other open-source diffusion workflows, we shall add acknowledgment here.

Comment on lines +1110 to +1128
if rollout_enabled:
latents, step_log_prob = sde_step_with_logprob(
self.scheduler,
model_output=noise_pred,
sample=latents,
step_index=rollout_step_indices[i],
generator=batch.generator,
sde_type=rollout_sde_type,
noise_level=rollout_noise_level,
)
trajectory_log_probs.append(step_log_prob)
else:
latents = self.scheduler.step(
model_output=noise_pred,
timestep=t_device,
sample=latents,
**extra_step_kwargs,
return_dict=False,
)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little bit unclear to me that sde_step_with_logprob vs self.scheduler.step in the input parameter. The most strange thing is that sde_step_with_logprob takes self.scheduler as a parameter while self.scheduler.step is an object method of scheduler. Could we share the same design pattern for parameters like:

  1. change sde_step_with_logprob to self.scheduler.sde_step_with_logprob
  2. Or, only have one entrypoint self.scheduler.step, but pass in step_index=rollout_step_indices[i], generator=batch.generator, sde_type=rollout_sde_type, noise_level=rollout_noise_level, as kwargs?

In deed I am not so sure about the process of SDE and CPS. Shall ask for help on design from BBuf, mick and Yuhao.

@yhyang201
Copy link
Collaborator

@mickqian Nvidia CI passed and PR is approved, ready for merge

— SGLDHelper bot

return_dict=False,
)[0]
if rollout_enabled:
latents, step_log_prob = sde_step_with_logprob(
Copy link
Collaborator

@alphabetc1 alphabetc1 Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rollout path bypasses self.scheduler.step(...) and directly computes the next sample from sigmas.

That seems not equivalent for multi-step schedulers like FlowUniPCMultistepScheduler, because their step() also updates internal state such as last_sample, model_outputs, timestep_list, and lower_order_nums.

Is this by design?

save_output: bool = True
return_frames: bool = False
rollout: bool = False
rollout_sde_type: str = "sde"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we validate rollout_sde_type in _validate so the request fails early

@Rockdu Rockdu force-pushed the feat/rollout-logprob-support branch from 19d822d to 91e3f0d Compare March 23, 2026 07:54
@Rockdu Rockdu force-pushed the feat/rollout-logprob-support branch 2 times, most recently from b0f98c8 to cd46eea Compare March 23, 2026 08:27
@Rockdu
Copy link

Rockdu commented Mar 23, 2026

Hi, since the original PR is relatively simple and lack necessary supports for parallel inference, I revamped the rollout part based on @MikukuOvO's version. See #21204
[Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training

@Rockdu Rockdu force-pushed the feat/rollout-logprob-support branch from 3f0fa57 to 19d822d Compare March 23, 2026 09:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants