Skip to content

Commit 5d33f3f

Browse files
Fix qkv conversion logic for CogView4 to Diffusers format
1 parent b04f15d commit 5d33f3f

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

scripts/convert_cogview4_to_diffusers.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,40 @@ def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
103103
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
104104
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
105105

106+
# qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
107+
# qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
108+
# q, k, v = qkv_weight.chunk(3, dim=0)
109+
# q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
110+
#
111+
# new_state_dict[block_prefix + "attn1.to_q.weight"] = q
112+
# new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
113+
# new_state_dict[block_prefix + "attn1.to_k.weight"] = k
114+
# new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
115+
# new_state_dict[block_prefix + "attn1.to_v.weight"] = v
116+
# new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
117+
106118
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
107119
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
120+
121+
num_heads = 32
122+
hidden_dim = 4096
123+
head_dim = qkv_weight.shape[0] // (3 * num_heads)
124+
qkv_weight = qkv_weight.view(num_heads, 3, head_dim, hidden_dim)
125+
qkv_bias = qkv_bias.view(num_heads, 3, head_dim)
126+
127+
qkv_weight = qkv_weight.permute(1, 0, 2, 3) # (3, num_heads, head_dim, hidden_dim)
128+
qkv_bias = qkv_bias.permute(1, 0, 2) # (3, num_heads, head_dim)
129+
108130
q, k, v = qkv_weight.chunk(3, dim=0)
109131
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
110132

111-
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
112-
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
113-
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
114-
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
115-
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
116-
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
133+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
134+
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias.squeeze(0).reshape(num_heads * head_dim)
135+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
136+
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias.squeeze(0).reshape(num_heads * head_dim)
137+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
138+
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias.squeeze(0).reshape(num_heads * head_dim)
139+
117140

118141
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
119142
old_prefix + "attention.dense.weight"

0 commit comments

Comments
 (0)