@@ -113,9 +113,6 @@ def __init__(
113113 self ._maybe_enable_torch_compile (transformer )
114114
115115 self .scheduler = scheduler
116- self .scheduler .sde_step_with_logprob = sde_step_with_logprob .__get__ (
117- self .scheduler , type (self .scheduler )
118- )
119116 self .vae = vae
120117 self .pipeline = weakref .ref (pipeline ) if pipeline else None
121118
@@ -1014,27 +1011,9 @@ def forward(
10141011 trajectory_log_probs : list [torch .Tensor ] = []
10151012 rollout_enabled = bool (batch .rollout )
10161013 rollout_sde_type = batch .rollout_sde_type
1017- if rollout_sde_type is None or str (rollout_sde_type ).strip () == "" :
1018- if rollout_enabled :
1019- logger .warning ("rollout_sde_type is not set, defaulting to 'sde'." )
1020- rollout_sde_type = "sde"
1021- else :
1022- rollout_sde_type = str (rollout_sde_type ).strip ().lower ()
1023- if rollout_sde_type not in ("sde" , "cps" ):
1024- logger .warning (
1025- "Unknown rollout_sde_type '%s', using default 'sde'." ,
1026- rollout_sde_type ,
1027- )
1028- rollout_sde_type = "sde"
10291014
10301015 rollout_noise_level = batch .rollout_noise_level
10311016
1032- if rollout_enabled and not hasattr (self .scheduler , "sde_step_with_logprob" ):
1033- raise RuntimeError (
1034- f"Rollout is enabled, but scheduler '{ type (self .scheduler ).__name__ } ' "
1035- "does not provide sde_step_with_logprob."
1036- )
1037-
10381017 # Run denoising loop
10391018 denoising_start_time = time .time ()
10401019
@@ -1122,15 +1101,14 @@ def forward(
11221101
11231102 # Compute the previous noisy sample
11241103 if rollout_enabled :
1125- latents , step_log_prob , _ , _ = (
1126- self .scheduler .sde_step_with_logprob (
1127- model_output = noise_pred ,
1128- timestep = t_device ,
1129- sample = latents ,
1130- generator = batch .generator ,
1131- sde_type = rollout_sde_type ,
1132- noise_level = rollout_noise_level ,
1133- )
1104+ latents , step_log_prob = sde_step_with_logprob (
1105+ self .scheduler ,
1106+ model_output = noise_pred ,
1107+ timestep = t_device ,
1108+ sample = latents ,
1109+ generator = batch .generator ,
1110+ sde_type = rollout_sde_type ,
1111+ noise_level = rollout_noise_level ,
11341112 )
11351113 trajectory_log_probs .append (step_log_prob )
11361114 else :
0 commit comments