@@ -103,17 +103,40 @@ def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
103103 new_state_dict [block_prefix + "norm1.linear.weight" ] = original_state_dict .pop (adaln_prefix + "1.weight" )
104104 new_state_dict [block_prefix + "norm1.linear.bias" ] = original_state_dict .pop (adaln_prefix + "1.bias" )
105105
106+ # qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
107+ # qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
108+ # q, k, v = qkv_weight.chunk(3, dim=0)
109+ # q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
110+ #
111+ # new_state_dict[block_prefix + "attn1.to_q.weight"] = q
112+ # new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
113+ # new_state_dict[block_prefix + "attn1.to_k.weight"] = k
114+ # new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
115+ # new_state_dict[block_prefix + "attn1.to_v.weight"] = v
116+ # new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
117+
106118 qkv_weight = original_state_dict .pop (old_prefix + "attention.query_key_value.weight" )
107119 qkv_bias = original_state_dict .pop (old_prefix + "attention.query_key_value.bias" )
120+
121+ num_heads = 32
122+ hidden_dim = 4096
123+ head_dim = qkv_weight .shape [0 ] // (3 * num_heads )
124+ qkv_weight = qkv_weight .view (num_heads , 3 , head_dim , hidden_dim )
125+ qkv_bias = qkv_bias .view (num_heads , 3 , head_dim )
126+
127+ qkv_weight = qkv_weight .permute (1 , 0 , 2 , 3 ) # (3, num_heads, head_dim, hidden_dim)
128+ qkv_bias = qkv_bias .permute (1 , 0 , 2 ) # (3, num_heads, head_dim)
129+
108130 q , k , v = qkv_weight .chunk (3 , dim = 0 )
109131 q_bias , k_bias , v_bias = qkv_bias .chunk (3 , dim = 0 )
110132
111- new_state_dict [block_prefix + "attn1.to_q.weight" ] = q
112- new_state_dict [block_prefix + "attn1.to_q.bias" ] = q_bias
113- new_state_dict [block_prefix + "attn1.to_k.weight" ] = k
114- new_state_dict [block_prefix + "attn1.to_k.bias" ] = k_bias
115- new_state_dict [block_prefix + "attn1.to_v.weight" ] = v
116- new_state_dict [block_prefix + "attn1.to_v.bias" ] = v_bias
133+ new_state_dict [block_prefix + "attn1.to_q.weight" ] = q .squeeze (0 ).reshape (num_heads * head_dim , hidden_dim )
134+ new_state_dict [block_prefix + "attn1.to_q.bias" ] = q_bias .squeeze (0 ).reshape (num_heads * head_dim )
135+ new_state_dict [block_prefix + "attn1.to_k.weight" ] = k .squeeze (0 ).reshape (num_heads * head_dim , hidden_dim )
136+ new_state_dict [block_prefix + "attn1.to_k.bias" ] = k_bias .squeeze (0 ).reshape (num_heads * head_dim )
137+ new_state_dict [block_prefix + "attn1.to_v.weight" ] = v .squeeze (0 ).reshape (num_heads * head_dim , hidden_dim )
138+ new_state_dict [block_prefix + "attn1.to_v.bias" ] = v_bias .squeeze (0 ).reshape (num_heads * head_dim )
139+
117140
118141 new_state_dict [block_prefix + "attn1.to_out.0.weight" ] = original_state_dict .pop (
119142 old_prefix + "attention.dense.weight"
0 commit comments