diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index 7dcf9bf1dc88..060109689a14 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -172,6 +172,13 @@ class SamplingParams: # Misc save_output: bool = True return_frames: bool = False + rollout: bool = False + rollout_sde_type: str = "sde" + rollout_noise_level: float = 0.7 + rollout_log_prob_no_const: bool = False # exclude constants in rollout logprob + rollout_debug_mode: bool = ( + False # return rollout debug tensors (intermediate states) + ) return_trajectory_latents: bool = False # returns all latents for each timestep return_trajectory_decoded: bool = False # returns decoded latents for each timestep # if True, disallow user params to override subclass-defined protected fields @@ -319,6 +326,16 @@ def _finite_non_negative_float( _finite_non_negative_float( "guidance_rescale", self.guidance_rescale, allow_none=False ) + _finite_non_negative_float( + "rollout_noise_level", self.rollout_noise_level, allow_none=False + ) + + _VALID_ROLLOUT_SDE_TYPES = ("sde", "cps", "ode") + if self.rollout_sde_type not in _VALID_ROLLOUT_SDE_TYPES: + raise ValueError( + f"rollout_sde_type must be one of {_VALID_ROLLOUT_SDE_TYPES}, " + f"got {self.rollout_sde_type!r}" + ) if self.cfg_normalization is None: self.cfg_normalization = 0.0 @@ -803,6 +820,32 @@ def add_argument(*name_or_flags, **kwargs): action="store_true", help="Whether to return the trajectory", ) + add_argument( + "--rollout", + action="store_true", + help="Enable rollout mode and return per-step log_prob trajectory", + ) + add_argument( + "--rollout-sde-type", + type=str, + choices=["sde", "cps", "ode"], + help="Rollout step objective type used in log-prob computation.", + ) + add_argument( + "--rollout-noise-level", + type=float, + help="Noise level used by rollout SDE/CPS step objective.", + ) + add_argument( + "--rollout-log-prob-no-const", + action=StoreBoolean, + help="If true, return rollout log-prob without constant terms.", + ) + add_argument( + "--rollout-debug-mode", + action=StoreBoolean, + help="If true, return rollout debug tensors (variance noise, mean, std, model output).", + ) add_argument( "--return-trajectory-decoded", action="store_true", diff --git a/python/sglang/multimodal_gen/runtime/distributed/communication_op.py b/python/sglang/multimodal_gen/runtime/distributed/communication_op.py index 2714d7ce5119..2da348cfc028 100644 --- a/python/sglang/multimodal_gen/runtime/distributed/communication_op.py +++ b/python/sglang/multimodal_gen/runtime/distributed/communication_op.py @@ -44,6 +44,11 @@ def sequence_model_parallel_all_gather( return get_sp_group().all_gather(input_, dim) +def sequence_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + return get_sp_group().all_reduce(input_) + + def cfg_model_parallel_all_gather( input_: torch.Tensor, dim: int = -1, separate_tensors: bool = False ) -> torch.Tensor: diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py index 2f71ad1d1fd5..29af3ee676a0 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py @@ -253,6 +253,7 @@ def generate( ), trajectory_latents=output_batch.trajectory_latents, trajectory_timesteps=output_batch.trajectory_timesteps, + rollout_trajectory_data=output_batch.rollout_trajectory_data, trajectory_decoded=output_batch.trajectory_decoded, ) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/utils.py b/python/sglang/multimodal_gen/runtime/entrypoints/utils.py index 21b6e1a1f000..69605b4e44fc 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/utils.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/utils.py @@ -108,6 +108,7 @@ class GenerationResult: metrics: dict = field(default_factory=dict) trajectory_latents: Any = None trajectory_timesteps: Any = None + rollout_trajectory_data: Any = None trajectory_decoded: Any = None prompt_index: int = 0 output_file_path: str | None = None diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index 8757da74ee68..e1b1baa50f21 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -237,6 +237,9 @@ def execute_forward(self, batch: List[Req]) -> OutputBatch: metrics=result.metrics, trajectory_timesteps=getattr(result, "trajectory_timesteps", None), trajectory_latents=getattr(result, "trajectory_latents", None), + rollout_trajectory_data=getattr( + result, "rollout_trajectory_data", None + ), noise_pred=getattr(result, "noise_pred", None), trajectory_decoded=getattr(result, "trajectory_decoded", None), ) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/zimage.py b/python/sglang/multimodal_gen/runtime/models/dits/zimage.py index ca191075f5ef..8a5d18435b7d 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/zimage.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/zimage.py @@ -876,7 +876,8 @@ def forward( x = list(unified.unbind(dim=0)) x = self.unpatchify(x, x_size, patch_size, f_patch_size) - return -x[0] + # Keep batch dim so output shape matches input (e.g. rollout/scheduler expect same ndim). + return -torch.stack(x) EntryClass = ZImageTransformer2DModel diff --git a/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py index 980ff50f91d9..d9bc63ecd535 100644 --- a/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py +++ b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py @@ -32,6 +32,9 @@ from diffusers.utils import BaseOutput from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler +from sglang.multimodal_gen.runtime.post_training.scheduler_rl_mixin import ( + SchedulerRLMixin, +) from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) @@ -51,7 +54,9 @@ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): prev_sample: torch.FloatTensor -class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin, BaseScheduler): +class FlowMatchEulerDiscreteScheduler( + SchedulerMixin, ConfigMixin, BaseScheduler, SchedulerRLMixin +): """ Euler scheduler. @@ -447,6 +452,7 @@ def step( s_noise: float = 1.0, generator: torch.Generator | None = None, per_token_timesteps: torch.Tensor | None = None, + batch=None, return_dict: bool = True, ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple[torch.FloatTensor, ...]: """ @@ -516,12 +522,22 @@ def step( next_sigma = sigma_next dt = sigma_next - sigma - if self.config.stochastic_sampling: - x0 = sample - current_sigma * model_output - noise = torch.randn_like(sample) - prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + if batch is not None and self.already_prepared_rollout(batch): + prev_sample, log_prob_local_sum, log_prob_local_count = ( + self.flow_sde_sampling( + batch, model_output, sample, current_sigma, next_sigma, generator + ) + ) + self.append_local_rollout_log_probs( + batch, log_prob_local_sum, log_prob_local_count + ) else: - prev_sample = sample + dt * model_output + if self.config.stochastic_sampling: + x0 = sample - current_sigma * model_output + noise = torch.randn_like(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + prev_sample = sample + dt * model_output # upon completion increase step index by one assert self._step_index is not None, "_step_index should not be None" diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py index bb30bd10b712..7ec90fba7836 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py @@ -22,6 +22,9 @@ import torch from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams +from sglang.multimodal_gen.runtime.post_training.rl_dataclasses import ( + RolloutTrajectoryData, +) from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.logging_utils import ( _sanitize_for_logging, @@ -131,8 +134,9 @@ class Req: # Component modules (populated by the pipeline) modules: dict[str, Any] = field(default_factory=dict) - trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_timesteps: torch.Tensor | None = None trajectory_latents: torch.Tensor | None = None + rollout_trajectory_data: RolloutTrajectoryData | None = None trajectory_audio_latents: torch.Tensor | None = None # Extra parameters that might be needed by specific pipeline implementations @@ -333,8 +337,9 @@ class OutputBatch: output: torch.Tensor | None = None audio: torch.Tensor | None = None audio_sample_rate: int | None = None - trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_timesteps: torch.Tensor | None = None trajectory_latents: torch.Tensor | None = None + rollout_trajectory_data: RolloutTrajectoryData | None = None trajectory_decoded: list[torch.Tensor] | None = None error: str | None = None output_file_paths: list[str] | None = None diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py index 006139487595..29ef27fe4da0 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/decoding.py @@ -236,6 +236,7 @@ def forward( output=frames, trajectory_timesteps=batch.trajectory_timesteps, trajectory_latents=batch.trajectory_latents, + rollout_trajectory_data=batch.rollout_trajectory_data, trajectory_decoded=trajectory_decoded, metrics=batch.metrics, noise_pred=None, diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index 937446899278..c9bdbb21b704 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -73,6 +73,12 @@ AttentionBackendEnum, current_platform, ) +from sglang.multimodal_gen.runtime.post_training.rl_dataclasses import ( + RolloutTrajectoryData, +) +from sglang.multimodal_gen.runtime.post_training.scheduler_rl_mixin import ( + SchedulerRLMixin, +) from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -129,6 +135,43 @@ def __init__( self._cached_num_steps = None self._is_warmed_up = False + def _maybe_prepare_rollout(self, batch: Req): + """Prepare denoising loop for rollout.""" + if not isinstance(self.scheduler, SchedulerRLMixin): + if batch.rollout: + raise ValueError( + f"Scheduler {type(self.scheduler)} does not support rollout" + ) + return + + self.scheduler.release_rollout_resources(batch) + if batch.rollout: + self.scheduler.prepare_rollout( + batch=batch, + pipeline_config=self.server_args.pipeline_config, + ) + + def _maybe_collect_rollout_log_probs(self, batch: Req): + """Get rollout log probs and store into batch for reward calculation.""" + if not isinstance(self.scheduler, SchedulerRLMixin): + if batch.rollout: + raise ValueError( + f"Scheduler {type(self.scheduler)} does not support rollout" + ) + return + + if batch.rollout: + if batch.rollout_trajectory_data is None: + batch.rollout_trajectory_data = RolloutTrajectoryData() + batch.rollout_trajectory_data.rollout_log_probs = ( + self.scheduler.collect_rollout_log_probs(batch) + ) + if getattr(batch, "rollout_debug_mode", False): + batch.rollout_trajectory_data.rollout_debug_tensors = ( + self.scheduler.collect_rollout_debug_tensors(batch) + ) + self.scheduler.release_rollout_resources(batch) + def _maybe_enable_torch_compile(self, module: object) -> None: """ Compile a module with torch.compile, and enable inductor overlap tweak if available. @@ -557,10 +600,12 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): else: self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch) + self._maybe_prepare_rollout(batch) + # Prepare extra step kwargs for scheduler extra_step_kwargs = self.prepare_extra_func_kwargs( self.scheduler.step, - {"generator": batch.generator, "eta": batch.eta}, + {"generator": batch.generator, "eta": batch.eta, "batch": batch}, ) # Setup precision and autocast settings @@ -718,6 +763,9 @@ def _post_denoising_loop( trajectory_tensor = None trajectory_timesteps_tensor = None + # Gather log probs for rollout + self._maybe_collect_rollout_log_probs(batch) + # Gather results if using sequence parallelism latents, trajectory_tensor = self._postprocess_sp_latents( batch, latents, trajectory_tensor @@ -1075,7 +1123,6 @@ def forward( guidance=guidance, latents=latents, ) - # Save noise_pred to batch for external access (e.g., ComfyUI) if server_args.comfyui_mode: batch.noise_pred = noise_pred diff --git a/python/sglang/multimodal_gen/runtime/post_training/__init__.py b/python/sglang/multimodal_gen/runtime/post_training/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/sglang/multimodal_gen/runtime/post_training/rl_dataclasses.py b/python/sglang/multimodal_gen/runtime/post_training/rl_dataclasses.py new file mode 100644 index 000000000000..4ce8b0d94864 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/post_training/rl_dataclasses.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +"""RL-specific dataclasses used by post-training and rollout paths.""" + +from dataclasses import dataclass, field +from typing import Any + +import torch + + +@dataclass +class RolloutSessionData: + """Per-batch rollout state created by prepare_rollout(), lives on the batch object. + + Cleared by setting ``batch._rollout_session_data = None``. + """ + + pipeline_config: Any = None + sigma_max: float = 0.0 + latents_shape: tuple | None = None + noise_buffer: torch.Tensor | None = None + + local_log_prob_sum: list[torch.Tensor] = field(default_factory=list) + local_log_prob_count: list[torch.Tensor] = field(default_factory=list) + + local_variance_noises: list[torch.Tensor] = field(default_factory=list) + local_prev_sample_means: list[torch.Tensor] = field(default_factory=list) + local_noise_std_devs: list[torch.Tensor] = field(default_factory=list) + local_model_outputs: list[torch.Tensor] = field(default_factory=list) + + +@dataclass +class RolloutDebugTensors: + """Container for rollout debug tensors collected during denoising.""" + + rollout_variance_noises: torch.Tensor | None = None + rollout_prev_sample_means: torch.Tensor | None = None + rollout_noise_std_devs: torch.Tensor | None = None + rollout_model_outputs: torch.Tensor | None = None + + +@dataclass +class RolloutTrajectoryData: + """Container for rollout-specific trajectory outputs.""" + + rollout_log_probs: torch.Tensor | None = None + rollout_debug_tensors: RolloutDebugTensors | None = None diff --git a/python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_debug_mixin.py b/python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_debug_mixin.py new file mode 100644 index 000000000000..acdacae7117b --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_debug_mixin.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Debug tensor helpers for rollout-enabled schedulers.""" + +import torch + +from sglang.multimodal_gen.runtime.distributed import ( + get_local_torch_device, + get_sp_world_size, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.post_training.rl_dataclasses import ( + RolloutDebugTensors, + RolloutSessionData, +) + + +class SchedulerRLDebugMixin: + @staticmethod + def _reset_rollout_debug_tensors(rollout_session_data: RolloutSessionData) -> None: + rollout_session_data.local_variance_noises = [] + rollout_session_data.local_prev_sample_means = [] + rollout_session_data.local_noise_std_devs = [] + rollout_session_data.local_model_outputs = [] + + def append_local_rollout_debug_tensors( + self, + batch, + *, + variance_noise: torch.Tensor, + prev_sample_mean: torch.Tensor, + noise_std_dev: torch.Tensor, + model_output: torch.Tensor, + ) -> None: + rollout_session_data = batch._rollout_session_data + batch_size = variance_noise.shape[0] + rollout_session_data.local_variance_noises.append(variance_noise) + rollout_session_data.local_prev_sample_means.append(prev_sample_mean) + rollout_session_data.local_noise_std_devs.append( + noise_std_dev.expand((batch_size, 1)) + ) + rollout_session_data.local_model_outputs.append(model_output) + + def consume_local_rollout_debug_tensors( + self, + batch, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + rollout_session_data = batch._rollout_session_data + variance_noises = torch.stack(rollout_session_data.local_variance_noises, dim=1) + prev_sample_means = torch.stack( + rollout_session_data.local_prev_sample_means, dim=1 + ) + noise_std_devs = torch.stack(rollout_session_data.local_noise_std_devs, dim=1) + model_outputs = torch.stack(rollout_session_data.local_model_outputs, dim=1) + self._reset_rollout_debug_tensors(rollout_session_data) + return variance_noises, prev_sample_means, noise_std_devs, model_outputs + + def collect_rollout_debug_tensors(self, batch: Req) -> RolloutDebugTensors: + """ + Consume rollout debug tensors and merge for all SP ranks. + + Returns rollout debug tensors with shape [B, T, ...]. + """ + rollout_session_data = batch._rollout_session_data + variance_noises, prev_sample_means, noise_std_devs, model_outputs = ( + self.consume_local_rollout_debug_tensors(batch) + ) + + if get_sp_world_size() > 1 and getattr(batch, "did_sp_shard_latents", False): + variance_noises = variance_noises.to(get_local_torch_device()) + prev_sample_means = prev_sample_means.to(get_local_torch_device()) + noise_std_devs = noise_std_devs.to(get_local_torch_device()) + model_outputs = model_outputs.to(get_local_torch_device()) + pipeline_config = rollout_session_data.pipeline_config + bsz, num_steps = variance_noises.shape[0], variance_noises.shape[1] + + # [B, T, ...] -> [B*T, ...] + variance_noises_packed = variance_noises.contiguous().reshape( + bsz * num_steps, *variance_noises.shape[2:] + ) + prev_sample_means_packed = prev_sample_means.contiguous().reshape( + bsz * num_steps, *prev_sample_means.shape[2:] + ) + model_outputs_packed = model_outputs.contiguous().reshape( + bsz * num_steps, *model_outputs.shape[2:] + ) + + # Gather on packed tensors first. + variance_noises_packed = pipeline_config.gather_latents_for_sp( + variance_noises_packed + ) + prev_sample_means_packed = pipeline_config.gather_latents_for_sp( + prev_sample_means_packed + ) + model_outputs_packed = pipeline_config.gather_latents_for_sp( + model_outputs_packed + ) + + # Unpack back to [B, T, ...]. + variance_noises = variance_noises_packed.reshape( + bsz, num_steps, *variance_noises_packed.shape[1:] + ) + prev_sample_means = prev_sample_means_packed.reshape( + bsz, num_steps, *prev_sample_means_packed.shape[1:] + ) + model_outputs = model_outputs_packed.reshape( + bsz, num_steps, *model_outputs_packed.shape[1:] + ) + # noise_std_devs is same on every device, not a sharded latent tensor. + + return RolloutDebugTensors( + rollout_variance_noises=variance_noises.cpu(), + rollout_prev_sample_means=prev_sample_means.cpu(), + rollout_noise_std_devs=noise_std_devs.cpu(), + rollout_model_outputs=model_outputs.cpu(), + ) diff --git a/python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py b/python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py new file mode 100644 index 000000000000..598d6b9eba47 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Flow-matching rollout step utilities for log-prob computation.""" + +import math +from typing import Any, Union + +import torch + +from sglang.multimodal_gen.runtime.distributed import ( + get_local_torch_device, + get_sp_world_size, +) +from sglang.multimodal_gen.runtime.distributed.communication_op import ( + sequence_model_parallel_all_reduce, +) +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req +from sglang.multimodal_gen.runtime.post_training.rl_dataclasses import ( + RolloutSessionData, +) +from sglang.multimodal_gen.runtime.post_training.scheduler_rl_debug_mixin import ( + SchedulerRLDebugMixin, +) + +_LOG_SQRT_2PI = math.log(math.sqrt(2 * math.pi)) + + +class SchedulerRLMixin(SchedulerRLDebugMixin): + + @staticmethod + def _get_rollout_session_data(batch) -> RolloutSessionData: + """Return the RolloutSessionData attached to *batch*, or raise if not prepared.""" + rollout_session_data = getattr(batch, "_rollout_session_data", None) + if rollout_session_data is None: + raise RuntimeError("prepare_rollout() not called before rollout") + return rollout_session_data + + def release_rollout_resources(self, batch) -> None: + """Release rollout-owned resources. Call when denoising ends or before a new rollout.""" + batch._rollout_session_data = None + + def prepare_rollout(self, batch: Req, pipeline_config: Any = None) -> None: + """Enable rollout and set SDE/CPS params. Call once before the denoising loop.""" + if get_sp_world_size() > 1 and pipeline_config is None: + raise RuntimeError( + "SP rollout requires pipeline_config to be passed to prepare_rollout()." + ) + batch._rollout_session_data = RolloutSessionData( + pipeline_config=pipeline_config, + sigma_max=self.sigmas[min(1, len(self.sigmas) - 1)].item(), + latents_shape=( + tuple(batch.latents.shape) if batch.latents is not None else None + ), + ) + + def already_prepared_rollout(self, batch) -> bool: + return getattr(batch, "_rollout_session_data", None) is not None + + def _get_or_create_rollout_noise_buffer( + self, + rollout_session_data: RolloutSessionData, + full_shape: tuple, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Get or create the reusable noise buffer (local or full shape) for rollout.""" + buffer = rollout_session_data.noise_buffer + if ( + buffer is None + or buffer.shape != full_shape + or buffer.dtype != dtype + or buffer.device != device + ): + buffer = torch.empty(full_shape, device=device, dtype=dtype) + rollout_session_data.noise_buffer = buffer + return buffer + + def _rollout_variance_noise( + self, + batch, + model_output: torch.FloatTensor, + generator: Union[torch.Generator, list[torch.Generator]], + ) -> torch.FloatTensor: + """Generate variance noise for rollout. If generator is a list, use generator[i] for the i-th batch item.""" + assert generator is not None, "Generator must be provided" + + rollout_session_data = self._get_rollout_session_data(batch) + device = model_output.device + dtype = model_output.dtype + local_shape = tuple(model_output.shape) + + B = local_shape[0] + if isinstance(generator, torch.Generator): + assert B == 1, "Generator must be a list if batch size is not 1" + generator = [generator] + else: + assert ( + len(generator) == B + ), "Generator list must have the same length as batch size" + + buffer = self._get_or_create_rollout_noise_buffer( + rollout_session_data, rollout_session_data.latents_shape, device, dtype + ) + for i in range(B): + torch.randn( + rollout_session_data.latents_shape, + out=buffer[i : i + 1], + generator=generator[i], + ) + + sharded_noise, _ = rollout_session_data.pipeline_config.shard_latents_for_sp( + batch, buffer + ) + if tuple(sharded_noise.shape) != local_shape: + raise ValueError( + "Rollout SP noise shape mismatch after shard. " + f"Expected local_shape={local_shape}, got {tuple(sharded_noise.shape)}." + ) + return sharded_noise + + def flow_sde_sampling( + self, + batch, + model_output: torch.FloatTensor, + sample: torch.FloatTensor, + current_sigma: torch.FloatTensor, + next_sigma: torch.FloatTensor, + generator: torch.Generator, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Flow rollout step for log-prob / sampling (see FlowGRPO-style references). + + ``rollout_sde_type`` (from batch SamplingParams): + + 1. ``"sde"``: Standard stochastic differential equation transition (Gaussian). + 2. ``"cps"``: Coupled Particle Sampling. + 3. ``"ode"``: Deterministic ODE step (no diffusion noise). + """ + rollout_session_data = self._get_rollout_session_data(batch) + sde_type = batch.rollout_sde_type + noise_level = float(batch.rollout_noise_level) + log_prob_no_const = batch.rollout_log_prob_no_const + debug_mode = bool(getattr(batch, "rollout_debug_mode", False)) + + if not log_prob_no_const and sde_type != "ode": + assert ( + noise_level > 0 + ), "True log-probability computation requires a non-zero noise level." + + dt = next_sigma - current_sigma + if sde_type == "sde": + variance_noise = self._rollout_variance_noise( + batch, model_output, generator + ) + std_dev_t = ( + torch.sqrt( + current_sigma + / ( + 1 + - torch.where( + torch.isclose(current_sigma, current_sigma.new_tensor(1.0)), + rollout_session_data.sigma_max, + current_sigma, + ) + ) + ) + * noise_level + ) + noise_std_dev = std_dev_t * torch.sqrt(-1 * dt) + prev_sample_mean = ( + sample * (1 + std_dev_t**2 / (2 * current_sigma) * dt) + + model_output + * (1 + std_dev_t**2 * (1 - current_sigma) / (2 * current_sigma)) + * dt + ) + + weighted_variance_noise = variance_noise * noise_std_dev + prev_sample = prev_sample_mean + weighted_variance_noise + log_prob_no_const_val = -(weighted_variance_noise**2) + + elif sde_type == "cps": + variance_noise = self._rollout_variance_noise( + batch, model_output, generator + ) + std_dev_t = next_sigma * math.sin(noise_level * math.pi / 2) + noise_std_dev = std_dev_t + pred_original_sample = sample - current_sigma * model_output + noise_estimate = sample + model_output * (1 - current_sigma) + prev_sample_mean = pred_original_sample * ( + 1 - next_sigma + ) + noise_estimate * torch.sqrt(next_sigma**2 - std_dev_t**2) + + weighted_variance_noise = variance_noise * noise_std_dev + prev_sample = prev_sample_mean + weighted_variance_noise + log_prob_no_const_val = -(weighted_variance_noise**2) + + elif sde_type == "ode": + prev_sample = sample + dt * model_output + prev_sample_mean = prev_sample + variance_noise = torch.zeros_like(model_output) + noise_std_dev = torch.zeros( + (), device=model_output.device, dtype=model_output.dtype + ) + log_prob_no_const_val = torch.zeros_like(model_output) + assert ( + log_prob_no_const + ), "p_ode is always 0, true log_prob is meaningless, set rollout_log_prob_no_const to True to enable log_prob computation" + + else: + raise ValueError(f"Unsupported sde_type: {sde_type}") + + reduce_dims = list(range(1, len(log_prob_no_const_val.shape))) + local_elem_count = log_prob_no_const_val.new_full( + (log_prob_no_const_val.shape[0],), + float(math.prod(log_prob_no_const_val.shape[1:])), + ) + + if log_prob_no_const: + log_prob_local_sum = log_prob_no_const_val.sum(dim=reduce_dims) + else: + log_prob_local_sum = ( + log_prob_no_const_val / (2 * (noise_std_dev**2)) + - torch.log(noise_std_dev) + - _LOG_SQRT_2PI + ).sum(dim=list(range(1, len(log_prob_no_const_val.shape)))) + + if debug_mode: + self.append_local_rollout_debug_tensors( + batch, + variance_noise=variance_noise, + prev_sample_mean=prev_sample_mean, + noise_std_dev=noise_std_dev, + model_output=model_output, + ) + + return prev_sample, log_prob_local_sum, local_elem_count + + def append_local_rollout_log_probs( + self, batch, log_prob_sum: torch.Tensor, log_prob_count: torch.Tensor + ) -> None: + rollout_session_data = self._get_rollout_session_data(batch) + rollout_session_data.local_log_prob_sum.append(log_prob_sum) + rollout_session_data.local_log_prob_count.append(log_prob_count) + + def consume_local_rollout_log_probs( + self, batch + ) -> tuple[torch.Tensor, torch.Tensor]: + rollout_session_data = self._get_rollout_session_data(batch) + values_sum = torch.stack(rollout_session_data.local_log_prob_sum, dim=-1) + values_count = torch.stack(rollout_session_data.local_log_prob_count, dim=-1) + rollout_session_data.local_log_prob_sum = [] + rollout_session_data.local_log_prob_count = [] + return values_sum, values_count + + def collect_rollout_log_probs(self, batch: Req) -> torch.Tensor | None: + """Consume local rollout log probs and merge for all SP ranks.""" + + trajectory_log_prob_sum, trajectory_log_prob_count = ( + self.consume_local_rollout_log_probs(batch) + ) + if get_sp_world_size() > 1 and getattr(batch, "did_sp_shard_latents", False): + packed = torch.stack( + [trajectory_log_prob_sum, trajectory_log_prob_count], dim=0 + ).to(get_local_torch_device()) + sequence_model_parallel_all_reduce(packed) + trajectory_log_prob_sum = packed[0] + trajectory_log_prob_count = packed[1] + + rollout_log_probs_tensor = trajectory_log_prob_sum / trajectory_log_prob_count + return rollout_log_probs_tensor.cpu() diff --git a/python/sglang/multimodal_gen/test/run_suite.py b/python/sglang/multimodal_gen/test/run_suite.py index f7182b9bb7f9..3d65416a094d 100644 --- a/python/sglang/multimodal_gen/test/run_suite.py +++ b/python/sglang/multimodal_gen/test/run_suite.py @@ -35,6 +35,7 @@ "../unit/test_storage.py", "../unit/test_lora_format_adapter.py", "../unit/test_server_args.py", + "../unit/test_scheduler_rollout_unit.py", # add new unit tests here ], "1-gpu": [ diff --git a/python/sglang/multimodal_gen/test/unit/test_scheduler_rollout_unit.py b/python/sglang/multimodal_gen/test/unit/test_scheduler_rollout_unit.py new file mode 100644 index 000000000000..33cebfd76aa0 --- /dev/null +++ b/python/sglang/multimodal_gen/test/unit/test_scheduler_rollout_unit.py @@ -0,0 +1,272 @@ +import math +import types +import unittest + +import torch + +import sglang.multimodal_gen.runtime.post_training.scheduler_rl_mixin as rl_mixin_module +from sglang.multimodal_gen.runtime.post_training.scheduler_rl_mixin import ( + SchedulerRLMixin, +) + + +class _DummyScheduler(SchedulerRLMixin): + def __init__(self): + self.sigmas = torch.tensor([1.0, 0.8, 0.6, 0.4, 0.2, 0.0], dtype=torch.float32) + + +class TestSchedulerRolloutOdeUnit(unittest.TestCase): + def setUp(self): + self._orig_get_sp_world_size = rl_mixin_module.get_sp_world_size + rl_mixin_module.get_sp_world_size = lambda: 1 + + def tearDown(self): + rl_mixin_module.get_sp_world_size = self._orig_get_sp_world_size + + def _build_batch(self, *, debug_mode: bool) -> types.SimpleNamespace: + return types.SimpleNamespace( + rollout_log_prob_no_const=True, + rollout_noise_level=0.5, + rollout_sde_type="ode", + rollout_debug_mode=debug_mode, + latents=torch.zeros(2, 4, 8, 8, dtype=torch.float32), + _rollout_session_data=None, + ) + + def test_ode_step_does_not_call_variance_noise_sampler(self): + scheduler = _DummyScheduler() + batch = self._build_batch(debug_mode=False) + scheduler.prepare_rollout(batch) + + def _raise_if_called(*args, **kwargs): + raise AssertionError("ODE path should not call _rollout_variance_noise") + + scheduler._rollout_variance_noise = _raise_if_called # type: ignore[method-assign] + + sample = torch.randn(2, 4, 8, 8, dtype=torch.float32) + model_output = torch.randn_like(sample) + current_sigma = torch.tensor(0.6, dtype=torch.float32) + next_sigma = torch.tensor(0.4, dtype=torch.float32) + + prev_sample, log_prob_local_sum, local_elem_count = scheduler.flow_sde_sampling( + batch, + model_output=model_output, + sample=sample, + current_sigma=current_sigma, + next_sigma=next_sigma, + generator=torch.Generator(device=sample.device).manual_seed(1), + ) + + expected_prev = sample + (next_sigma - current_sigma) * model_output + self.assertTrue(torch.allclose(prev_sample, expected_prev, atol=1e-6, rtol=0.0)) + self.assertTrue( + torch.allclose(log_prob_local_sum, torch.zeros_like(log_prob_local_sum)) + ) + self.assertEqual(tuple(log_prob_local_sum.shape), (sample.shape[0],)) + self.assertEqual(tuple(local_elem_count.shape), (sample.shape[0],)) + self.assertTrue(torch.all(local_elem_count == float(sample[0].numel()))) + + def test_ode_debug_tensors_have_shape_safe_noise_std(self): + scheduler = _DummyScheduler() + batch = self._build_batch(debug_mode=True) + scheduler.prepare_rollout(batch) + + sample = torch.randn(2, 4, 8, 8, dtype=torch.float32) + model_output = torch.randn_like(sample) + current_sigma = torch.tensor(0.6, dtype=torch.float32) + next_sigma = torch.tensor(0.4, dtype=torch.float32) + + scheduler.flow_sde_sampling( + batch, + model_output=model_output, + sample=sample, + current_sigma=current_sigma, + next_sigma=next_sigma, + generator=torch.Generator(device=sample.device).manual_seed(2), + ) + + ( + variance_noises, + prev_sample_means, + noise_std_devs, + model_outputs, + ) = scheduler.consume_local_rollout_debug_tensors(batch) + + # [B, T, ...] with one step in this test. + self.assertEqual(tuple(variance_noises.shape), (2, 1, 4, 8, 8)) + self.assertEqual(tuple(prev_sample_means.shape), (2, 1, 4, 8, 8)) + self.assertEqual(tuple(model_outputs.shape), (2, 1, 4, 8, 8)) + self.assertEqual(tuple(noise_std_devs.shape), (2, 1, 1)) + self.assertTrue( + torch.allclose(noise_std_devs, torch.zeros_like(noise_std_devs)) + ) + self.assertTrue( + torch.allclose(variance_noises, torch.zeros_like(variance_noises)) + ) + + +def _flowgrpo_sde_step_with_logprob( + *, + model_output: torch.Tensor, + sample: torch.Tensor, + variance_noise: torch.Tensor, + sigma: torch.Tensor, + sigma_prev: torch.Tensor, + sigma_max: float, + noise_level: float, + sde_type: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Verbatim from FlowGRPO sd3_sde_with_logprob.py ``sde_step_with_logprob``. + + Returns (prev_sample, log_prob, prev_sample_mean, noise_std_dev). + ``sigma`` / ``sigma_prev`` follow FlowGRPO convention (current / next). + """ + model_output = model_output.float() + sample = sample.float() + + dt = sigma_prev - sigma + + if sde_type == "sde": + std_dev_t = ( + torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) + * noise_level + ) + prev_sample_mean = ( + sample * (1 + std_dev_t**2 / (2 * sigma) * dt) + + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt + ) + noise_std_dev = std_dev_t * torch.sqrt(-1 * dt) + prev_sample = prev_sample_mean + noise_std_dev * variance_noise + + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) + / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2)) + - torch.log(std_dev_t * torch.sqrt(-1 * dt)) + - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) + ) + + elif sde_type == "cps": + std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) + noise_std_dev = std_dev_t + pred_original_sample = sample - sigma * model_output + noise_estimate = sample + model_output * (1 - sigma) + prev_sample_mean = pred_original_sample * ( + 1 - sigma_prev + ) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) + + else: + raise ValueError(f"Unsupported sde_type: {sde_type}") + + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + return prev_sample, log_prob, prev_sample_mean, noise_std_dev + + +# FlowGRPO convention: SDE uses full Gaussian log-prob, CPS uses no_const. +_FLOWGRPO_LOG_PROB_NO_CONST = {"sde": False, "cps": True} + + +class TestSchedulerFlowGRPOStepAlignmentUnit(unittest.TestCase): + def setUp(self): + self._orig_get_sp_world_size = rl_mixin_module.get_sp_world_size + rl_mixin_module.get_sp_world_size = lambda: 1 + + def tearDown(self): + rl_mixin_module.get_sp_world_size = self._orig_get_sp_world_size + + def _build_batch( + self, *, sde_type: str, shape: tuple[int, ...] + ) -> types.SimpleNamespace: + return types.SimpleNamespace( + rollout_log_prob_no_const=_FLOWGRPO_LOG_PROB_NO_CONST[sde_type], + rollout_noise_level=0.5, + rollout_sde_type=sde_type, + rollout_debug_mode=True, + latents=torch.empty(shape, dtype=torch.float32), + _rollout_session_data=None, + ) + + def test_single_step_matches_flowgrpo_reference(self): + """Verify prev_sample, prev_sample_mean, noise_std_dev, and log_prob + all match FlowGRPO's ``sde_step_with_logprob`` for SDE and CPS.""" + scheduler = _DummyScheduler() + current_sigma = torch.tensor(0.5, dtype=torch.float32) + next_sigma = torch.tensor(0.3, dtype=torch.float32) + shape = (1, 16, 1, 32, 32) + atol = 1e-6 + pipeline_config = types.SimpleNamespace( + shard_latents_for_sp=lambda _batch, latents: (latents, False) + ) + + for sde_type in ("sde", "cps"): + for seed in (0, 1, 2, 3): + batch = self._build_batch(sde_type=sde_type, shape=shape) + scheduler.release_rollout_resources(batch) + scheduler.prepare_rollout(batch=batch, pipeline_config=pipeline_config) + + g = torch.Generator(device="cpu").manual_seed(seed) + model_output = torch.randn(shape, generator=g, dtype=torch.float32) + sample = torch.randn(shape, generator=g, dtype=torch.float32) + variance_noise = torch.randn(shape, generator=g, dtype=torch.float32) + scheduler._rollout_variance_noise = ( # type: ignore[method-assign] + lambda _batch, *_args, **_kwargs: variance_noise + ) + + prev_sgl, log_prob_sum, elem_count = scheduler.flow_sde_sampling( + batch, + model_output=model_output, + sample=sample, + current_sigma=current_sigma, + next_sigma=next_sigma, + generator=g, + ) + ( + _variance_noises, + prev_sample_means, + noise_std_devs, + _model_outputs, + ) = scheduler.consume_local_rollout_debug_tensors(batch) + + prev_ref, log_prob_ref, prev_mean_ref, noise_std_ref = ( + _flowgrpo_sde_step_with_logprob( + model_output=model_output, + sample=sample, + variance_noise=variance_noise, + sigma=current_sigma, + sigma_prev=next_sigma, + sigma_max=scheduler.sigmas[1].item(), + noise_level=0.5, + sde_type=sde_type, + ) + ) + + log_prob_mean = log_prob_sum / elem_count + + errs = { + "prev_sample": float((prev_sgl - prev_ref).abs().max().item()), + "prev_sample_mean": float( + (prev_sample_means[:, 0] - prev_mean_ref).abs().max().item() + ), + "noise_std": float( + (noise_std_devs[:, 0, 0] - noise_std_ref.reshape(-1)) + .abs() + .max() + .item() + ), + "log_prob": float( + (log_prob_mean - log_prob_ref).abs().max().item() + ), + } + + for name, err in errs.items(): + self.assertLessEqual( + err, + atol, + msg=f"{sde_type} seed={seed} {name} max_abs={err:.9f}", + ) + + +if __name__ == "__main__": + unittest.main()