diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index e5388df60afc..c6ef163b35b6 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -159,6 +159,9 @@ class SamplingParams: # Misc save_output: bool = True return_frames: bool = False + rollout: bool = False + rollout_sde_type: str = "sde" + rollout_noise_level: float = 0.7 return_trajectory_latents: bool = False # returns all latents for each timestep return_trajectory_decoded: bool = False # returns decoded latents for each timestep # if True, disallow user params to override subclass-defined protected fields @@ -306,6 +309,9 @@ def _finite_non_negative_float( _finite_non_negative_float( "guidance_rescale", self.guidance_rescale, allow_none=False ) + _finite_non_negative_float( + "rollout_noise_level", self.rollout_noise_level, allow_none=False + ) if self.cfg_normalization is None: self.cfg_normalization = 0.0 @@ -808,6 +814,25 @@ def add_cli_args(parser: Any) -> Any: default=SamplingParams.return_trajectory_latents, help="Whether to return the trajectory", ) + parser.add_argument( + "--rollout", + action="store_true", + default=SamplingParams.rollout, + help="Enable rollout mode and return per-step log_prob trajectory", + ) + parser.add_argument( + "--rollout-sde-type", + type=str, + choices=["sde", "cps"], + default=SamplingParams.rollout_sde_type, + help="Rollout step objective type used in log-prob computation.", + ) + parser.add_argument( + "--rollout-noise-level", + type=float, + default=SamplingParams.rollout_noise_level, + help="Noise level used by rollout SDE/CPS step objective.", + ) parser.add_argument( "--return-trajectory-decoded", action="store_true", diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py index ace021b20385..6221aa34f535 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py @@ -216,6 +216,7 @@ def generate( ), trajectory_latents=output_batch.trajectory_latents, trajectory_timesteps=output_batch.trajectory_timesteps, + trajectory_log_probs=output_batch.trajectory_log_probs, trajectory_decoded=output_batch.trajectory_decoded, ) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py index 9d4cb410c18b..8728b0c2ce8d 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py @@ -129,6 +129,9 @@ async def generations( true_cfg_scale=request.true_cfg_scale, negative_prompt=request.negative_prompt, enable_teacache=request.enable_teacache, + rollout=request.rollout, + rollout_sde_type=request.rollout_sde_type, + rollout_noise_level=request.rollout_noise_level, output_compression=request.output_compression, output_quality=request.output_quality, ) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py index c959e2f22259..953ce388277c 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py @@ -46,6 +46,9 @@ class ImageGenerationsRequest(BaseModel): output_quality: Optional[str] = "default" output_compression: Optional[int] = None enable_teacache: Optional[bool] = False + rollout: Optional[bool] = False + rollout_sde_type: Optional[str] = "sde" + rollout_noise_level: Optional[float] = 0.7 diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend @@ -98,6 +101,9 @@ class VideoGenerationsRequest(BaseModel): output_quality: Optional[str] = "default" output_compression: Optional[int] = None output_path: Optional[str] = None + rollout: Optional[bool] = False + rollout_sde_type: Optional[str] = "sde" + rollout_noise_level: Optional[float] = 0.7 diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py index 9db0fde3ca61..a3aabd01d88e 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py @@ -75,6 +75,9 @@ def _build_video_sampling_params(request_id: str, request: VideoGenerationsReque frame_interpolation_exp=request.frame_interpolation_exp, frame_interpolation_scale=request.frame_interpolation_scale, frame_interpolation_model_path=request.frame_interpolation_model_path, + rollout=request.rollout, + rollout_sde_type=request.rollout_sde_type, + rollout_noise_level=request.rollout_noise_level, output_path=request.output_path, output_compression=request.output_compression, output_quality=request.output_quality, @@ -181,6 +184,9 @@ async def create_video( frame_interpolation_exp: Optional[int] = Form(1), frame_interpolation_scale: Optional[float] = Form(1.0), frame_interpolation_model_path: Optional[str] = Form(None), + rollout: Optional[bool] = Form(False), + rollout_sde_type: Optional[str] = Form("sde"), + rollout_noise_level: Optional[float] = Form(0.7), output_quality: Optional[str] = Form("default"), output_compression: Optional[int] = Form(None), extra_body: Optional[str] = Form(None), @@ -256,6 +262,9 @@ async def create_video( frame_interpolation_exp=frame_interpolation_exp, frame_interpolation_scale=frame_interpolation_scale, frame_interpolation_model_path=frame_interpolation_model_path, + rollout=rollout, + rollout_sde_type=rollout_sde_type, + rollout_noise_level=rollout_noise_level, output_compression=output_compression, output_quality=output_quality, **( diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/utils.py b/python/sglang/multimodal_gen/runtime/entrypoints/utils.py index 95bc98ef7ac0..ea527981e222 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/utils.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/utils.py @@ -108,6 +108,7 @@ class GenerationResult: metrics: dict = field(default_factory=dict) trajectory_latents: Any = None trajectory_timesteps: Any = None + trajectory_log_probs: Any = None trajectory_decoded: Any = None prompt_index: int = 0 output_file_path: str | None = None diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index ca51f21fb9fe..3e7ffcf329ca 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -233,6 +233,7 @@ def execute_forward(self, batch: List[Req]) -> OutputBatch: metrics=result.metrics, trajectory_timesteps=getattr(result, "trajectory_timesteps", None), trajectory_latents=getattr(result, "trajectory_latents", None), + trajectory_log_probs=getattr(result, "trajectory_log_probs", None), noise_pred=getattr(result, "noise_pred", None), trajectory_decoded=getattr(result, "trajectory_decoded", None), ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/patches/__init__.py b/python/sglang/multimodal_gen/runtime/pipelines/patches/__init__.py new file mode 100644 index 000000000000..9881313609aa --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/patches/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py b/python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py new file mode 100644 index 000000000000..706e4936bc02 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Flow-matching rollout step utilities for log-prob computation.""" + +import math +from typing import Any, Optional, Union + +import torch +from diffusers.utils.torch_utils import randn_tensor + + +def sde_step_with_logprob( + self: Any, + model_output: torch.FloatTensor, + sample: torch.FloatTensor, + step_index: int, + noise_level: float = 0.7, + prev_sample: Optional[torch.FloatTensor] = None, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + sde_type: str = "sde", +): + """Run one rollout step and compute per-sample log_prob. + + sde_type: + 1. "sde": Standard Stochastic Differential Equation transition. + 2. "cps": Coupled Particle Sampling. + """ + sample_dtype = sample.dtype + model_output = model_output.float() + sample = sample.float() + if prev_sample is not None: + prev_sample = prev_sample.float() + + step_indices = torch.full( + (sample.shape[0],), + int(step_index), + device=self.sigmas.device, + dtype=torch.long, + ) + prev_step_indices = (step_indices + 1).clamp_max(len(self.sigmas) - 1) + sigma = self.sigmas[step_indices].to(device=sample.device, dtype=sample.dtype) + sigma_prev = self.sigmas[prev_step_indices].to( + device=sample.device, dtype=sample.dtype + ) + sigma = sigma.view(-1, *([1] * (sample.ndim - 1))) + sigma_prev = sigma_prev.view(-1, *([1] * (sample.ndim - 1))) + sigma_max = self.sigmas[min(1, len(self.sigmas) - 1)].to( + device=sample.device, dtype=sample.dtype + ) + dt = sigma_prev - sigma + + if sde_type == "sde": + denom_sigma = 1 - torch.where( + torch.isclose(sigma, sigma.new_tensor(1.0)), sigma_max, sigma + ) + std_dev_t = torch.sqrt((sigma / denom_sigma).clamp_min(1e-12)) * noise_level + prev_sample_mean = ( + sample * (1 + std_dev_t**2 / (2 * sigma) * dt) + + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt + ) + + sqrt_neg_dt = torch.sqrt((-dt).clamp_min(1e-12)) + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * sqrt_neg_dt * variance_noise + + std = (std_dev_t * sqrt_neg_dt).clamp_min(1e-12) + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2)) + - torch.log(std) + - torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device))) + ) + elif sde_type == "cps": + std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) + pred_original_sample = sample - sigma * model_output + noise_estimate = sample + model_output * (1 - sigma) + sigma_delta = (sigma_prev**2 - std_dev_t**2).clamp_min(0.0) + prev_sample_mean = pred_original_sample * ( + 1 - sigma_prev + ) + noise_estimate * torch.sqrt(sigma_delta) + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + # CPS transition is Gaussian with std_dev_t, so compute a valid log-probability. + std = std_dev_t.clamp_min(1e-12) + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2)) + - torch.log(std) + - torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device))) + ) + else: + raise ValueError(f"Unsupported sde_type: {sde_type}") + + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + return prev_sample.to(sample_dtype), log_prob diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py index 7a33bff8a1b3..b6e6436d7bcc 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py @@ -133,8 +133,9 @@ class Req: # Component modules (populated by the pipeline) modules: dict[str, Any] = field(default_factory=dict) - trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_timesteps: torch.Tensor | None = None trajectory_latents: torch.Tensor | None = None + trajectory_log_probs: torch.Tensor | None = None trajectory_audio_latents: torch.Tensor | None = None # Extra parameters that might be needed by specific pipeline implementations @@ -329,8 +330,9 @@ class OutputBatch: output: torch.Tensor | None = None audio: torch.Tensor | None = None audio_sample_rate: int | None = None - trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_timesteps: torch.Tensor | None = None trajectory_latents: torch.Tensor | None = None + trajectory_log_probs: torch.Tensor | None = None trajectory_decoded: list[torch.Tensor] | None = None error: str | None = None output_file_paths: list[str] | None = None diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py index 980d15210a95..46ba06b483f8 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py @@ -231,6 +231,7 @@ def forward( output=frames, trajectory_timesteps=batch.trajectory_timesteps, trajectory_latents=batch.trajectory_latents, + trajectory_log_probs=batch.trajectory_log_probs, trajectory_decoded=trajectory_decoded, metrics=batch.metrics, ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py index 28df749bdf7a..bc85d3298932 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding_av.py @@ -72,6 +72,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> OutputBatch: output=video, trajectory_timesteps=batch.trajectory_timesteps, trajectory_latents=batch.trajectory_latents, + trajectory_log_probs=batch.trajectory_log_probs, trajectory_decoded=None, metrics=batch.metrics, ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index 844999554d99..3afd633cfd39 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -59,6 +59,9 @@ TransformerLoader, ) from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines.patches.flow_matching_with_logprob import ( + sde_step_with_logprob, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( PipelineStage, @@ -695,6 +698,7 @@ def _post_denoising_loop( latents: torch.Tensor, trajectory_latents: list, trajectory_timesteps: list, + trajectory_log_probs: list, server_args: ServerArgs, is_warmup: bool = False, ): @@ -705,6 +709,10 @@ def _post_denoising_loop( else: trajectory_tensor = None trajectory_timesteps_tensor = None + if trajectory_log_probs: + trajectory_log_probs_tensor = torch.stack(trajectory_log_probs, dim=1) + else: + trajectory_log_probs_tensor = None # Gather results if using sequence parallelism latents, trajectory_tensor = self._postprocess_sp_latents( @@ -731,6 +739,8 @@ def _post_denoising_loop( if trajectory_tensor is not None and trajectory_timesteps_tensor is not None: batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu() batch.trajectory_latents = trajectory_tensor.cpu() + if trajectory_log_probs_tensor is not None: + batch.trajectory_log_probs = trajectory_log_probs_tensor.cpu() # Update batch with final latents batch.latents = self.server_args.pipeline_config.post_denoising_loop( @@ -998,6 +1008,11 @@ def forward( # Initialize lists for ODE trajectory trajectory_timesteps: list[torch.Tensor] = [] trajectory_latents: list[torch.Tensor] = [] + trajectory_log_probs: list[torch.Tensor] = [] + rollout_enabled = bool(batch.rollout) + rollout_sde_type = batch.rollout_sde_type + + rollout_noise_level = batch.rollout_noise_level # Run denoising loop denoising_start_time = time.time() @@ -1006,6 +1021,13 @@ def forward( is_warmup = batch.is_warmup self.scheduler.set_begin_index(0) timesteps_cpu = timesteps.cpu() + rollout_step_indices: list[int] = [] + if rollout_enabled: + scheduler_timesteps = self.scheduler.timesteps + rollout_step_indices = [ + self.scheduler.index_for_timestep(t.to(scheduler_timesteps.device)) + for t in timesteps_cpu + ] num_timesteps = timesteps_cpu.shape[0] with torch.autocast( device_type=current_platform.device_type, @@ -1085,13 +1107,25 @@ def forward( batch.noise_pred = noise_pred # Compute the previous noisy sample - latents = self.scheduler.step( - model_output=noise_pred, - timestep=t_device, - sample=latents, - **extra_step_kwargs, - return_dict=False, - )[0] + if rollout_enabled: + latents, step_log_prob = sde_step_with_logprob( + self.scheduler, + model_output=noise_pred, + sample=latents, + step_index=rollout_step_indices[i], + generator=batch.generator, + sde_type=rollout_sde_type, + noise_level=rollout_noise_level, + ) + trajectory_log_probs.append(step_log_prob) + else: + latents = self.scheduler.step( + model_output=noise_pred, + timestep=t_device, + sample=latents, + **extra_step_kwargs, + return_dict=False, + )[0] latents = self.post_forward_for_ti2v_task( batch, server_args, reserved_frames_mask, latents, z @@ -1126,6 +1160,7 @@ def forward( latents=latents, trajectory_latents=trajectory_latents, trajectory_timesteps=trajectory_timesteps, + trajectory_log_probs=trajectory_log_probs, server_args=server_args, is_warmup=is_warmup, ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py index 504fc429e03b..25376b7d27a5 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py @@ -215,6 +215,7 @@ def forward( latents=latents, trajectory_latents=[], trajectory_timesteps=[], + trajectory_log_probs=[], server_args=server_args, )