Skip to content

Commit c7d1227

Browse files
committed
feat(scheduler): Add CogView scheduler implementation
1 parent 4a4afd5 commit c7d1227

File tree

1 file changed

+332
-0
lines changed

1 file changed

+332
-0
lines changed
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
16+
17+
from typing import List, Optional, Tuple, Union
18+
19+
import numpy as np
20+
import torch
21+
22+
from ..configuration_utils import ConfigMixin, register_to_config
23+
from ..utils.torch_utils import randn_tensor
24+
from .scheduling_ddim import DDIMSchedulerOutput
25+
from .scheduling_utils import SchedulerMixin
26+
27+
28+
class CogViewScheduler(SchedulerMixin, ConfigMixin):
29+
"""
30+
`CogViewScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
31+
32+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
33+
methods the library implements for all schedulers such as loading and saving.
34+
35+
Args:
36+
num_train_timesteps (`int`, defaults to 1000):
37+
The number of diffusion steps to train the model.
38+
beta_start (`float`, defaults to 0.00085):
39+
The starting `beta` value of inference.
40+
beta_end (`float`, defaults to 0.012):
41+
The final `beta` value.
42+
prediction_type (`str`, defaults to `v_prediction`):
43+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
44+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
45+
Video](https://imagen.research.google/video/paper.pdf) paper).
46+
timestep_spacing (`str`, defaults to `leading`):
47+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
48+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
49+
steps_offset (`int`, defaults to 0):
50+
An offset added to the inference steps, as required by some model families.
51+
num_inference_steps (`int`, defaults to 50):
52+
The number of inference steps to use.
53+
scale_factor (`float`, defaults to 1.0):
54+
Scaling factor to apply to the model input.
55+
snr_shift_scale (`float`, defaults to 1.0):
56+
Scale factor for shifting the signal-to-noise ratio.
57+
zero_snr (`bool`, defaults to True):
58+
Whether to adjust the alphas to achieve zero terminal SNR.
59+
"""
60+
61+
@register_to_config
62+
def __init__(
63+
self,
64+
num_train_timesteps: int = 1000,
65+
beta_start: float = 0.00085,
66+
beta_end: float = 0.012,
67+
prediction_type: str = "v_prediction",
68+
timestep_spacing: str = "leading",
69+
steps_offset: int = 0,
70+
num_inference_steps: int = 50,
71+
scale_factor: float = 1.0,
72+
snr_shift_scale: float = 1.0,
73+
zero_snr: bool = True,
74+
):
75+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
76+
77+
self.alphas = 1.0 - self.betas
78+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
79+
# SNR shift
80+
self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
81+
sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
82+
if zero_snr:
83+
sqrt_alphas_cumprod_0 = sqrt_alphas_cumprod[0]
84+
sqrt_alphas_cumprod_T_1 = sqrt_alphas_cumprod[-1]
85+
sqrt_alphas_cumprod -= sqrt_alphas_cumprod_T_1
86+
sqrt_alphas_cumprod *= sqrt_alphas_cumprod_0 / (sqrt_alphas_cumprod_0 - sqrt_alphas_cumprod_T_1)
87+
self.sqrt_alphas_cumprod = sqrt_alphas_cumprod
88+
self.sigmas = torch.sqrt(1 - sqrt_alphas_cumprod**2)
89+
90+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
91+
"""
92+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
93+
current timestep.
94+
95+
Args:
96+
sample (`torch.Tensor`):
97+
The input sample.
98+
timestep (`int`, *optional*):
99+
The current timestep in the diffusion chain.
100+
101+
Returns:
102+
`torch.Tensor`:
103+
A scaled input sample.
104+
"""
105+
return sample * self.scale_factor
106+
107+
def set_timesteps(
108+
self,
109+
num_inference_steps: Optional[int] = None,
110+
device: Union[str, torch.device] = None,
111+
timesteps: Optional[List[int]] = None,
112+
):
113+
"""
114+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
115+
116+
Args:
117+
num_inference_steps (`int`):
118+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
119+
`timesteps` must be `None`.
120+
device (`str` or `torch.device`, *optional*):
121+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
122+
timesteps (`List[int]`, *optional*):
123+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
124+
timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
125+
`num_inference_steps` must be `None`.
126+
127+
"""
128+
if num_inference_steps is not None and timesteps is not None:
129+
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
130+
131+
if timesteps is not None:
132+
for i in range(1, len(timesteps)):
133+
if timesteps[i] >= timesteps[i - 1]:
134+
raise ValueError("`custom_timesteps` must be in descending order.")
135+
136+
if timesteps[0] >= self.config.num_train_timesteps:
137+
raise ValueError(
138+
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
139+
)
140+
141+
timesteps = np.array(timesteps, dtype=np.int64)
142+
self.custom_timesteps = True
143+
else:
144+
if num_inference_steps > self.config.num_train_timesteps:
145+
raise ValueError(
146+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
147+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
148+
f" maximal {self.config.num_train_timesteps} timesteps."
149+
)
150+
151+
self.num_inference_steps = num_inference_steps
152+
self.custom_timesteps = False
153+
154+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
155+
if self.config.timestep_spacing == "linspace":
156+
timesteps = (
157+
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
158+
.round()[::-1]
159+
.copy()
160+
.astype(np.int64)
161+
)
162+
elif self.config.timestep_spacing == "leading":
163+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
164+
# creates integer timesteps by multiplying by ratio
165+
# casting to int to avoid issues when num_inference_step is power of 3
166+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
167+
timesteps += self.config.steps_offset
168+
elif self.config.timestep_spacing == "trailing":
169+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
170+
# creates integer timesteps by multiplying by ratio
171+
# casting to int to avoid issues when num_inference_step is power of 3
172+
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
173+
timesteps -= 1
174+
else:
175+
raise ValueError(
176+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
177+
)
178+
179+
self.timesteps = torch.from_numpy(timesteps).to(device)
180+
181+
def step(
182+
self,
183+
model_output: torch.Tensor,
184+
timestep: int,
185+
sample: torch.Tensor,
186+
eta: float = 1.0,
187+
generator=None,
188+
variance_noise: Optional[torch.Tensor] = None,
189+
return_dict: bool = True,
190+
) -> Union[DDIMSchedulerOutput, Tuple]:
191+
"""
192+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
193+
process from the learned model outputs (most often the predicted noise).
194+
195+
Args:
196+
model_output (`torch.Tensor`):
197+
The direct output from learned diffusion model.
198+
timestep (`float`):
199+
The current discrete timestep in the diffusion chain.
200+
sample (`torch.Tensor`):
201+
A current instance of a sample created by the diffusion process.
202+
eta (`float`):
203+
The weight of noise for added noise in diffusion step.
204+
use_clipped_model_output (`bool`, defaults to `False`):
205+
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
206+
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
207+
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
208+
`use_clipped_model_output` has no effect.
209+
generator (`torch.Generator`, *optional*):
210+
A random number generator.
211+
variance_noise (`torch.Tensor`):
212+
Alternative to generating noise with `generator` by directly providing the noise for the variance
213+
itself. Useful for methods such as [`CycleDiffusion`].
214+
return_dict (`bool`, *optional*, defaults to `True`):
215+
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
216+
217+
Returns:
218+
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
219+
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
220+
tuple is returned where the first element is the sample tensor.
221+
222+
"""
223+
if self.num_inference_steps is None:
224+
raise ValueError(
225+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
226+
)
227+
228+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
229+
# Ideally, read DDIM paper in-detail understanding
230+
231+
# Notation (<variable name> -> <name in paper>
232+
# - pred_noise_t -> e_theta(x_t, t)
233+
# - pred_original_sample -> f_theta(x_t, t) or x_0
234+
# - std_dev_t -> sigma_t
235+
# - eta -> η
236+
# - pred_sample_direction -> "direction pointing to x_t"
237+
# - pred_prev_sample -> "x_t-1"
238+
239+
# 1. get previous step value (=t-1)
240+
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
241+
242+
# 2. compute alphas, betas
243+
alpha_prod_t = self.alphas_cumprod[timestep]
244+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else 1.0
245+
sigma_t = eta * torch.sqrt(
246+
(1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
247+
)
248+
249+
beta_prod_t = 1 - alpha_prod_t
250+
251+
# 3. compute predicted original sample from predicted noise also called
252+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
253+
if self.config.prediction_type == "epsilon":
254+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
255+
pred_epsilon = model_output
256+
elif self.config.prediction_type == "sample":
257+
pred_original_sample = model_output
258+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
259+
elif self.config.prediction_type == "v_prediction":
260+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
261+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
262+
else:
263+
raise ValueError(
264+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
265+
" `v_prediction`"
266+
)
267+
268+
# 4. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
269+
pred_sample_direction = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5) * pred_epsilon
270+
271+
# 5. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
272+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
273+
274+
if eta > 0:
275+
if variance_noise is not None and generator is not None:
276+
raise ValueError(
277+
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
278+
" `variance_noise` stays `None`."
279+
)
280+
281+
if variance_noise is None:
282+
variance_noise = randn_tensor(
283+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
284+
)
285+
variance = sigma_t * variance_noise
286+
287+
prev_sample = prev_sample + variance
288+
289+
if not return_dict:
290+
return (
291+
prev_sample,
292+
pred_original_sample,
293+
)
294+
295+
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
296+
297+
def add_noise(
298+
self,
299+
original_samples: torch.Tensor,
300+
noise: torch.Tensor,
301+
timesteps: torch.IntTensor,
302+
apply_scale: bool = True,
303+
) -> torch.Tensor:
304+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
305+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
306+
# for the subsequent add_noise calls
307+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
308+
self.sigmas = self.sigmas.to(dtype=original_samples.dtype)
309+
timesteps = timesteps.to(original_samples.device)
310+
311+
sqrt_alpha_prod = self.sqrt_alphas_cumprod[timesteps]
312+
sigmas = self.sigmas[timesteps]
313+
assert sqrt_alpha_prod.dim() == 1, f"sqrt_alpha_prod must be a 1D tensor, got {sqrt_alpha_prod.dim()}D"
314+
assert sqrt_alpha_prod.shape == sigmas.shape, (
315+
f"sigmas and sqrt_alpha_prod must have the same shape, got {sigmas.shape} and {sqrt_alpha_prod.shape}"
316+
)
317+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
318+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
319+
sigmas = sigmas.unsqueeze(-1)
320+
321+
if apply_scale:
322+
original_samples = original_samples * self.scale_factor
323+
324+
# scale noise and original samples
325+
noise = noise * sigmas
326+
original_samples = original_samples * sqrt_alpha_prod
327+
328+
noisy_samples = noise + original_samples
329+
return noisy_samples
330+
331+
def __len__(self):
332+
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)