-
Notifications
You must be signed in to change notification settings - Fork 5k
[diffusion] feat: add rollout log_prob with flow-matching SDE/CPS support
#18806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """Flow-matching rollout step utilities for log-prob computation.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we adapt from other open-source diffusion workflows, we shall add acknowledgment here. |
||
|
|
||
| import math | ||
| from typing import Any, Optional, Union | ||
|
|
||
| import torch | ||
| from diffusers.utils.torch_utils import randn_tensor | ||
|
|
||
|
|
||
| def sde_step_with_logprob( | ||
| self: Any, | ||
| model_output: torch.FloatTensor, | ||
| sample: torch.FloatTensor, | ||
| step_index: int, | ||
| noise_level: float = 0.7, | ||
| prev_sample: Optional[torch.FloatTensor] = None, | ||
| generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, | ||
| sde_type: str = "sde", | ||
| ): | ||
| """Run one rollout step and compute per-sample log_prob. | ||
| sde_type="sde" uses the Gaussian transition objective. | ||
| sde_type="cps" uses the simplified CPS objective. | ||
| """ | ||
| sample_dtype = sample.dtype | ||
| model_output = model_output.float() | ||
| sample = sample.float() | ||
| if prev_sample is not None: | ||
| prev_sample = prev_sample.float() | ||
|
|
||
| step_indices = torch.full( | ||
| (sample.shape[0],), | ||
| int(step_index), | ||
| device=self.sigmas.device, | ||
| dtype=torch.long, | ||
| ) | ||
| prev_step_indices = (step_indices + 1).clamp_max(len(self.sigmas) - 1) | ||
| sigma = self.sigmas[step_indices].to(device=sample.device, dtype=sample.dtype) | ||
| sigma_prev = self.sigmas[prev_step_indices].to( | ||
| device=sample.device, dtype=sample.dtype | ||
| ) | ||
| sigma = sigma.view(-1, *([1] * (sample.ndim - 1))) | ||
| sigma_prev = sigma_prev.view(-1, *([1] * (sample.ndim - 1))) | ||
| sigma_max = self.sigmas[min(1, len(self.sigmas) - 1)].to( | ||
| device=sample.device, dtype=sample.dtype | ||
| ) | ||
| dt = sigma_prev - sigma | ||
|
|
||
| if sde_type == "sde": | ||
| denom_sigma = 1 - torch.where( | ||
| torch.isclose(sigma, sigma.new_tensor(1.0)), sigma_max, sigma | ||
| ) | ||
| std_dev_t = torch.sqrt((sigma / denom_sigma).clamp_min(1e-12)) * noise_level | ||
| prev_sample_mean = ( | ||
| sample * (1 + std_dev_t**2 / (2 * sigma) * dt) | ||
| + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt | ||
| ) | ||
|
|
||
| sqrt_neg_dt = torch.sqrt((-dt).clamp_min(1e-12)) | ||
| if prev_sample is None: | ||
| variance_noise = randn_tensor( | ||
| model_output.shape, | ||
| generator=generator, | ||
| device=model_output.device, | ||
| dtype=model_output.dtype, | ||
| ) | ||
| prev_sample = prev_sample_mean + std_dev_t * sqrt_neg_dt * variance_noise | ||
|
|
||
| std = (std_dev_t * sqrt_neg_dt).clamp_min(1e-12) | ||
| log_prob = ( | ||
| -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2)) | ||
| - torch.log(std) | ||
| - torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device))) | ||
| ) | ||
| elif sde_type == "cps": | ||
| std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) | ||
| pred_original_sample = sample - sigma * model_output | ||
| noise_estimate = sample + model_output * (1 - sigma) | ||
| sigma_delta = (sigma_prev**2 - std_dev_t**2).clamp_min(0.0) | ||
| prev_sample_mean = pred_original_sample * ( | ||
| 1 - sigma_prev | ||
| ) + noise_estimate * torch.sqrt(sigma_delta) | ||
|
|
||
| if prev_sample is None: | ||
| variance_noise = randn_tensor( | ||
| model_output.shape, | ||
| generator=generator, | ||
| device=model_output.device, | ||
| dtype=model_output.dtype, | ||
| ) | ||
| prev_sample = prev_sample_mean + std_dev_t * variance_noise | ||
|
|
||
| # CPS transition is Gaussian with std_dev_t, so compute a valid log-probability. | ||
| std = std_dev_t.clamp_min(1e-12) | ||
| log_prob = ( | ||
| -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2)) | ||
| - torch.log(std) | ||
| - torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device))) | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unsupported sde_type: {sde_type}") | ||
|
|
||
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | ||
| return prev_sample.to(sample_dtype), log_prob | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,9 @@ | |
| TransformerLoader, | ||
| ) | ||
| from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context | ||
| from sglang.multimodal_gen.runtime.pipelines.patches.flow_matching_with_logprob import ( | ||
| sde_step_with_logprob, | ||
| ) | ||
| from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req | ||
| from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( | ||
| PipelineStage, | ||
|
|
@@ -695,6 +698,7 @@ def _post_denoising_loop( | |
| latents: torch.Tensor, | ||
| trajectory_latents: list, | ||
| trajectory_timesteps: list, | ||
| trajectory_log_probs: list, | ||
| server_args: ServerArgs, | ||
| is_warmup: bool = False, | ||
| ): | ||
|
|
@@ -705,6 +709,10 @@ def _post_denoising_loop( | |
| else: | ||
| trajectory_tensor = None | ||
| trajectory_timesteps_tensor = None | ||
| if trajectory_log_probs: | ||
| trajectory_log_probs_tensor = torch.stack(trajectory_log_probs, dim=1) | ||
| else: | ||
| trajectory_log_probs_tensor = None | ||
|
|
||
| # Gather results if using sequence parallelism | ||
| latents, trajectory_tensor = self._postprocess_sp_latents( | ||
|
|
@@ -731,6 +739,8 @@ def _post_denoising_loop( | |
| if trajectory_tensor is not None and trajectory_timesteps_tensor is not None: | ||
| batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu() | ||
| batch.trajectory_latents = trajectory_tensor.cpu() | ||
| if trajectory_log_probs_tensor is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No sure wether this works in SP. We can leave it to the future.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like SP isn’t supported. Could we add a TODO or FIXME? |
||
| batch.trajectory_log_probs = trajectory_log_probs_tensor.cpu() | ||
|
|
||
| # Update batch with final latents | ||
| batch.latents = self.server_args.pipeline_config.post_denoising_loop( | ||
|
|
@@ -998,6 +1008,11 @@ def forward( | |
| # Initialize lists for ODE trajectory | ||
| trajectory_timesteps: list[torch.Tensor] = [] | ||
| trajectory_latents: list[torch.Tensor] = [] | ||
| trajectory_log_probs: list[torch.Tensor] = [] | ||
| rollout_enabled = bool(batch.rollout) | ||
| rollout_sde_type = batch.rollout_sde_type | ||
|
|
||
| rollout_noise_level = batch.rollout_noise_level | ||
|
|
||
| # Run denoising loop | ||
| denoising_start_time = time.time() | ||
|
|
@@ -1006,6 +1021,13 @@ def forward( | |
| is_warmup = batch.is_warmup | ||
| self.scheduler.set_begin_index(0) | ||
| timesteps_cpu = timesteps.cpu() | ||
| rollout_step_indices: list[int] = [] | ||
| if rollout_enabled: | ||
| scheduler_timesteps = self.scheduler.timesteps | ||
| rollout_step_indices = [ | ||
| self.scheduler.index_for_timestep(t.to(scheduler_timesteps.device)) | ||
| for t in timesteps_cpu | ||
| ] | ||
| num_timesteps = timesteps_cpu.shape[0] | ||
| with torch.autocast( | ||
| device_type=current_platform.device_type, | ||
|
|
@@ -1085,13 +1107,25 @@ def forward( | |
| batch.noise_pred = noise_pred | ||
|
|
||
| # Compute the previous noisy sample | ||
| latents = self.scheduler.step( | ||
| model_output=noise_pred, | ||
| timestep=t_device, | ||
| sample=latents, | ||
| **extra_step_kwargs, | ||
| return_dict=False, | ||
| )[0] | ||
| if rollout_enabled: | ||
| latents, step_log_prob = sde_step_with_logprob( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This rollout path bypasses That seems not equivalent for multi-step schedulers like Is this by design? |
||
| self.scheduler, | ||
| model_output=noise_pred, | ||
| sample=latents, | ||
| step_index=rollout_step_indices[i], | ||
| generator=batch.generator, | ||
| sde_type=rollout_sde_type, | ||
| noise_level=rollout_noise_level, | ||
| ) | ||
| trajectory_log_probs.append(step_log_prob) | ||
| else: | ||
| latents = self.scheduler.step( | ||
| model_output=noise_pred, | ||
| timestep=t_device, | ||
| sample=latents, | ||
| **extra_step_kwargs, | ||
| return_dict=False, | ||
| )[0] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a little bit unclear to me that
In deed I am not so sure about the process of SDE and CPS. Shall ask for help on design from BBuf, mick and Yuhao. |
||
|
|
||
| latents = self.post_forward_for_ti2v_task( | ||
| batch, server_args, reserved_frames_mask, latents, z | ||
|
|
@@ -1126,6 +1160,7 @@ def forward( | |
| latents=latents, | ||
| trajectory_latents=trajectory_latents, | ||
| trajectory_timesteps=trajectory_timesteps, | ||
| trajectory_log_probs=trajectory_log_probs, | ||
| server_args=server_args, | ||
| is_warmup=is_warmup, | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we validate
rollout_sde_typein_validateso the request fails early