We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f826aec commit 6a3a07fCopy full SHA for 6a3a07f
src/diffusers/models/attention_processor.py
@@ -2837,9 +2837,9 @@ def __call__(
2837
inner_dim = key.shape[-1]
2838
head_dim = inner_dim // attn.heads
2839
2840
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2841
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2842
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
2843
2844
###############################################3
2845
# TODO: 直接用qkv_weight算出qkv(注意要先分出num_heads, head_dim),再在head_dims上拆出qkv
@@ -2850,7 +2850,6 @@ def __call__(
2850
qkv = qkv.view(batch_size, -1, attn.heads, head_dim * 3)
2851
query, key, value = qkv.chunk(3, dim=-1)
2852
2853
-
2854
# TODO: 校验rope是否apply正确(目前有25%的误差)
2855
2856
0 commit comments