Skip to content

Commit 5cc2ade

Browse files
authored
Merge branch 'main' into cogview3-plus
2 parents 80e7cca + 38a3e4d commit 5cc2ade

18 files changed

+1559
-8
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
5353
- all
5454
- __call__
5555

56+
## StableDiffusionPAGImg2ImgPipeline
57+
[[autodoc]] StableDiffusionPAGImg2ImgPipeline
58+
- all
59+
- __call__
60+
5661
## StableDiffusionControlNetPAGPipeline
5762
[[autodoc]] StableDiffusionControlNetPAGPipeline
5863

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@
346346
"StableDiffusionLatentUpscalePipeline",
347347
"StableDiffusionLDM3DPipeline",
348348
"StableDiffusionModelEditingPipeline",
349+
"StableDiffusionPAGImg2ImgPipeline",
349350
"StableDiffusionPAGPipeline",
350351
"StableDiffusionPanoramaPipeline",
351352
"StableDiffusionParadigmsPipeline",
@@ -799,6 +800,7 @@
799800
StableDiffusionLatentUpscalePipeline,
800801
StableDiffusionLDM3DPipeline,
801802
StableDiffusionModelEditingPipeline,
803+
StableDiffusionPAGImg2ImgPipeline,
802804
StableDiffusionPAGPipeline,
803805
StableDiffusionPanoramaPipeline,
804806
StableDiffusionParadigmsPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@
165165
"HunyuanDiTPAGPipeline",
166166
"StableDiffusion3PAGPipeline",
167167
"StableDiffusionPAGPipeline",
168+
"StableDiffusionPAGImg2ImgPipeline",
168169
"StableDiffusionControlNetPAGPipeline",
169170
"StableDiffusionXLPAGPipeline",
170171
"StableDiffusionXLPAGInpaintPipeline",
@@ -571,6 +572,7 @@
571572
StableDiffusion3PAGPipeline,
572573
StableDiffusionControlNetPAGInpaintPipeline,
573574
StableDiffusionControlNetPAGPipeline,
575+
StableDiffusionPAGImg2ImgPipeline,
574576
StableDiffusionPAGPipeline,
575577
StableDiffusionXLControlNetPAGImg2ImgPipeline,
576578
StableDiffusionXLControlNetPAGPipeline,

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
StableDiffusion3PAGPipeline,
6565
StableDiffusionControlNetPAGInpaintPipeline,
6666
StableDiffusionControlNetPAGPipeline,
67+
StableDiffusionPAGImg2ImgPipeline,
6768
StableDiffusionPAGPipeline,
6869
StableDiffusionXLControlNetPAGImg2ImgPipeline,
6970
StableDiffusionXLControlNetPAGPipeline,
@@ -133,6 +134,7 @@
133134
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
134135
("kandinsky3", Kandinsky3Img2ImgPipeline),
135136
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
137+
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
136138
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
137139
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
138140
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,10 @@ def cross_attention_kwargs(self):
893893
def num_timesteps(self):
894894
return self._num_timesteps
895895

896+
@property
897+
def interrupt(self):
898+
return self._interrupt
899+
896900
@torch.no_grad()
897901
@replace_example_docstring(EXAMPLE_DOC_STRING)
898902
def __call__(
@@ -1089,6 +1093,7 @@ def __call__(
10891093
self._guidance_scale = guidance_scale
10901094
self._clip_skip = clip_skip
10911095
self._cross_attention_kwargs = cross_attention_kwargs
1096+
self._interrupt = False
10921097

10931098
# 2. Define call parameters
10941099
if prompt is not None and isinstance(prompt, str):
@@ -1235,6 +1240,9 @@ def __call__(
12351240
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
12361241
with self.progress_bar(total=num_inference_steps) as progress_bar:
12371242
for i, t in enumerate(timesteps):
1243+
if self.interrupt:
1244+
continue
1245+
12381246
# Relevant thread:
12391247
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
12401248
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,10 @@ def cross_attention_kwargs(self):
891891
def num_timesteps(self):
892892
return self._num_timesteps
893893

894+
@property
895+
def interrupt(self):
896+
return self._interrupt
897+
894898
@torch.no_grad()
895899
@replace_example_docstring(EXAMPLE_DOC_STRING)
896900
def __call__(
@@ -1081,6 +1085,7 @@ def __call__(
10811085
self._guidance_scale = guidance_scale
10821086
self._clip_skip = clip_skip
10831087
self._cross_attention_kwargs = cross_attention_kwargs
1088+
self._interrupt = False
10841089

10851090
# 2. Define call parameters
10861091
if prompt is not None and isinstance(prompt, str):
@@ -1211,6 +1216,9 @@ def __call__(
12111216
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
12121217
with self.progress_bar(total=num_inference_steps) as progress_bar:
12131218
for i, t in enumerate(timesteps):
1219+
if self.interrupt:
1220+
continue
1221+
12141222
# expand the latents if we are doing classifier free guidance
12151223
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
12161224
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,10 @@ def cross_attention_kwargs(self):
976976
def num_timesteps(self):
977977
return self._num_timesteps
978978

979+
@property
980+
def interrupt(self):
981+
return self._interrupt
982+
979983
@torch.no_grad()
980984
@replace_example_docstring(EXAMPLE_DOC_STRING)
981985
def __call__(
@@ -1191,6 +1195,7 @@ def __call__(
11911195
self._guidance_scale = guidance_scale
11921196
self._clip_skip = clip_skip
11931197
self._cross_attention_kwargs = cross_attention_kwargs
1198+
self._interrupt = False
11941199

11951200
# 2. Define call parameters
11961201
if prompt is not None and isinstance(prompt, str):
@@ -1375,6 +1380,9 @@ def __call__(
13751380
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
13761381
with self.progress_bar(total=num_inference_steps) as progress_bar:
13771382
for i, t in enumerate(timesteps):
1383+
if self.interrupt:
1384+
continue
1385+
13781386
# expand the latents if we are doing classifier free guidance
13791387
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
13801388
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,10 @@ def cross_attention_kwargs(self):
11451145
def num_timesteps(self):
11461146
return self._num_timesteps
11471147

1148+
@property
1149+
def interrupt(self):
1150+
return self._interrupt
1151+
11481152
@torch.no_grad()
11491153
@replace_example_docstring(EXAMPLE_DOC_STRING)
11501154
def __call__(
@@ -1427,6 +1431,7 @@ def __call__(
14271431
self._guidance_scale = guidance_scale
14281432
self._clip_skip = clip_skip
14291433
self._cross_attention_kwargs = cross_attention_kwargs
1434+
self._interrupt = False
14301435

14311436
# 2. Define call parameters
14321437
if prompt is not None and isinstance(prompt, str):
@@ -1695,6 +1700,9 @@ def denoising_value_valid(dnv):
16951700

16961701
with self.progress_bar(total=num_inference_steps) as progress_bar:
16971702
for i, t in enumerate(timesteps):
1703+
if self.interrupt:
1704+
continue
1705+
16981706
# expand the latents if we are doing classifier free guidance
16991707
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
17001708

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,10 @@ def denoising_end(self):
990990
def num_timesteps(self):
991991
return self._num_timesteps
992992

993+
@property
994+
def interrupt(self):
995+
return self._interrupt
996+
993997
@torch.no_grad()
994998
@replace_example_docstring(EXAMPLE_DOC_STRING)
995999
def __call__(
@@ -1245,6 +1249,7 @@ def __call__(
12451249
self._clip_skip = clip_skip
12461250
self._cross_attention_kwargs = cross_attention_kwargs
12471251
self._denoising_end = denoising_end
1252+
self._interrupt = False
12481253

12491254
# 2. Define call parameters
12501255
if prompt is not None and isinstance(prompt, str):
@@ -1442,6 +1447,9 @@ def __call__(
14421447
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
14431448
with self.progress_bar(total=num_inference_steps) as progress_bar:
14441449
for i, t in enumerate(timesteps):
1450+
if self.interrupt:
1451+
continue
1452+
14451453
# Relevant thread:
14461454
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
14471455
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,10 @@ def cross_attention_kwargs(self):
10701070
def num_timesteps(self):
10711071
return self._num_timesteps
10721072

1073+
@property
1074+
def interrupt(self):
1075+
return self._interrupt
1076+
10731077
@torch.no_grad()
10741078
@replace_example_docstring(EXAMPLE_DOC_STRING)
10751079
def __call__(
@@ -1338,6 +1342,7 @@ def __call__(
13381342
self._guidance_scale = guidance_scale
13391343
self._clip_skip = clip_skip
13401344
self._cross_attention_kwargs = cross_attention_kwargs
1345+
self._interrupt = False
13411346

13421347
# 2. Define call parameters
13431348
if prompt is not None and isinstance(prompt, str):
@@ -1510,6 +1515,9 @@ def __call__(
15101515
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
15111516
with self.progress_bar(total=num_inference_steps) as progress_bar:
15121517
for i, t in enumerate(timesteps):
1518+
if self.interrupt:
1519+
continue
1520+
15131521
# expand the latents if we are doing classifier free guidance
15141522
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
15151523
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

0 commit comments

Comments
 (0)