File tree Expand file tree Collapse file tree 4 files changed +8
-24
lines changed Expand file tree Collapse file tree 4 files changed +8
-24
lines changed Original file line number Diff line number Diff line change @@ -548,16 +548,12 @@ def __len__(self):
548548 return self .config .num_train_timesteps
549549
550550 def previous_timestep (self , timestep ):
551- if self .custom_timesteps :
551+ if self .custom_timesteps or self . num_inference_steps :
552552 index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
553553 if index == self .timesteps .shape [0 ] - 1 :
554554 prev_t = torch .tensor (- 1 )
555555 else :
556556 prev_t = self .timesteps [index + 1 ]
557557 else :
558- num_inference_steps = (
559- self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
560- )
561- prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
562-
558+ prev_t = timestep - 1
563559 return prev_t
Original file line number Diff line number Diff line change @@ -639,16 +639,12 @@ def __len__(self):
639639
640640 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
641641 def previous_timestep (self , timestep ):
642- if self .custom_timesteps :
642+ if self .custom_timesteps or self . num_inference_steps :
643643 index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
644644 if index == self .timesteps .shape [0 ] - 1 :
645645 prev_t = torch .tensor (- 1 )
646646 else :
647647 prev_t = self .timesteps [index + 1 ]
648648 else :
649- num_inference_steps = (
650- self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
651- )
652- prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
653-
649+ prev_t = timestep - 1
654650 return prev_t
Original file line number Diff line number Diff line change @@ -643,16 +643,12 @@ def __len__(self):
643643
644644 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
645645 def previous_timestep (self , timestep ):
646- if self .custom_timesteps :
646+ if self .custom_timesteps or self . num_inference_steps :
647647 index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
648648 if index == self .timesteps .shape [0 ] - 1 :
649649 prev_t = torch .tensor (- 1 )
650650 else :
651651 prev_t = self .timesteps [index + 1 ]
652652 else :
653- num_inference_steps = (
654- self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
655- )
656- prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
657-
653+ prev_t = timestep - 1
658654 return prev_t
Original file line number Diff line number Diff line change @@ -680,16 +680,12 @@ def __len__(self):
680680
681681 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
682682 def previous_timestep (self , timestep ):
683- if self .custom_timesteps :
683+ if self .custom_timesteps or self . num_inference_steps :
684684 index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
685685 if index == self .timesteps .shape [0 ] - 1 :
686686 prev_t = torch .tensor (- 1 )
687687 else :
688688 prev_t = self .timesteps [index + 1 ]
689689 else :
690- num_inference_steps = (
691- self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
692- )
693- prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
694-
690+ prev_t = timestep - 1
695691 return prev_t
You can’t perform that action at this time.
0 commit comments