[diffusion] feat: add rollout log_prob with flow-matching SDE/CPS support#18806
[diffusion] feat: add rollout log_prob with flow-matching SDE/CPS support#18806MikukuOvO wants to merge 2 commits intosgl-project:mainfrom
log_prob with flow-matching SDE/CPS support#18806Conversation
Summary of ChangesHello @MikukuOvO, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the diffusion flow-matching pipeline by integrating the capability to compute and expose Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces log_prob computation for rollouts in flow-matching diffusion pipelines, adding support for SDE and CPS modes. The changes are well-structured, correctly adding new sampling parameters, API endpoints, and plumbing the log_prob results through the pipeline stages. My review includes one suggestion to improve performance in the newly added sde_step_with_logprob function by optimizing a loop that could cause GPU-CPU synchronization overhead.
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
23595dd to
020befb
Compare
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py
Outdated
Show resolved
Hide resolved
|
rebase and fix lint please. |
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
This is nice. Do you think we can leverage |
zhaochenyang20
left a comment
There was a problem hiding this comment.
- Fix the lint.
- Could you add unit test to SDE and CPS. These APIs are important.
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py
Outdated
Show resolved
Hide resolved
|
unit test could be like: |
…#19153 (sleep/wake) Cherry-picked from: - PR sgl-project#18806 (MikukuOvO): flow-matching SDE/CPS log_prob - PR sgl-project#19153 (Godmook): release/resume memory occupation Known issues: - t.item() in log_prob path causes GPU sync overhead - release_memory_occupation tags only supports "weights"
Great suggestion. Yes, I think we can leverage both latent and log_prob as CI metrics for rollout regression checks. |
|
Thanks for the detailed review. I have addressed all the requested changes above (sync/device handling, comment/doc cleanup, and removal of dynamic getattr/hasattr usage in the touched paths). I am currently working on adding unit tests for both SDE and CPS rollout paths, and will push the test updates next. |
…#19153 (sleep/wake) Cherry-picked from: - PR sgl-project#18806 (MikukuOvO): flow-matching SDE/CPS log_prob - PR sgl-project#19153 (Godmook): release/resume memory occupation Known issues: - t.item() in log_prob path causes GPU sync overhead - release_memory_occupation tags only supports "weights"
|
/rerun-failed-ci |
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py
Outdated
Show resolved
Hide resolved
|
Thanks for the reviews! I've gone through all your comments and pushed the fixes. |
…port Rebased onto latest main.
4abb27b to
f1d30d1
Compare
|
This is my verification commands:
cd python
uv pip install -e ".[diffusion]"
from sglang import DiffGenerator
gen = DiffGenerator.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
# Mode 1: No Rollout (baseline)
result_baseline = gen.generate(sampling_params_kwargs={
"prompt": "A curious raccoon in a forest",
"rollout": False, "save_output": True,
})
# Mode 2: SDE Rollout
result_sde = gen.generate(sampling_params_kwargs={
"prompt": "A curious raccoon in a forest",
"rollout": True, "rollout_sde_type": "sde",
"rollout_noise_level": 0.7, "return_trajectory_latents": True, "save_output": True,
})
# Mode 3: CPS Rollout
result_cps = gen.generate(sampling_params_kwargs={
"prompt": "A curious raccoon in a forest",
"rollout": True, "rollout_sde_type": "cps",
"rollout_noise_level": 0.7, "return_trajectory_latents": True, "save_output": True,
})
results = []
for i in range(2):
r = gen.generate(sampling_params_kwargs={
"prompt": "A curious raccoon in a forest",
"rollout": True, "rollout_sde_type": "sde",
"rollout_noise_level": 0.7, "return_trajectory_latents": True, "seed": 42,
})
results.append(r)
import torch
print(f"Log_probs match: {torch.allclose(results[0].trajectory_log_probs, results[1].trajectory_log_probs)}")
print(f"Latents match: {torch.allclose(results[0].trajectory_latents, results[1].trajectory_latents)}")
sglang serve --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers --num-gpus 1 # No Rollout
curl http://localhost:30000/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A curious raccoon in a forest",
"model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
}'
# SDE Rollout
curl http://localhost:30000/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A curious raccoon in a forest",
"model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"rollout": true,
"rollout_sde_type": "sde",
"rollout_noise_level": 0.7
}'
# CPS Rollout
curl http://localhost:30000/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A curious raccoon in a forest",
"model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"rollout": true,
"rollout_sde_type": "cps",
"rollout_noise_level": 0.7
}' |
This is a critical issue. Every RL workload is running on a server, so the curl API should 100% have some ways to get |
|
I do think that adding
Considering the latency of transfering large data via https, if you think that's slow, you are correct. But we still need to send them back. So,
|
…port (sgl-project#18806) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Quick manual verification (Qwen, Python API only) cd python
uv pip install -e ".[diffusion]"
python - <<'PY'
import torch
from sglang.multimodal_gen import DiffGenerator
MODEL = "Qwen/Qwen-Image"
PROMPT = "A curious raccoon in a forest"
def one(gen, **kwargs):
r = gen.generate(sampling_params_kwargs=kwargs)
return r[0] if isinstance(r, list) else r
common = dict(
prompt=PROMPT,
save_output=False, # quick test only
rollout_noise_level=0.7,
num_inference_steps=20, # reduce runtime
)
with DiffGenerator.from_pretrained(model_path=MODEL, num_gpus=1) as gen:
baseline = one(gen, **common, rollout=False)
sde = one(gen, **common, rollout=True, rollout_sde_type="sde", return_trajectory_latents=True)
cps = one(gen, **common, rollout=True, rollout_sde_type="cps", return_trajectory_latents=True)
assert baseline.trajectory_log_probs is None
assert sde.trajectory_log_probs is not None and sde.trajectory_latents is not None
assert cps.trajectory_log_probs is not None and cps.trajectory_latents is not None
r0 = one(gen, **common, rollout=True, rollout_sde_type="sde", return_trajectory_latents=True, seed=42)
r1 = one(gen, **common, rollout=True, rollout_sde_type="sde", return_trajectory_latents=True, seed=42)
lp_ok = torch.allclose(r0.trajectory_log_probs, r1.trajectory_log_probs)
lat_ok = torch.allclose(r0.trajectory_latents, r1.trajectory_latents)
assert lp_ok and lat_ok
print("SDE log_probs shape:", tuple(sde.trajectory_log_probs.shape))
print("SDE latents shape:", tuple(sde.trajectory_latents.shape))
print("CPS log_probs shape:", tuple(cps.trajectory_log_probs.shape))
print("CPS latents shape:", tuple(cps.trajectory_latents.shape))
print("Determinism log_probs:", lp_ok)
print("Determinism latents:", lat_ok)
print("Quick rollout verification passed.")
PY |
|
@mickqian Nvidia CI passed and PR is approved, ready for merge |
Updated docstring to clarify sde_type options.
| # SPDX-License-Identifier: Apache-2.0 | ||
| """Flow-matching rollout step utilities for log-prob computation.""" |
There was a problem hiding this comment.
If we adapt from other open-source diffusion workflows, we shall add acknowledgment here.
| if rollout_enabled: | ||
| latents, step_log_prob = sde_step_with_logprob( | ||
| self.scheduler, | ||
| model_output=noise_pred, | ||
| sample=latents, | ||
| step_index=rollout_step_indices[i], | ||
| generator=batch.generator, | ||
| sde_type=rollout_sde_type, | ||
| noise_level=rollout_noise_level, | ||
| ) | ||
| trajectory_log_probs.append(step_log_prob) | ||
| else: | ||
| latents = self.scheduler.step( | ||
| model_output=noise_pred, | ||
| timestep=t_device, | ||
| sample=latents, | ||
| **extra_step_kwargs, | ||
| return_dict=False, | ||
| )[0] |
There was a problem hiding this comment.
It's a little bit unclear to me that sde_step_with_logprob vs self.scheduler.step in the input parameter. The most strange thing is that sde_step_with_logprob takes self.scheduler as a parameter while self.scheduler.step is an object method of scheduler. Could we share the same design pattern for parameters like:
- change
sde_step_with_logprobtoself.scheduler.sde_step_with_logprob - Or, only have one entrypoint
self.scheduler.step, but pass instep_index=rollout_step_indices[i], generator=batch.generator, sde_type=rollout_sde_type, noise_level=rollout_noise_level,askwargs?
In deed I am not so sure about the process of SDE and CPS. Shall ask for help on design from BBuf, mick and Yuhao.
|
@mickqian Nvidia CI passed and PR is approved, ready for merge — SGLDHelper bot |
| return_dict=False, | ||
| )[0] | ||
| if rollout_enabled: | ||
| latents, step_log_prob = sde_step_with_logprob( |
There was a problem hiding this comment.
This rollout path bypasses self.scheduler.step(...) and directly computes the next sample from sigmas.
That seems not equivalent for multi-step schedulers like FlowUniPCMultistepScheduler, because their step() also updates internal state such as last_sample, model_outputs, timestep_list, and lower_order_nums.
Is this by design?
| save_output: bool = True | ||
| return_frames: bool = False | ||
| rollout: bool = False | ||
| rollout_sde_type: str = "sde" |
There was a problem hiding this comment.
can we validate rollout_sde_type in _validate so the request fails early
19d822d to
91e3f0d
Compare
b0f98c8 to
cd46eea
Compare
|
Hi, since the original PR is relatively simple and lack necessary supports for parallel inference, I revamped the rollout part based on @MikukuOvO's version. See #21204 |
3f0fa57 to
19d822d
Compare
Motivation
This PR adds rollout
log_probsupport for diffusion flow-matching pipelines.Previously, rollout paths did not expose consistent
log_probsignals for flow-matching variants (especially SDE/CPS), which limited downstream training/evaluation workflows that depend on likelihood-based objectives.Modifications
log_probcomputation in the diffusion rollout path.log_probin rollout outputs with consistent shape/semantics across modes.log_probis not requested.Accuracy Tests
All tests were run with fixed random seeds.
Under the same seed and configuration, rollout outputs are deterministic:
latentandlog_probare consistent across repeated runs.Flux
no rollout,CPS,SDE.latentandlog_probunder fixed seed.Qwen
no rollout,CPS,SDE.latentandlog_probunder fixed seed.Z-Image
no rollout,CPS,SDE.latentandlog_probunder fixed seed.Benchmarking and Profiling
log_probcorrectness and mode coverage.Checklist
Review Process