Skip to content

Commit 2158f00

Browse files
committed
remove qkv fusion
1 parent 9d9b0b2 commit 2158f00

File tree

2 files changed

+0
-54
lines changed

2 files changed

+0
-54
lines changed

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
Attention,
2525
AttentionProcessor,
2626
CogVideoXAttnProcessor2_0,
27-
FusedCogVideoXAttnProcessor2_0,
2827
)
2928
from ...models.modeling_utils import ModelMixin
3029
from ...models.normalization import AdaLayerNormContinuous
@@ -277,46 +276,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
277276
for name, module in self.named_children():
278277
fn_recursive_attn_processor(name, module, processor)
279278

280-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
281-
def fuse_qkv_projections(self):
282-
"""
283-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
284-
are fused. For cross-attention modules, key and value projection matrices are fused.
285-
286-
<Tip warning={true}>
287-
288-
This API is 🧪 experimental.
289-
290-
</Tip>
291-
"""
292-
self.original_attn_processors = None
293-
294-
for _, attn_processor in self.attn_processors.items():
295-
if "Added" in str(attn_processor.__class__.__name__):
296-
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
297-
298-
self.original_attn_processors = self.attn_processors
299-
300-
for module in self.modules():
301-
if isinstance(module, Attention):
302-
module.fuse_projections(fuse=True)
303-
304-
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
305-
306-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
307-
def unfuse_qkv_projections(self):
308-
"""Disables the fused QKV projection if enabled.
309-
310-
<Tip warning={true}>
311-
312-
This API is 🧪 experimental.
313-
314-
</Tip>
315-
316-
"""
317-
if self.original_attn_processors is not None:
318-
self.set_attn_processor(self.original_attn_processors)
319-
320279
def _set_gradient_checkpointing(self, module, value=False):
321280
if hasattr(module, "gradient_checkpointing"):
322281
module.gradient_checkpointing = value

src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -397,19 +397,6 @@ def check_inputs(
397397
f" {negative_prompt_embeds.shape}."
398398
)
399399

400-
def fuse_qkv_projections(self) -> None:
401-
r"""Enables fused QKV projections."""
402-
self.fusing_transformer = True
403-
self.transformer.fuse_qkv_projections()
404-
405-
def unfuse_qkv_projections(self) -> None:
406-
r"""Disable QKV projection fusion if enabled."""
407-
if not self.fusing_transformer:
408-
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
409-
else:
410-
self.transformer.unfuse_qkv_projections()
411-
self.fusing_transformer = False
412-
413400
@property
414401
def guidance_scale(self):
415402
return self._guidance_scale

0 commit comments

Comments
 (0)