@@ -103,40 +103,17 @@ def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
103103 new_state_dict [block_prefix + "norm1.linear.weight" ] = original_state_dict .pop (adaln_prefix + "1.weight" )
104104 new_state_dict [block_prefix + "norm1.linear.bias" ] = original_state_dict .pop (adaln_prefix + "1.bias" )
105105
106- # qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
107- # qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
108- # q, k, v = qkv_weight.chunk(3, dim=0)
109- # q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
110- #
111- # new_state_dict[block_prefix + "attn1.to_q.weight"] = q
112- # new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
113- # new_state_dict[block_prefix + "attn1.to_k.weight"] = k
114- # new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
115- # new_state_dict[block_prefix + "attn1.to_v.weight"] = v
116- # new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
117-
118106 qkv_weight = original_state_dict .pop (old_prefix + "attention.query_key_value.weight" )
119107 qkv_bias = original_state_dict .pop (old_prefix + "attention.query_key_value.bias" )
120-
121- num_heads = 32
122- hidden_dim = 4096
123- head_dim = qkv_weight .shape [0 ] // (3 * num_heads )
124- qkv_weight = qkv_weight .view (num_heads , 3 , head_dim , hidden_dim )
125- qkv_bias = qkv_bias .view (num_heads , 3 , head_dim )
126-
127- qkv_weight = qkv_weight .permute (1 , 0 , 2 , 3 ) # (3, num_heads, head_dim, hidden_dim)
128- qkv_bias = qkv_bias .permute (1 , 0 , 2 ) # (3, num_heads, head_dim)
129-
130108 q , k , v = qkv_weight .chunk (3 , dim = 0 )
131109 q_bias , k_bias , v_bias = qkv_bias .chunk (3 , dim = 0 )
132110
133- new_state_dict [block_prefix + "attn1.to_q.weight" ] = q .squeeze (0 ).reshape (num_heads * head_dim , hidden_dim )
134- new_state_dict [block_prefix + "attn1.to_q.bias" ] = q_bias .squeeze (0 ).reshape (num_heads * head_dim )
135- new_state_dict [block_prefix + "attn1.to_k.weight" ] = k .squeeze (0 ).reshape (num_heads * head_dim , hidden_dim )
136- new_state_dict [block_prefix + "attn1.to_k.bias" ] = k_bias .squeeze (0 ).reshape (num_heads * head_dim )
137- new_state_dict [block_prefix + "attn1.to_v.weight" ] = v .squeeze (0 ).reshape (num_heads * head_dim , hidden_dim )
138- new_state_dict [block_prefix + "attn1.to_v.bias" ] = v_bias .squeeze (0 ).reshape (num_heads * head_dim )
139-
111+ new_state_dict [block_prefix + "attn1.to_q.weight" ] = q
112+ new_state_dict [block_prefix + "attn1.to_q.bias" ] = q_bias
113+ new_state_dict [block_prefix + "attn1.to_k.weight" ] = k
114+ new_state_dict [block_prefix + "attn1.to_k.bias" ] = k_bias
115+ new_state_dict [block_prefix + "attn1.to_v.weight" ] = v
116+ new_state_dict [block_prefix + "attn1.to_v.bias" ] = v_bias
140117
141118 new_state_dict [block_prefix + "attn1.to_out.0.weight" ] = original_state_dict .pop (
142119 old_prefix + "attention.dense.weight"
0 commit comments