88from diffusers .utils .torch_utils import randn_tensor
99
1010
11- def _as_timestep_tensor (
12- timestep : Union [float , torch .Tensor ], batch_size : int , device : torch .device
13- ) -> torch .Tensor :
14- """Normalize timestep input to a 1D tensor on the target device."""
15- if torch .is_tensor (timestep ):
16- ts = timestep .to (device = device )
17- else :
18- ts = torch .tensor ([timestep ], device = device )
19-
20- if ts .ndim == 0 :
21- ts = ts .view (1 )
22- else :
23- ts = ts .view (- 1 )
24-
25- # Broadcast scalar timestep to match batch size.
26- if ts .numel () == 1 and batch_size > 1 :
27- ts = ts .repeat (batch_size )
28- return ts
29-
30-
3111def sde_step_with_logprob (
3212 self : Any ,
3313 model_output : torch .FloatTensor ,
34- timestep : Union [float , torch .FloatTensor ],
3514 sample : torch .FloatTensor ,
15+ step_index : int ,
3616 noise_level : float = 0.7 ,
3717 prev_sample : Optional [torch .FloatTensor ] = None ,
3818 generator : Optional [Union [torch .Generator , list [torch .Generator ]]] = None ,
@@ -49,10 +29,9 @@ def sde_step_with_logprob(
4929 if prev_sample is not None :
5030 prev_sample = prev_sample .float ()
5131
52- batch_size = sample .shape [0 ]
53- timestep_tensor = _as_timestep_tensor (timestep , batch_size , sample .device )
54- step_indices = torch .tensor (
55- [self .index_for_timestep (t .to (self .timesteps .device )) for t in timestep_tensor ],
32+ step_indices = torch .full (
33+ (sample .shape [0 ],),
34+ int (step_index ),
5635 device = self .sigmas .device ,
5736 dtype = torch .long ,
5837 )
@@ -112,8 +91,13 @@ def sde_step_with_logprob(
11291 )
11392 prev_sample = prev_sample_mean + std_dev_t * variance_noise
11493
115- # Keep the same simplified cps objective used in the original patch.
116- log_prob = - ((prev_sample .detach () - prev_sample_mean ) ** 2 )
94+ # CPS transition is Gaussian with std_dev_t, so compute a valid log-probability.
95+ std = std_dev_t .clamp_min (1e-12 )
96+ log_prob = (
97+ - ((prev_sample .detach () - prev_sample_mean ) ** 2 ) / (2 * (std ** 2 ))
98+ - torch .log (std )
99+ - torch .log (torch .sqrt (torch .as_tensor (2 * math .pi , device = std .device )))
100+ )
117101 else :
118102 raise ValueError (f"Unsupported sde_type: { sde_type } " )
119103
0 commit comments