Skip to content

Commit dcaadb8

Browse files
committed
[diffusion] fix: address rollout log_prob review feedback (#18806)
1 parent d2c5cad commit dcaadb8

File tree

2 files changed

+17
-38
lines changed

2 files changed

+17
-38
lines changed

python/sglang/multimodal_gen/runtime/pipelines/patches/flow_matching_with_logprob.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,14 @@ def sde_step_with_logprob(
5353
timestep_tensor = _as_timestep_tensor(timestep, batch_size, sample.device)
5454
step_indices = torch.tensor(
5555
[self.index_for_timestep(t.to(self.timesteps.device)) for t in timestep_tensor],
56-
device=sample.device,
56+
device=self.sigmas.device,
5757
dtype=torch.long,
5858
)
5959
prev_step_indices = (step_indices + 1).clamp_max(len(self.sigmas) - 1)
60-
step_indices = step_indices.to(device=self.sigmas.device)
61-
prev_step_indices = prev_step_indices.to(device=self.sigmas.device)
62-
63-
sigma = self.sigmas[step_indices].to(sample.device).to(sample.dtype)
64-
sigma_prev = self.sigmas[prev_step_indices].to(sample.device).to(sample.dtype)
60+
sigma = self.sigmas[step_indices].to(device=sample.device, dtype=sample.dtype)
61+
sigma_prev = self.sigmas[prev_step_indices].to(
62+
device=sample.device, dtype=sample.dtype
63+
)
6564
sigma = sigma.view(-1, *([1] * (sample.ndim - 1)))
6665
sigma_prev = sigma_prev.view(-1, *([1] * (sample.ndim - 1)))
6766
sigma_max = self.sigmas[min(1, len(self.sigmas) - 1)].to(
@@ -70,7 +69,9 @@ def sde_step_with_logprob(
7069
dt = sigma_prev - sigma
7170

7271
if sde_type == "sde":
73-
denom_sigma = 1 - torch.where(sigma == 1, sigma_max, sigma)
72+
denom_sigma = 1 - torch.where(
73+
torch.isclose(sigma, sigma.new_tensor(1.0)), sigma_max, sigma
74+
)
7475
std_dev_t = torch.sqrt((sigma / denom_sigma).clamp_min(1e-12)) * noise_level
7576
prev_sample_mean = (
7677
sample * (1 + std_dev_t**2 / (2 * sigma) * dt)
@@ -117,4 +118,4 @@ def sde_step_with_logprob(
117118
raise ValueError(f"Unsupported sde_type: {sde_type}")
118119

119120
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
120-
return prev_sample.to(sample_dtype), log_prob, prev_sample_mean, std_dev_t
121+
return prev_sample.to(sample_dtype), log_prob

python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ def __init__(
113113
self._maybe_enable_torch_compile(transformer)
114114

115115
self.scheduler = scheduler
116-
self.scheduler.sde_step_with_logprob = sde_step_with_logprob.__get__(
117-
self.scheduler, type(self.scheduler)
118-
)
119116
self.vae = vae
120117
self.pipeline = weakref.ref(pipeline) if pipeline else None
121118

@@ -1014,27 +1011,9 @@ def forward(
10141011
trajectory_log_probs: list[torch.Tensor] = []
10151012
rollout_enabled = bool(batch.rollout)
10161013
rollout_sde_type = batch.rollout_sde_type
1017-
if rollout_sde_type is None or str(rollout_sde_type).strip() == "":
1018-
if rollout_enabled:
1019-
logger.warning("rollout_sde_type is not set, defaulting to 'sde'.")
1020-
rollout_sde_type = "sde"
1021-
else:
1022-
rollout_sde_type = str(rollout_sde_type).strip().lower()
1023-
if rollout_sde_type not in ("sde", "cps"):
1024-
logger.warning(
1025-
"Unknown rollout_sde_type '%s', using default 'sde'.",
1026-
rollout_sde_type,
1027-
)
1028-
rollout_sde_type = "sde"
10291014

10301015
rollout_noise_level = batch.rollout_noise_level
10311016

1032-
if rollout_enabled and not hasattr(self.scheduler, "sde_step_with_logprob"):
1033-
raise RuntimeError(
1034-
f"Rollout is enabled, but scheduler '{type(self.scheduler).__name__}' "
1035-
"does not provide sde_step_with_logprob."
1036-
)
1037-
10381017
# Run denoising loop
10391018
denoising_start_time = time.time()
10401019

@@ -1122,15 +1101,14 @@ def forward(
11221101

11231102
# Compute the previous noisy sample
11241103
if rollout_enabled:
1125-
latents, step_log_prob, _, _ = (
1126-
self.scheduler.sde_step_with_logprob(
1127-
model_output=noise_pred,
1128-
timestep=t_device,
1129-
sample=latents,
1130-
generator=batch.generator,
1131-
sde_type=rollout_sde_type,
1132-
noise_level=rollout_noise_level,
1133-
)
1104+
latents, step_log_prob = sde_step_with_logprob(
1105+
self.scheduler,
1106+
model_output=noise_pred,
1107+
timestep=t_device,
1108+
sample=latents,
1109+
generator=batch.generator,
1110+
sde_type=rollout_sde_type,
1111+
noise_level=rollout_noise_level,
11341112
)
11351113
trajectory_log_probs.append(step_log_prob)
11361114
else:

0 commit comments

Comments
 (0)