Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/sglang/multimodal_gen/configs/sample/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ class SamplingParams:
# Misc
save_output: bool = True
return_frames: bool = False
rollout: bool = False
rollout_sde_type: str = "sde"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we validate rollout_sde_type in _validate so the request fails early

rollout_noise_level: float = 0.7
return_trajectory_latents: bool = False # returns all latents for each timestep
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
# if True, disallow user params to override subclass-defined protected fields
Expand Down Expand Up @@ -306,6 +309,9 @@ def _finite_non_negative_float(
_finite_non_negative_float(
"guidance_rescale", self.guidance_rescale, allow_none=False
)
_finite_non_negative_float(
"rollout_noise_level", self.rollout_noise_level, allow_none=False
)

if self.cfg_normalization is None:
self.cfg_normalization = 0.0
Expand Down Expand Up @@ -808,6 +814,25 @@ def add_cli_args(parser: Any) -> Any:
default=SamplingParams.return_trajectory_latents,
help="Whether to return the trajectory",
)
parser.add_argument(
"--rollout",
action="store_true",
default=SamplingParams.rollout,
help="Enable rollout mode and return per-step log_prob trajectory",
)
parser.add_argument(
"--rollout-sde-type",
type=str,
choices=["sde", "cps"],
default=SamplingParams.rollout_sde_type,
help="Rollout step objective type used in log-prob computation.",
)
parser.add_argument(
"--rollout-noise-level",
type=float,
default=SamplingParams.rollout_noise_level,
help="Noise level used by rollout SDE/CPS step objective.",
)
parser.add_argument(
"--return-trajectory-decoded",
action="store_true",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def generate(
),
trajectory_latents=output_batch.trajectory_latents,
trajectory_timesteps=output_batch.trajectory_timesteps,
trajectory_log_probs=output_batch.trajectory_log_probs,
trajectory_decoded=output_batch.trajectory_decoded,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ async def generations(
true_cfg_scale=request.true_cfg_scale,
negative_prompt=request.negative_prompt,
enable_teacache=request.enable_teacache,
rollout=request.rollout,
rollout_sde_type=request.rollout_sde_type,
rollout_noise_level=request.rollout_noise_level,
output_compression=request.output_compression,
output_quality=request.output_quality,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class ImageGenerationsRequest(BaseModel):
output_quality: Optional[str] = "default"
output_compression: Optional[int] = None
enable_teacache: Optional[bool] = False
rollout: Optional[bool] = False
rollout_sde_type: Optional[str] = "sde"
rollout_noise_level: Optional[float] = 0.7
diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend


Expand Down Expand Up @@ -98,6 +101,9 @@ class VideoGenerationsRequest(BaseModel):
output_quality: Optional[str] = "default"
output_compression: Optional[int] = None
output_path: Optional[str] = None
rollout: Optional[bool] = False
rollout_sde_type: Optional[str] = "sde"
rollout_noise_level: Optional[float] = 0.7
diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def _build_video_sampling_params(request_id: str, request: VideoGenerationsReque
frame_interpolation_exp=request.frame_interpolation_exp,
frame_interpolation_scale=request.frame_interpolation_scale,
frame_interpolation_model_path=request.frame_interpolation_model_path,
rollout=request.rollout,
rollout_sde_type=request.rollout_sde_type,
rollout_noise_level=request.rollout_noise_level,
output_path=request.output_path,
output_compression=request.output_compression,
output_quality=request.output_quality,
Expand Down Expand Up @@ -181,6 +184,9 @@ async def create_video(
frame_interpolation_exp: Optional[int] = Form(1),
frame_interpolation_scale: Optional[float] = Form(1.0),
frame_interpolation_model_path: Optional[str] = Form(None),
rollout: Optional[bool] = Form(False),
rollout_sde_type: Optional[str] = Form("sde"),
rollout_noise_level: Optional[float] = Form(0.7),
output_quality: Optional[str] = Form("default"),
output_compression: Optional[int] = Form(None),
extra_body: Optional[str] = Form(None),
Expand Down Expand Up @@ -256,6 +262,9 @@ async def create_video(
frame_interpolation_exp=frame_interpolation_exp,
frame_interpolation_scale=frame_interpolation_scale,
frame_interpolation_model_path=frame_interpolation_model_path,
rollout=rollout,
rollout_sde_type=rollout_sde_type,
rollout_noise_level=rollout_noise_level,
output_compression=output_compression,
output_quality=output_quality,
**(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class GenerationResult:
metrics: dict = field(default_factory=dict)
trajectory_latents: Any = None
trajectory_timesteps: Any = None
trajectory_log_probs: Any = None
trajectory_decoded: Any = None
prompt_index: int = 0
output_file_path: str | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def execute_forward(self, batch: List[Req]) -> OutputBatch:
metrics=result.metrics,
trajectory_timesteps=getattr(result, "trajectory_timesteps", None),
trajectory_latents=getattr(result, "trajectory_latents", None),
trajectory_log_probs=getattr(result, "trajectory_log_probs", None),
noise_pred=getattr(result, "noise_pred", None),
trajectory_decoded=getattr(result, "trajectory_decoded", None),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
"""Flow-matching rollout step utilities for log-prob computation."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we adapt from other open-source diffusion workflows, we shall add acknowledgment here.


import math
from typing import Any, Optional, Union

import torch
from diffusers.utils.torch_utils import randn_tensor


def sde_step_with_logprob(
self: Any,
model_output: torch.FloatTensor,
sample: torch.FloatTensor,
step_index: int,
noise_level: float = 0.7,
prev_sample: Optional[torch.FloatTensor] = None,
generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
sde_type: str = "sde",
):
"""Run one rollout step and compute per-sample log_prob.
sde_type="sde" uses the Gaussian transition objective.
sde_type="cps" uses the simplified CPS objective.
"""
sample_dtype = sample.dtype
model_output = model_output.float()
sample = sample.float()
if prev_sample is not None:
prev_sample = prev_sample.float()

step_indices = torch.full(
(sample.shape[0],),
int(step_index),
device=self.sigmas.device,
dtype=torch.long,
)
prev_step_indices = (step_indices + 1).clamp_max(len(self.sigmas) - 1)
sigma = self.sigmas[step_indices].to(device=sample.device, dtype=sample.dtype)
sigma_prev = self.sigmas[prev_step_indices].to(
device=sample.device, dtype=sample.dtype
)
sigma = sigma.view(-1, *([1] * (sample.ndim - 1)))
sigma_prev = sigma_prev.view(-1, *([1] * (sample.ndim - 1)))
sigma_max = self.sigmas[min(1, len(self.sigmas) - 1)].to(
device=sample.device, dtype=sample.dtype
)
dt = sigma_prev - sigma

if sde_type == "sde":
denom_sigma = 1 - torch.where(
torch.isclose(sigma, sigma.new_tensor(1.0)), sigma_max, sigma
)
std_dev_t = torch.sqrt((sigma / denom_sigma).clamp_min(1e-12)) * noise_level
prev_sample_mean = (
sample * (1 + std_dev_t**2 / (2 * sigma) * dt)
+ model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
)

sqrt_neg_dt = torch.sqrt((-dt).clamp_min(1e-12))
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * sqrt_neg_dt * variance_noise

std = (std_dev_t * sqrt_neg_dt).clamp_min(1e-12)
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2))
- torch.log(std)
- torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device)))
)
elif sde_type == "cps":
std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2)
pred_original_sample = sample - sigma * model_output
noise_estimate = sample + model_output * (1 - sigma)
sigma_delta = (sigma_prev**2 - std_dev_t**2).clamp_min(0.0)
prev_sample_mean = pred_original_sample * (
1 - sigma_prev
) + noise_estimate * torch.sqrt(sigma_delta)

if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * variance_noise

# CPS transition is Gaussian with std_dev_t, so compute a valid log-probability.
std = std_dev_t.clamp_min(1e-12)
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2))
- torch.log(std)
- torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device)))
)
else:
raise ValueError(f"Unsupported sde_type: {sde_type}")

log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return prev_sample.to(sample_dtype), log_prob
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ class Req:
# Component modules (populated by the pipeline)
modules: dict[str, Any] = field(default_factory=dict)

trajectory_timesteps: list[torch.Tensor] | None = None
trajectory_timesteps: torch.Tensor | None = None
trajectory_latents: torch.Tensor | None = None
trajectory_log_probs: torch.Tensor | None = None
trajectory_audio_latents: torch.Tensor | None = None

# Extra parameters that might be needed by specific pipeline implementations
Expand Down Expand Up @@ -329,8 +330,9 @@ class OutputBatch:
output: torch.Tensor | None = None
audio: torch.Tensor | None = None
audio_sample_rate: int | None = None
trajectory_timesteps: list[torch.Tensor] | None = None
trajectory_timesteps: torch.Tensor | None = None
trajectory_latents: torch.Tensor | None = None
trajectory_log_probs: torch.Tensor | None = None
trajectory_decoded: list[torch.Tensor] | None = None
error: str | None = None
output_file_paths: list[str] | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def forward(
output=frames,
trajectory_timesteps=batch.trajectory_timesteps,
trajectory_latents=batch.trajectory_latents,
trajectory_log_probs=batch.trajectory_log_probs,
trajectory_decoded=trajectory_decoded,
metrics=batch.metrics,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> OutputBatch:
output=video,
trajectory_timesteps=batch.trajectory_timesteps,
trajectory_latents=batch.trajectory_latents,
trajectory_log_probs=batch.trajectory_log_probs,
trajectory_decoded=None,
metrics=batch.metrics,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
TransformerLoader,
)
from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context
from sglang.multimodal_gen.runtime.pipelines.patches.flow_matching_with_logprob import (
sde_step_with_logprob,
)
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.stages.base import (
PipelineStage,
Expand Down Expand Up @@ -695,6 +698,7 @@ def _post_denoising_loop(
latents: torch.Tensor,
trajectory_latents: list,
trajectory_timesteps: list,
trajectory_log_probs: list,
server_args: ServerArgs,
is_warmup: bool = False,
):
Expand All @@ -705,6 +709,10 @@ def _post_denoising_loop(
else:
trajectory_tensor = None
trajectory_timesteps_tensor = None
if trajectory_log_probs:
trajectory_log_probs_tensor = torch.stack(trajectory_log_probs, dim=1)
else:
trajectory_log_probs_tensor = None

# Gather results if using sequence parallelism
latents, trajectory_tensor = self._postprocess_sp_latents(
Expand All @@ -731,6 +739,8 @@ def _post_denoising_loop(
if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
batch.trajectory_latents = trajectory_tensor.cpu()
if trajectory_log_probs_tensor is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No sure wether this works in SP. We can leave it to the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like SP isn’t supported. Could we add a TODO or FIXME?

batch.trajectory_log_probs = trajectory_log_probs_tensor.cpu()

# Update batch with final latents
batch.latents = self.server_args.pipeline_config.post_denoising_loop(
Expand Down Expand Up @@ -998,6 +1008,11 @@ def forward(
# Initialize lists for ODE trajectory
trajectory_timesteps: list[torch.Tensor] = []
trajectory_latents: list[torch.Tensor] = []
trajectory_log_probs: list[torch.Tensor] = []
rollout_enabled = bool(batch.rollout)
rollout_sde_type = batch.rollout_sde_type

rollout_noise_level = batch.rollout_noise_level

# Run denoising loop
denoising_start_time = time.time()
Expand All @@ -1006,6 +1021,13 @@ def forward(
is_warmup = batch.is_warmup
self.scheduler.set_begin_index(0)
timesteps_cpu = timesteps.cpu()
rollout_step_indices: list[int] = []
if rollout_enabled:
scheduler_timesteps = self.scheduler.timesteps
rollout_step_indices = [
self.scheduler.index_for_timestep(t.to(scheduler_timesteps.device))
for t in timesteps_cpu
]
num_timesteps = timesteps_cpu.shape[0]
with torch.autocast(
device_type=current_platform.device_type,
Expand Down Expand Up @@ -1085,13 +1107,25 @@ def forward(
batch.noise_pred = noise_pred

# Compute the previous noisy sample
latents = self.scheduler.step(
model_output=noise_pred,
timestep=t_device,
sample=latents,
**extra_step_kwargs,
return_dict=False,
)[0]
if rollout_enabled:
latents, step_log_prob = sde_step_with_logprob(
Copy link
Collaborator

@alphabetc1 alphabetc1 Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rollout path bypasses self.scheduler.step(...) and directly computes the next sample from sigmas.

That seems not equivalent for multi-step schedulers like FlowUniPCMultistepScheduler, because their step() also updates internal state such as last_sample, model_outputs, timestep_list, and lower_order_nums.

Is this by design?

self.scheduler,
model_output=noise_pred,
sample=latents,
step_index=rollout_step_indices[i],
generator=batch.generator,
sde_type=rollout_sde_type,
noise_level=rollout_noise_level,
)
trajectory_log_probs.append(step_log_prob)
else:
latents = self.scheduler.step(
model_output=noise_pred,
timestep=t_device,
sample=latents,
**extra_step_kwargs,
return_dict=False,
)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little bit unclear to me that sde_step_with_logprob vs self.scheduler.step in the input parameter. The most strange thing is that sde_step_with_logprob takes self.scheduler as a parameter while self.scheduler.step is an object method of scheduler. Could we share the same design pattern for parameters like:

  1. change sde_step_with_logprob to self.scheduler.sde_step_with_logprob
  2. Or, only have one entrypoint self.scheduler.step, but pass in step_index=rollout_step_indices[i], generator=batch.generator, sde_type=rollout_sde_type, noise_level=rollout_noise_level, as kwargs?

In deed I am not so sure about the process of SDE and CPS. Shall ask for help on design from BBuf, mick and Yuhao.


latents = self.post_forward_for_ti2v_task(
batch, server_args, reserved_frames_mask, latents, z
Expand Down Expand Up @@ -1126,6 +1160,7 @@ def forward(
latents=latents,
trajectory_latents=trajectory_latents,
trajectory_timesteps=trajectory_timesteps,
trajectory_log_probs=trajectory_log_probs,
server_args=server_args,
is_warmup=is_warmup,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def forward(
latents=latents,
trajectory_latents=[],
trajectory_timesteps=[],
trajectory_log_probs=[],
server_args=server_args,
)

Expand Down
Loading