Skip to content

Commit 4abb27b

Browse files
committed
[diffusion] fix: precompute rollout step index and correct cps log_prob (#18806)
1 parent dcaadb8 commit 4abb27b

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,11 @@
88
from diffusers.utils.torch_utils import randn_tensor
99

1010

11-
def _as_timestep_tensor(
12-
timestep: Union[float, torch.Tensor], batch_size: int, device: torch.device
13-
) -> torch.Tensor:
14-
"""Normalize timestep input to a 1D tensor on the target device."""
15-
if torch.is_tensor(timestep):
16-
ts = timestep.to(device=device)
17-
else:
18-
ts = torch.tensor([timestep], device=device)
19-
20-
if ts.ndim == 0:
21-
ts = ts.view(1)
22-
else:
23-
ts = ts.view(-1)
24-
25-
# Broadcast scalar timestep to match batch size.
26-
if ts.numel() == 1 and batch_size > 1:
27-
ts = ts.repeat(batch_size)
28-
return ts
29-
30-
3111
def sde_step_with_logprob(
3212
self: Any,
3313
model_output: torch.FloatTensor,
34-
timestep: Union[float, torch.FloatTensor],
3514
sample: torch.FloatTensor,
15+
step_index: int,
3616
noise_level: float = 0.7,
3717
prev_sample: Optional[torch.FloatTensor] = None,
3818
generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
@@ -49,10 +29,9 @@ def sde_step_with_logprob(
4929
if prev_sample is not None:
5030
prev_sample = prev_sample.float()
5131

52-
batch_size = sample.shape[0]
53-
timestep_tensor = _as_timestep_tensor(timestep, batch_size, sample.device)
54-
step_indices = torch.tensor(
55-
[self.index_for_timestep(t.to(self.timesteps.device)) for t in timestep_tensor],
32+
step_indices = torch.full(
33+
(sample.shape[0],),
34+
int(step_index),
5635
device=self.sigmas.device,
5736
dtype=torch.long,
5837
)
@@ -112,8 +91,13 @@ def sde_step_with_logprob(
11291
)
11392
prev_sample = prev_sample_mean + std_dev_t * variance_noise
11493

115-
# Keep the same simplified cps objective used in the original patch.
116-
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
94+
# CPS transition is Gaussian with std_dev_t, so compute a valid log-probability.
95+
std = std_dev_t.clamp_min(1e-12)
96+
log_prob = (
97+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2))
98+
- torch.log(std)
99+
- torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device)))
100+
)
117101
else:
118102
raise ValueError(f"Unsupported sde_type: {sde_type}")
119103

python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,13 @@ def forward(
10211021
is_warmup = batch.is_warmup
10221022
self.scheduler.set_begin_index(0)
10231023
timesteps_cpu = timesteps.cpu()
1024+
rollout_step_indices: list[int] = []
1025+
if rollout_enabled:
1026+
scheduler_timesteps = self.scheduler.timesteps
1027+
rollout_step_indices = [
1028+
self.scheduler.index_for_timestep(t.to(scheduler_timesteps.device))
1029+
for t in timesteps_cpu
1030+
]
10241031
num_timesteps = timesteps_cpu.shape[0]
10251032
with torch.autocast(
10261033
device_type=current_platform.device_type,
@@ -1104,8 +1111,8 @@ def forward(
11041111
latents, step_log_prob = sde_step_with_logprob(
11051112
self.scheduler,
11061113
model_output=noise_pred,
1107-
timestep=t_device,
11081114
sample=latents,
1115+
step_index=rollout_step_indices[i],
11091116
generator=batch.generator,
11101117
sde_type=rollout_sde_type,
11111118
noise_level=rollout_noise_level,

0 commit comments

Comments
 (0)