Skip to content

Why not timesteps[final_timestep] if you want to decode x_0? #112

@jsrdcht

Description

@jsrdcht

There appears to be a discrepancy between the paper and the code. The paper suggests an attempt to decode the final_timestep from the mid_timestep, but the code performs decoding at the mid_timestep in both instances. My expected behavior is

latent_model_input = latents
        latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timesteps[mid_timestep])
        noise_pred = self.unet(
            latent_model_input,
            timesteps[mid_timestep],
            prompt_embeds,
            added_cond_kwargs=unet_added_conditions,
        ).sample
        pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[**final_timestep**], latents).pred_original_sample.to(self.weight_dtype)

Code in this repo:

for i, t in enumerate(timesteps[:mid_timestep]):
            with torch.no_grad():
                latent_model_input = latents
                latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    prompt_embeds,
                    added_cond_kwargs=unet_added_conditions,
                ).sample
                latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
        
        latent_model_input = latents
        latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timesteps[mid_timestep])
        noise_pred = self.unet(
            latent_model_input,
            timesteps[mid_timestep],
            prompt_embeds,
            added_cond_kwargs=unet_added_conditions,
        ).sample
        pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[mid_timestep], latents).pred_original_sample.to(self.weight_dtype)
        
        pred_original_sample = 1 / self.vae.config.scaling_factor * pred_original_sample
        image = self.vae.decode(pred_original_sample.to(self.weight_dtype)).sample
        image = (image / 2 + 0.5).clamp(0, 1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions