Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions python/sglang/multimodal_gen/configs/sample/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ class SamplingParams:
# Misc
save_output: bool = True
return_frames: bool = False
rollout: bool = False
rollout_sde_type: str = "sde"
rollout_noise_level: float = 0.7
rollout_log_prob_no_const: bool = False # exclude constants in rollout logprob
rollout_debug_mode: bool = (
False # return rollout debug tensors (intermediate states)
)
return_trajectory_latents: bool = False # returns all latents for each timestep
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
# if True, disallow user params to override subclass-defined protected fields
Expand Down Expand Up @@ -319,6 +326,16 @@ def _finite_non_negative_float(
_finite_non_negative_float(
"guidance_rescale", self.guidance_rescale, allow_none=False
)
_finite_non_negative_float(
"rollout_noise_level", self.rollout_noise_level, allow_none=False
)

_VALID_ROLLOUT_SDE_TYPES = ("sde", "cps", "ode")
if self.rollout_sde_type not in _VALID_ROLLOUT_SDE_TYPES:
raise ValueError(
f"rollout_sde_type must be one of {_VALID_ROLLOUT_SDE_TYPES}, "
f"got {self.rollout_sde_type!r}"
)

if self.cfg_normalization is None:
self.cfg_normalization = 0.0
Expand Down Expand Up @@ -803,6 +820,32 @@ def add_argument(*name_or_flags, **kwargs):
action="store_true",
help="Whether to return the trajectory",
)
add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use a dedicated parser (for RL), and add arguments of the parser here

"--rollout",
action="store_true",
help="Enable rollout mode and return per-step log_prob trajectory",
)
add_argument(
"--rollout-sde-type",
type=str,
choices=["sde", "cps", "ode"],
help="Rollout step objective type used in log-prob computation.",
)
add_argument(
"--rollout-noise-level",
type=float,
help="Noise level used by rollout SDE/CPS step objective.",
)
add_argument(
"--rollout-log-prob-no-const",
action=StoreBoolean,
help="If true, return rollout log-prob without constant terms.",
)
add_argument(
"--rollout-debug-mode",
action=StoreBoolean,
help="If true, return rollout debug tensors (variance noise, mean, std, model output).",
)
add_argument(
"--return-trajectory-decoded",
action="store_true",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def sequence_model_parallel_all_gather(
return get_sp_group().all_gather(input_, dim)


def sequence_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_sp_group().all_reduce(input_)


def cfg_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1, separate_tensors: bool = False
) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def generate(
),
trajectory_latents=output_batch.trajectory_latents,
trajectory_timesteps=output_batch.trajectory_timesteps,
rollout_trajectory_data=output_batch.rollout_trajectory_data,
trajectory_decoded=output_batch.trajectory_decoded,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class GenerationResult:
metrics: dict = field(default_factory=dict)
trajectory_latents: Any = None
trajectory_timesteps: Any = None
rollout_trajectory_data: Any = None
trajectory_decoded: Any = None
prompt_index: int = 0
output_file_path: str | None = None
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/multimodal_gen/runtime/managers/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ def execute_forward(self, batch: List[Req]) -> OutputBatch:
metrics=result.metrics,
trajectory_timesteps=getattr(result, "trajectory_timesteps", None),
trajectory_latents=getattr(result, "trajectory_latents", None),
rollout_trajectory_data=getattr(
result, "rollout_trajectory_data", None
),
noise_pred=getattr(result, "noise_pred", None),
trajectory_decoded=getattr(result, "trajectory_decoded", None),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,8 @@ def forward(
x = list(unified.unbind(dim=0))
x = self.unpatchify(x, x_size, patch_size, f_patch_size)

return -x[0]
# Keep batch dim so output shape matches input (e.g. rollout/scheduler expect same ndim).
return -torch.stack(x)


EntryClass = ZImageTransformer2DModel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from diffusers.utils import BaseOutput

from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler
from sglang.multimodal_gen.runtime.post_training.scheduler_rl_mixin import (
SchedulerRLMixin,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)
Expand All @@ -51,7 +54,9 @@ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor


class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):
class FlowMatchEulerDiscreteScheduler(
SchedulerMixin, ConfigMixin, BaseScheduler, SchedulerRLMixin
):
"""
Euler scheduler.
Expand Down Expand Up @@ -447,6 +452,7 @@ def step(
s_noise: float = 1.0,
generator: torch.Generator | None = None,
per_token_timesteps: torch.Tensor | None = None,
batch=None,
return_dict: bool = True,
) -> FlowMatchEulerDiscreteSchedulerOutput | tuple[torch.FloatTensor, ...]:
"""
Expand Down Expand Up @@ -516,12 +522,22 @@ def step(
next_sigma = sigma_next
dt = sigma_next - sigma

if self.config.stochastic_sampling:
x0 = sample - current_sigma * model_output
noise = torch.randn_like(sample)
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
if batch is not None and self.already_prepared_rollout(batch):
prev_sample, log_prob_local_sum, log_prob_local_count = (
self.flow_sde_sampling(
batch, model_output, sample, current_sigma, next_sigma, generator
)
)
self.append_local_rollout_log_probs(
batch, log_prob_local_sum, log_prob_local_count
)
else:
prev_sample = sample + dt * model_output
if self.config.stochastic_sampling:
x0 = sample - current_sigma * model_output
noise = torch.randn_like(sample)
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
else:
prev_sample = sample + dt * model_output

# upon completion increase step index by one
assert self._step_index is not None, "_step_index should not be None"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import torch

from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams
from sglang.multimodal_gen.runtime.post_training.rl_dataclasses import (
RolloutTrajectoryData,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import (
_sanitize_for_logging,
Expand Down Expand Up @@ -131,8 +134,9 @@ class Req:
# Component modules (populated by the pipeline)
modules: dict[str, Any] = field(default_factory=dict)

trajectory_timesteps: list[torch.Tensor] | None = None
trajectory_timesteps: torch.Tensor | None = None
trajectory_latents: torch.Tensor | None = None
rollout_trajectory_data: RolloutTrajectoryData | None = None
trajectory_audio_latents: torch.Tensor | None = None

# Extra parameters that might be needed by specific pipeline implementations
Expand Down Expand Up @@ -333,8 +337,9 @@ class OutputBatch:
output: torch.Tensor | None = None
audio: torch.Tensor | None = None
audio_sample_rate: int | None = None
trajectory_timesteps: list[torch.Tensor] | None = None
trajectory_timesteps: torch.Tensor | None = None
trajectory_latents: torch.Tensor | None = None
rollout_trajectory_data: RolloutTrajectoryData | None = None
trajectory_decoded: list[torch.Tensor] | None = None
error: str | None = None
output_file_paths: list[str] | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def forward(
output=frames,
trajectory_timesteps=batch.trajectory_timesteps,
trajectory_latents=batch.trajectory_latents,
rollout_trajectory_data=batch.rollout_trajectory_data,
trajectory_decoded=trajectory_decoded,
metrics=batch.metrics,
noise_pred=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
AttentionBackendEnum,
current_platform,
)
from sglang.multimodal_gen.runtime.post_training.rl_dataclasses import (
RolloutTrajectoryData,
)
from sglang.multimodal_gen.runtime.post_training.scheduler_rl_mixin import (
SchedulerRLMixin,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
Expand Down Expand Up @@ -129,6 +135,43 @@ def __init__(
self._cached_num_steps = None
self._is_warmed_up = False

def _maybe_prepare_rollout(self, batch: Req):
"""Prepare denoising loop for rollout."""
if not isinstance(self.scheduler, SchedulerRLMixin):
if batch.rollout:
raise ValueError(
f"Scheduler {type(self.scheduler)} does not support rollout"
)
return

self.scheduler.release_rollout_resources(batch)
if batch.rollout:
self.scheduler.prepare_rollout(
batch=batch,
pipeline_config=self.server_args.pipeline_config,
)

def _maybe_collect_rollout_log_probs(self, batch: Req):
"""Get rollout log probs and store into batch for reward calculation."""
if not isinstance(self.scheduler, SchedulerRLMixin):
if batch.rollout:
raise ValueError(
f"Scheduler {type(self.scheduler)} does not support rollout"
)
return

if batch.rollout:
if batch.rollout_trajectory_data is None:
batch.rollout_trajectory_data = RolloutTrajectoryData()
batch.rollout_trajectory_data.rollout_log_probs = (
self.scheduler.collect_rollout_log_probs(batch)
)
if getattr(batch, "rollout_debug_mode", False):
batch.rollout_trajectory_data.rollout_debug_tensors = (
self.scheduler.collect_rollout_debug_tensors(batch)
)
self.scheduler.release_rollout_resources(batch)

def _maybe_enable_torch_compile(self, module: object) -> None:
"""
Compile a module with torch.compile, and enable inductor overlap tweak if available.
Expand Down Expand Up @@ -557,10 +600,12 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
else:
self._maybe_enable_cache_dit(cache_dit_num_inference_steps, batch)

self._maybe_prepare_rollout(batch)

# Prepare extra step kwargs for scheduler
extra_step_kwargs = self.prepare_extra_func_kwargs(
self.scheduler.step,
{"generator": batch.generator, "eta": batch.eta},
{"generator": batch.generator, "eta": batch.eta, "batch": batch},
)

# Setup precision and autocast settings
Expand Down Expand Up @@ -718,6 +763,9 @@ def _post_denoising_loop(
trajectory_tensor = None
trajectory_timesteps_tensor = None

# Gather log probs for rollout
self._maybe_collect_rollout_log_probs(batch)

# Gather results if using sequence parallelism
latents, trajectory_tensor = self._postprocess_sp_latents(
batch, latents, trajectory_tensor
Expand Down Expand Up @@ -1075,7 +1123,6 @@ def forward(
guidance=guidance,
latents=latents,
)

# Save noise_pred to batch for external access (e.g., ComfyUI)
if server_args.comfyui_mode:
batch.noise_pred = noise_pred
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
"""RL-specific dataclasses used by post-training and rollout paths."""

from dataclasses import dataclass, field
from typing import Any

import torch


@dataclass
class RolloutSessionData:
"""Per-batch rollout state created by prepare_rollout(), lives on the batch object.

Cleared by setting ``batch._rollout_session_data = None``.
"""

pipeline_config: Any = None
sigma_max: float = 0.0
latents_shape: tuple | None = None
noise_buffer: torch.Tensor | None = None

local_log_prob_sum: list[torch.Tensor] = field(default_factory=list)
local_log_prob_count: list[torch.Tensor] = field(default_factory=list)

local_variance_noises: list[torch.Tensor] = field(default_factory=list)
local_prev_sample_means: list[torch.Tensor] = field(default_factory=list)
local_noise_std_devs: list[torch.Tensor] = field(default_factory=list)
local_model_outputs: list[torch.Tensor] = field(default_factory=list)


@dataclass
class RolloutDebugTensors:
"""Container for rollout debug tensors collected during denoising."""

rollout_variance_noises: torch.Tensor | None = None
rollout_prev_sample_means: torch.Tensor | None = None
rollout_noise_std_devs: torch.Tensor | None = None
rollout_model_outputs: torch.Tensor | None = None


@dataclass
class RolloutTrajectoryData:
"""Container for rollout-specific trajectory outputs."""

rollout_log_probs: torch.Tensor | None = None
rollout_debug_tensors: RolloutDebugTensors | None = None
Loading
Loading