|
24 | 24 | Attention, |
25 | 25 | AttentionProcessor, |
26 | 26 | CogVideoXAttnProcessor2_0, |
27 | | - FusedCogVideoXAttnProcessor2_0, |
28 | 27 | ) |
29 | 28 | from ...models.modeling_utils import ModelMixin |
30 | 29 | from ...models.normalization import AdaLayerNormContinuous |
@@ -277,46 +276,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
277 | 276 | for name, module in self.named_children(): |
278 | 277 | fn_recursive_attn_processor(name, module, processor) |
279 | 278 |
|
280 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 |
281 | | - def fuse_qkv_projections(self): |
282 | | - """ |
283 | | - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
284 | | - are fused. For cross-attention modules, key and value projection matrices are fused. |
285 | | -
|
286 | | - <Tip warning={true}> |
287 | | -
|
288 | | - This API is 🧪 experimental. |
289 | | -
|
290 | | - </Tip> |
291 | | - """ |
292 | | - self.original_attn_processors = None |
293 | | - |
294 | | - for _, attn_processor in self.attn_processors.items(): |
295 | | - if "Added" in str(attn_processor.__class__.__name__): |
296 | | - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
297 | | - |
298 | | - self.original_attn_processors = self.attn_processors |
299 | | - |
300 | | - for module in self.modules(): |
301 | | - if isinstance(module, Attention): |
302 | | - module.fuse_projections(fuse=True) |
303 | | - |
304 | | - self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) |
305 | | - |
306 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections |
307 | | - def unfuse_qkv_projections(self): |
308 | | - """Disables the fused QKV projection if enabled. |
309 | | -
|
310 | | - <Tip warning={true}> |
311 | | -
|
312 | | - This API is 🧪 experimental. |
313 | | -
|
314 | | - </Tip> |
315 | | -
|
316 | | - """ |
317 | | - if self.original_attn_processors is not None: |
318 | | - self.set_attn_processor(self.original_attn_processors) |
319 | | - |
320 | 279 | def _set_gradient_checkpointing(self, module, value=False): |
321 | 280 | if hasattr(module, "gradient_checkpointing"): |
322 | 281 | module.gradient_checkpointing = value |
|
0 commit comments