Skip to content

Commit 6a3a07f

Browse files
committed
[bugfix] fix dimension mismatch in CogView4 attention
1 parent f826aec commit 6a3a07f

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,9 +2837,9 @@ def __call__(
28372837
inner_dim = key.shape[-1]
28382838
head_dim = inner_dim // attn.heads
28392839

2840-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2841-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2842-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2840+
query = query.view(batch_size, -1, attn.heads, head_dim)
2841+
key = key.view(batch_size, -1, attn.heads, head_dim)
2842+
value = value.view(batch_size, -1, attn.heads, head_dim)
28432843

28442844
###############################################3
28452845
# TODO: 直接用qkv_weight算出qkv(注意要先分出num_heads, head_dim),再在head_dims上拆出qkv
@@ -2850,7 +2850,6 @@ def __call__(
28502850
qkv = qkv.view(batch_size, -1, attn.heads, head_dim * 3)
28512851
query, key, value = qkv.chunk(3, dim=-1)
28522852

2853-
28542853
# TODO: 校验rope是否apply正确(目前有25%的误差)
28552854
###############################################3
28562855

0 commit comments

Comments
 (0)