Skip to content

Commit e239c3c

Browse files
revert to sat to cogview4 version
1 parent b889b37 commit e239c3c

File tree

1 file changed

+6
-29
lines changed

1 file changed

+6
-29
lines changed

scripts/convert_cogview4_to_diffusers.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -103,40 +103,17 @@ def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
103103
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
104104
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
105105

106-
# qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
107-
# qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
108-
# q, k, v = qkv_weight.chunk(3, dim=0)
109-
# q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
110-
#
111-
# new_state_dict[block_prefix + "attn1.to_q.weight"] = q
112-
# new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
113-
# new_state_dict[block_prefix + "attn1.to_k.weight"] = k
114-
# new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
115-
# new_state_dict[block_prefix + "attn1.to_v.weight"] = v
116-
# new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
117-
118106
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
119107
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
120-
121-
num_heads = 32
122-
hidden_dim = 4096
123-
head_dim = qkv_weight.shape[0] // (3 * num_heads)
124-
qkv_weight = qkv_weight.view(num_heads, 3, head_dim, hidden_dim)
125-
qkv_bias = qkv_bias.view(num_heads, 3, head_dim)
126-
127-
qkv_weight = qkv_weight.permute(1, 0, 2, 3) # (3, num_heads, head_dim, hidden_dim)
128-
qkv_bias = qkv_bias.permute(1, 0, 2) # (3, num_heads, head_dim)
129-
130108
q, k, v = qkv_weight.chunk(3, dim=0)
131109
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
132110

133-
new_state_dict[block_prefix + "attn1.to_q.weight"] = q.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
134-
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias.squeeze(0).reshape(num_heads * head_dim)
135-
new_state_dict[block_prefix + "attn1.to_k.weight"] = k.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
136-
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias.squeeze(0).reshape(num_heads * head_dim)
137-
new_state_dict[block_prefix + "attn1.to_v.weight"] = v.squeeze(0).reshape(num_heads * head_dim, hidden_dim)
138-
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias.squeeze(0).reshape(num_heads * head_dim)
139-
111+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
112+
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
113+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
114+
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
115+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
116+
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
140117

141118
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
142119
old_prefix + "attention.dense.weight"

0 commit comments

Comments
 (0)