-
Notifications
You must be signed in to change notification settings - Fork 83
Open
Description
There appears to be a discrepancy between the paper and the code. The paper suggests an attempt to decode the final_timestep from the mid_timestep, but the code performs decoding at the mid_timestep in both instances. My expected behavior is
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timesteps[mid_timestep])
noise_pred = self.unet(
latent_model_input,
timesteps[mid_timestep],
prompt_embeds,
added_cond_kwargs=unet_added_conditions,
).sample
pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[**final_timestep**], latents).pred_original_sample.to(self.weight_dtype)
Code in this repo:
for i, t in enumerate(timesteps[:mid_timestep]):
with torch.no_grad():
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
noise_pred = self.unet(
latent_model_input,
t,
prompt_embeds,
added_cond_kwargs=unet_added_conditions,
).sample
latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timesteps[mid_timestep])
noise_pred = self.unet(
latent_model_input,
timesteps[mid_timestep],
prompt_embeds,
added_cond_kwargs=unet_added_conditions,
).sample
pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[mid_timestep], latents).pred_original_sample.to(self.weight_dtype)
pred_original_sample = 1 / self.vae.config.scaling_factor * pred_original_sample
image = self.vae.decode(pred_original_sample.to(self.weight_dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
Metadata
Metadata
Assignees
Labels
No labels