Skip to content

Commit f1d30d1

Browse files
[diffusion] feat: add rollout log_prob with flow-matching SDE/CPS support
Rebased onto latest main.
1 parent e6e02ec commit f1d30d1

File tree

14 files changed

+201
-9
lines changed

14 files changed

+201
-9
lines changed

python/sglang/multimodal_gen/configs/sample/sampling_params.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ class SamplingParams:
159159
# Misc
160160
save_output: bool = True
161161
return_frames: bool = False
162+
rollout: bool = False
163+
rollout_sde_type: str = "sde"
164+
rollout_noise_level: float = 0.7
162165
return_trajectory_latents: bool = False # returns all latents for each timestep
163166
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
164167
# if True, disallow user params to override subclass-defined protected fields
@@ -306,6 +309,9 @@ def _finite_non_negative_float(
306309
_finite_non_negative_float(
307310
"guidance_rescale", self.guidance_rescale, allow_none=False
308311
)
312+
_finite_non_negative_float(
313+
"rollout_noise_level", self.rollout_noise_level, allow_none=False
314+
)
309315

310316
if self.cfg_normalization is None:
311317
self.cfg_normalization = 0.0
@@ -808,6 +814,25 @@ def add_cli_args(parser: Any) -> Any:
808814
default=SamplingParams.return_trajectory_latents,
809815
help="Whether to return the trajectory",
810816
)
817+
parser.add_argument(
818+
"--rollout",
819+
action="store_true",
820+
default=SamplingParams.rollout,
821+
help="Enable rollout mode and return per-step log_prob trajectory",
822+
)
823+
parser.add_argument(
824+
"--rollout-sde-type",
825+
type=str,
826+
choices=["sde", "cps"],
827+
default=SamplingParams.rollout_sde_type,
828+
help="Rollout step objective type used in log-prob computation.",
829+
)
830+
parser.add_argument(
831+
"--rollout-noise-level",
832+
type=float,
833+
default=SamplingParams.rollout_noise_level,
834+
help="Noise level used by rollout SDE/CPS step objective.",
835+
)
811836
parser.add_argument(
812837
"--return-trajectory-decoded",
813838
action="store_true",

python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def generate(
216216
),
217217
trajectory_latents=output_batch.trajectory_latents,
218218
trajectory_timesteps=output_batch.trajectory_timesteps,
219+
trajectory_log_probs=output_batch.trajectory_log_probs,
219220
trajectory_decoded=output_batch.trajectory_decoded,
220221
)
221222

python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ async def generations(
129129
true_cfg_scale=request.true_cfg_scale,
130130
negative_prompt=request.negative_prompt,
131131
enable_teacache=request.enable_teacache,
132+
rollout=request.rollout,
133+
rollout_sde_type=request.rollout_sde_type,
134+
rollout_noise_level=request.rollout_noise_level,
132135
output_compression=request.output_compression,
133136
output_quality=request.output_quality,
134137
)

python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class ImageGenerationsRequest(BaseModel):
4646
output_quality: Optional[str] = "default"
4747
output_compression: Optional[int] = None
4848
enable_teacache: Optional[bool] = False
49+
rollout: Optional[bool] = False
50+
rollout_sde_type: Optional[str] = "sde"
51+
rollout_noise_level: Optional[float] = 0.7
4952
diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend
5053

5154

@@ -98,6 +101,9 @@ class VideoGenerationsRequest(BaseModel):
98101
output_quality: Optional[str] = "default"
99102
output_compression: Optional[int] = None
100103
output_path: Optional[str] = None
104+
rollout: Optional[bool] = False
105+
rollout_sde_type: Optional[str] = "sde"
106+
rollout_noise_level: Optional[float] = 0.7
101107
diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend
102108

103109

python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def _build_video_sampling_params(request_id: str, request: VideoGenerationsReque
7575
frame_interpolation_exp=request.frame_interpolation_exp,
7676
frame_interpolation_scale=request.frame_interpolation_scale,
7777
frame_interpolation_model_path=request.frame_interpolation_model_path,
78+
rollout=request.rollout,
79+
rollout_sde_type=request.rollout_sde_type,
80+
rollout_noise_level=request.rollout_noise_level,
7881
output_path=request.output_path,
7982
output_compression=request.output_compression,
8083
output_quality=request.output_quality,
@@ -181,6 +184,9 @@ async def create_video(
181184
frame_interpolation_exp: Optional[int] = Form(1),
182185
frame_interpolation_scale: Optional[float] = Form(1.0),
183186
frame_interpolation_model_path: Optional[str] = Form(None),
187+
rollout: Optional[bool] = Form(False),
188+
rollout_sde_type: Optional[str] = Form("sde"),
189+
rollout_noise_level: Optional[float] = Form(0.7),
184190
output_quality: Optional[str] = Form("default"),
185191
output_compression: Optional[int] = Form(None),
186192
extra_body: Optional[str] = Form(None),
@@ -256,6 +262,9 @@ async def create_video(
256262
frame_interpolation_exp=frame_interpolation_exp,
257263
frame_interpolation_scale=frame_interpolation_scale,
258264
frame_interpolation_model_path=frame_interpolation_model_path,
265+
rollout=rollout,
266+
rollout_sde_type=rollout_sde_type,
267+
rollout_noise_level=rollout_noise_level,
259268
output_compression=output_compression,
260269
output_quality=output_quality,
261270
**(

python/sglang/multimodal_gen/runtime/entrypoints/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class GenerationResult:
108108
metrics: dict = field(default_factory=dict)
109109
trajectory_latents: Any = None
110110
trajectory_timesteps: Any = None
111+
trajectory_log_probs: Any = None
111112
trajectory_decoded: Any = None
112113
prompt_index: int = 0
113114
output_file_path: str | None = None

python/sglang/multimodal_gen/runtime/managers/gpu_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def execute_forward(self, batch: List[Req]) -> OutputBatch:
233233
metrics=result.metrics,
234234
trajectory_timesteps=getattr(result, "trajectory_timesteps", None),
235235
trajectory_latents=getattr(result, "trajectory_latents", None),
236+
trajectory_log_probs=getattr(result, "trajectory_log_probs", None),
236237
noise_pred=getattr(result, "noise_pred", None),
237238
trajectory_decoded=getattr(result, "trajectory_decoded", None),
238239
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Flow-matching rollout step utilities for log-prob computation."""
3+
4+
import math
5+
from typing import Any, Optional, Union
6+
7+
import torch
8+
from diffusers.utils.torch_utils import randn_tensor
9+
10+
11+
def sde_step_with_logprob(
12+
self: Any,
13+
model_output: torch.FloatTensor,
14+
sample: torch.FloatTensor,
15+
step_index: int,
16+
noise_level: float = 0.7,
17+
prev_sample: Optional[torch.FloatTensor] = None,
18+
generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
19+
sde_type: str = "sde",
20+
):
21+
"""Run one rollout step and compute per-sample log_prob.
22+
23+
sde_type="sde" uses the Gaussian transition objective.
24+
sde_type="cps" uses the simplified CPS objective.
25+
"""
26+
sample_dtype = sample.dtype
27+
model_output = model_output.float()
28+
sample = sample.float()
29+
if prev_sample is not None:
30+
prev_sample = prev_sample.float()
31+
32+
step_indices = torch.full(
33+
(sample.shape[0],),
34+
int(step_index),
35+
device=self.sigmas.device,
36+
dtype=torch.long,
37+
)
38+
prev_step_indices = (step_indices + 1).clamp_max(len(self.sigmas) - 1)
39+
sigma = self.sigmas[step_indices].to(device=sample.device, dtype=sample.dtype)
40+
sigma_prev = self.sigmas[prev_step_indices].to(
41+
device=sample.device, dtype=sample.dtype
42+
)
43+
sigma = sigma.view(-1, *([1] * (sample.ndim - 1)))
44+
sigma_prev = sigma_prev.view(-1, *([1] * (sample.ndim - 1)))
45+
sigma_max = self.sigmas[min(1, len(self.sigmas) - 1)].to(
46+
device=sample.device, dtype=sample.dtype
47+
)
48+
dt = sigma_prev - sigma
49+
50+
if sde_type == "sde":
51+
denom_sigma = 1 - torch.where(
52+
torch.isclose(sigma, sigma.new_tensor(1.0)), sigma_max, sigma
53+
)
54+
std_dev_t = torch.sqrt((sigma / denom_sigma).clamp_min(1e-12)) * noise_level
55+
prev_sample_mean = (
56+
sample * (1 + std_dev_t**2 / (2 * sigma) * dt)
57+
+ model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
58+
)
59+
60+
sqrt_neg_dt = torch.sqrt((-dt).clamp_min(1e-12))
61+
if prev_sample is None:
62+
variance_noise = randn_tensor(
63+
model_output.shape,
64+
generator=generator,
65+
device=model_output.device,
66+
dtype=model_output.dtype,
67+
)
68+
prev_sample = prev_sample_mean + std_dev_t * sqrt_neg_dt * variance_noise
69+
70+
std = (std_dev_t * sqrt_neg_dt).clamp_min(1e-12)
71+
log_prob = (
72+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2))
73+
- torch.log(std)
74+
- torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device)))
75+
)
76+
elif sde_type == "cps":
77+
std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2)
78+
pred_original_sample = sample - sigma * model_output
79+
noise_estimate = sample + model_output * (1 - sigma)
80+
sigma_delta = (sigma_prev**2 - std_dev_t**2).clamp_min(0.0)
81+
prev_sample_mean = pred_original_sample * (
82+
1 - sigma_prev
83+
) + noise_estimate * torch.sqrt(sigma_delta)
84+
85+
if prev_sample is None:
86+
variance_noise = randn_tensor(
87+
model_output.shape,
88+
generator=generator,
89+
device=model_output.device,
90+
dtype=model_output.dtype,
91+
)
92+
prev_sample = prev_sample_mean + std_dev_t * variance_noise
93+
94+
# CPS transition is Gaussian with std_dev_t, so compute a valid log-probability.
95+
std = std_dev_t.clamp_min(1e-12)
96+
log_prob = (
97+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2))
98+
- torch.log(std)
99+
- torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device)))
100+
)
101+
else:
102+
raise ValueError(f"Unsupported sde_type: {sde_type}")
103+
104+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
105+
return prev_sample.to(sample_dtype), log_prob

python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@ class Req:
133133
# Component modules (populated by the pipeline)
134134
modules: dict[str, Any] = field(default_factory=dict)
135135

136-
trajectory_timesteps: list[torch.Tensor] | None = None
136+
trajectory_timesteps: torch.Tensor | None = None
137137
trajectory_latents: torch.Tensor | None = None
138+
trajectory_log_probs: torch.Tensor | None = None
138139
trajectory_audio_latents: torch.Tensor | None = None
139140

140141
# Extra parameters that might be needed by specific pipeline implementations
@@ -329,8 +330,9 @@ class OutputBatch:
329330
output: torch.Tensor | None = None
330331
audio: torch.Tensor | None = None
331332
audio_sample_rate: int | None = None
332-
trajectory_timesteps: list[torch.Tensor] | None = None
333+
trajectory_timesteps: torch.Tensor | None = None
333334
trajectory_latents: torch.Tensor | None = None
335+
trajectory_log_probs: torch.Tensor | None = None
334336
trajectory_decoded: list[torch.Tensor] | None = None
335337
error: str | None = None
336338
output_file_paths: list[str] | None = None

0 commit comments

Comments
 (0)