|
| 1 | +import torch |
| 2 | +from collections import OrderedDict |
| 3 | +from diffusers import CogView4Transformer2DModel |
| 4 | + |
| 5 | +def load_state_dict_sat(file_path): |
| 6 | + """Load the SAT state dictionary from a given file path.""" |
| 7 | + # Typically, the stored SAT ckpt is in the format: {'module': {...}} |
| 8 | + ckpt = torch.load(file_path, map_location="cuda") |
| 9 | + return ckpt["module"] |
| 10 | + |
| 11 | + |
| 12 | +def extract_qkv_from_sat(state_dict, layer_idx): |
| 13 | + """ |
| 14 | + Extract QKV weights and biases from a SAT state_dict. |
| 15 | + Expects keys like: |
| 16 | + model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value |
| 17 | + """ |
| 18 | + prefix = f"model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value" |
| 19 | + w = state_dict[f"{prefix}.weight"].clone() |
| 20 | + b = state_dict[f"{prefix}.bias"].clone() |
| 21 | + return (w, b) |
| 22 | + |
| 23 | + |
| 24 | +def load_state_dict_cogview(cogview_path): |
| 25 | + """ |
| 26 | + Loads the CogView4 model from diffusers and returns its state_dict(). |
| 27 | + NOTE: You should adjust 'torch_dtype' and 'device_map' as appropriate. |
| 28 | + """ |
| 29 | + cogview_model = CogView4Transformer2DModel.from_pretrained( |
| 30 | + cogview_path, torch_dtype=torch.bfloat16, device_map="auto" |
| 31 | + ) |
| 32 | + return cogview_model.state_dict() |
| 33 | + |
| 34 | + |
| 35 | +def extract_qkv_from_cogview(state_dict, layer_idx, num_heads, head_dim, hidden_dim): |
| 36 | + """ |
| 37 | + Extract Q, K, V from CogView4 checkpoint and reshape them into the same shape as SAT’s QKV. |
| 38 | + For each layer i: |
| 39 | + Q prefix: transformer_blocks.{layer_idx}.attn1.to_q |
| 40 | + K prefix: transformer_blocks.{layer_idx}.attn1.to_k |
| 41 | + V prefix: transformer_blocks.{layer_idx}.attn1.to_v |
| 42 | + Final shape must match SAT's [3*hidden_dim, hidden_dim] for weight, and [3*hidden_dim] for bias. |
| 43 | + """ |
| 44 | + q_prefix = f"transformer_blocks.{layer_idx}.attn1.to_q" |
| 45 | + k_prefix = f"transformer_blocks.{layer_idx}.attn1.to_k" |
| 46 | + v_prefix = f"transformer_blocks.{layer_idx}.attn1.to_v" |
| 47 | + |
| 48 | + # Extract |
| 49 | + q_weight = state_dict[f"{q_prefix}.weight"].clone() |
| 50 | + k_weight = state_dict[f"{k_prefix}.weight"].clone() |
| 51 | + v_weight = state_dict[f"{v_prefix}.weight"].clone() |
| 52 | + |
| 53 | + q_bias = state_dict[f"{q_prefix}.bias"].clone() |
| 54 | + k_bias = state_dict[f"{k_prefix}.bias"].clone() |
| 55 | + v_bias = state_dict[f"{v_prefix}.bias"].clone() |
| 56 | + |
| 57 | + # Reshape weights: [hidden_dim, hidden_dim] -> [num_heads, head_dim, hidden_dim] |
| 58 | + # Then concat along the first dimension (which will become 3*num_heads*head_dim) |
| 59 | + q_weight = q_weight.view(num_heads, head_dim, hidden_dim) |
| 60 | + k_weight = k_weight.view(num_heads, head_dim, hidden_dim) |
| 61 | + v_weight = v_weight.view(num_heads, head_dim, hidden_dim) |
| 62 | + |
| 63 | + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) # shape: (3*num_heads, head_dim, hidden_dim) |
| 64 | + qkv_weight = qkv_weight.view(3 * num_heads * head_dim, hidden_dim) # flatten |
| 65 | + |
| 66 | + # Reshape biases: [hidden_dim] -> [num_heads, head_dim] |
| 67 | + q_bias = q_bias.view(num_heads, head_dim) |
| 68 | + k_bias = k_bias.view(num_heads, head_dim) |
| 69 | + v_bias = v_bias.view(num_heads, head_dim) |
| 70 | + |
| 71 | + qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) # (3*num_heads, head_dim) |
| 72 | + qkv_bias = qkv_bias.view(3 * num_heads * head_dim) |
| 73 | + |
| 74 | + return (qkv_weight, qkv_bias) |
| 75 | + |
| 76 | +def create_sat_state_dict_from_megatron(megatron_ckpt_dict, num_layers=48, num_heads=32, hidden_size=3072): |
| 77 | + """ |
| 78 | + Convert a loaded Megatron checkpoint's 'model' dictionary into the same |
| 79 | + format used by SAT. This returns something like {'module': {...}} for |
| 80 | + easy comparison with SAT. |
| 81 | +
|
| 82 | + The code below is adapted from your 'create_sat_state_dict' function, |
| 83 | + but we rename it here to keep it direct. |
| 84 | + """ |
| 85 | + from tqdm import tqdm |
| 86 | + |
| 87 | + hidden_size_per_head = hidden_size // num_heads |
| 88 | + mega_weight = megatron_ckpt_dict["model"] |
| 89 | + sat_weight = {} |
| 90 | + |
| 91 | + # --- patch_embed --- |
| 92 | + sat_weight["model.diffusion_model.mixins.patch_embed.proj.weight"] = \ |
| 93 | + mega_weight["encoder_expand_linear.weight"].reshape(hidden_size, 64).clone() |
| 94 | + sat_weight["model.diffusion_model.mixins.patch_embed.proj.bias"] = \ |
| 95 | + mega_weight["encoder_expand_linear.bias"].clone() |
| 96 | + |
| 97 | + sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.weight"] = \ |
| 98 | + mega_weight["text_projector.weight"].clone() |
| 99 | + sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.bias"] = \ |
| 100 | + mega_weight["text_projector.bias"].clone() |
| 101 | + |
| 102 | + # --- time embedding --- |
| 103 | + sat_weight["model.diffusion_model.time_embed.0.weight"] = \ |
| 104 | + mega_weight["time_embedding.time_embed.0.weight"].clone() |
| 105 | + sat_weight["model.diffusion_model.time_embed.0.bias"] = \ |
| 106 | + mega_weight["time_embedding.time_embed.0.bias"].clone() |
| 107 | + sat_weight["model.diffusion_model.time_embed.2.weight"] = \ |
| 108 | + mega_weight["time_embedding.time_embed.2.weight"].clone() |
| 109 | + sat_weight["model.diffusion_model.time_embed.2.bias"] = \ |
| 110 | + mega_weight["time_embedding.time_embed.2.bias"].clone() |
| 111 | + |
| 112 | + # --- label embedding --- |
| 113 | + sat_weight["model.diffusion_model.label_emb.0.0.weight"] = \ |
| 114 | + mega_weight["label_embedding.label_embed.0.weight"].clone() |
| 115 | + sat_weight["model.diffusion_model.label_emb.0.0.bias"] = \ |
| 116 | + mega_weight["label_embedding.label_embed.0.bias"].clone() |
| 117 | + sat_weight["model.diffusion_model.label_emb.0.2.weight"] = \ |
| 118 | + mega_weight["label_embedding.label_embed.2.weight"].clone() |
| 119 | + sat_weight["model.diffusion_model.label_emb.0.2.bias"] = \ |
| 120 | + mega_weight["label_embedding.label_embed.2.bias"].clone() |
| 121 | + |
| 122 | + # --- layers --- |
| 123 | + for i in tqdm(range(num_layers), desc="Converting Megatron->SAT"): |
| 124 | + # attention output |
| 125 | + sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.weight"] = \ |
| 126 | + mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.weight"].clone() |
| 127 | + sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.bias"] = \ |
| 128 | + mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.bias"].clone() |
| 129 | + |
| 130 | + # QKV |
| 131 | + qkv_weight = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.weight"].clone() |
| 132 | + qkv_bias = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.bias"].clone() |
| 133 | + |
| 134 | + # Reshape QKV from Megatron format into SAT format |
| 135 | + # qkv_weight: [3*hidden_size, hidden_size] -> [num_heads, 3, hidden_size_per_head, hidden_size] -> ... |
| 136 | + sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.weight"] = \ |
| 137 | + qkv_weight.view(num_heads, 3, hidden_size_per_head, hidden_size) \ |
| 138 | + .permute(1, 0, 2, 3) \ |
| 139 | + .reshape(3 * hidden_size, hidden_size).clone() |
| 140 | + sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.bias"] = \ |
| 141 | + qkv_bias.view(num_heads, 3, hidden_size_per_head) \ |
| 142 | + .permute(1, 0, 2) \ |
| 143 | + .reshape(3 * hidden_size) \ |
| 144 | + .clone() |
| 145 | + |
| 146 | + # MLP |
| 147 | + sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.weight"] = \ |
| 148 | + mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.weight"].clone() |
| 149 | + sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.bias"] = \ |
| 150 | + mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.bias"].clone() |
| 151 | + |
| 152 | + sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.weight"] = \ |
| 153 | + mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.weight"].clone() |
| 154 | + sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.bias"] = \ |
| 155 | + mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.bias"].clone() |
| 156 | + |
| 157 | + # AdaLN |
| 158 | + adaln_weight = mega_weight[f"decoder.layers.{i}.adaln.weight"].clone() |
| 159 | + adaln_bias = mega_weight[f"decoder.layers.{i}.adaln.bias"].clone() |
| 160 | + |
| 161 | + sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.weight"] = adaln_weight.clone() |
| 162 | + sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.bias"] = adaln_bias.clone() |
| 163 | + |
| 164 | + # --- final layers --- |
| 165 | + sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.weight"] = \ |
| 166 | + mega_weight["adaln_final.weight"].clone() |
| 167 | + sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.bias"] = \ |
| 168 | + mega_weight["adaln_final.bias"].clone() |
| 169 | + sat_weight["model.diffusion_model.mixins.final_layer.linear.weight"] = \ |
| 170 | + mega_weight["output_projector.weight"].clone() |
| 171 | + sat_weight["model.diffusion_model.mixins.final_layer.linear.bias"] = \ |
| 172 | + mega_weight["output_projector.bias"].clone() |
| 173 | + |
| 174 | + return OrderedDict(sat_weight) |
| 175 | + |
| 176 | + |
| 177 | +def load_state_dict_megatron_and_convert_to_sat(megatron_ckpt_path, num_layers, num_heads, hidden_size): |
| 178 | + """ |
| 179 | + Load a Megatron checkpoint from <megatron_ckpt_path>, then convert it into |
| 180 | + an SAT-style OrderedDict for direct QKV comparison. |
| 181 | +
|
| 182 | + Typically, <megatron_ckpt_path> = ".../iter_0287500/mp_rank_00/model_optim_rng.pt" |
| 183 | + """ |
| 184 | + ckpt = torch.load(megatron_ckpt_path, map_location="cuda") |
| 185 | + # Convert to SAT |
| 186 | + sat_like_weight = create_sat_state_dict_from_megatron( |
| 187 | + ckpt, num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size |
| 188 | + ) |
| 189 | + return sat_like_weight |
| 190 | + |
| 191 | +def compute_l2_difference(tensor1, tensor2): |
| 192 | + """Compute L2 norm of the difference between two tensors.""" |
| 193 | + return torch.norm(tensor1 - tensor2, p=2).item() |
| 194 | + |
| 195 | + |
| 196 | +def compare_qkv(qkv1, qkv2, name1="Model1", name2="Model2", atol=1e-6): |
| 197 | + """ |
| 198 | + Compare QKV from two different sources (each is a tuple of (weight, bias)). |
| 199 | + Returns (weight_match, bias_match, weight_l2, bias_l2). |
| 200 | + """ |
| 201 | + w1, b1 = qkv1 |
| 202 | + w2, b2 = qkv2 |
| 203 | + |
| 204 | + weight_match = torch.allclose(w1, w2, atol=atol) |
| 205 | + bias_match = torch.allclose(b1, b2, atol=atol) |
| 206 | + weight_l2_diff = compute_l2_difference(w1, w2) |
| 207 | + bias_l2_diff = compute_l2_difference(b1, b2) |
| 208 | + |
| 209 | + if not (weight_match and bias_match): |
| 210 | + print(f"[QKV Mismatch] {name1} vs {name2}") |
| 211 | + print(f" Weight L2: {weight_l2_diff:.6f}, Bias L2: {bias_l2_diff:.6f}") |
| 212 | + else: |
| 213 | + # If everything matches well: |
| 214 | + print(f"[QKV Match] {name1} vs {name2} (Weight L2={weight_l2_diff:.6f}, Bias L2={bias_l2_diff:.6f})") |
| 215 | + |
| 216 | + return weight_match, bias_match, weight_l2_diff, bias_l2_diff |
| 217 | + |
| 218 | +if __name__ == "__main__": |
| 219 | + num_layers = 28 |
| 220 | + num_heads = 32 |
| 221 | + hidden_dim = 4096 |
| 222 | + head_dim = hidden_dim // num_heads |
| 223 | + |
| 224 | + sat_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_sat/0287500/mp_rank_00_model_states.pt" |
| 225 | + sat_state_dict = load_state_dict_sat(sat_ckpt_path) |
| 226 | + |
| 227 | + cogview_path = "/share/zyx/CogView4-6B-0128/transformer" # directory containing model index for diffusers |
| 228 | + cogview_state_dict = load_state_dict_cogview(cogview_path) |
| 229 | + |
| 230 | + megatron_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_ema/iter_0287500/mp_rank_00/model_optim_rng.pt" |
| 231 | + mega_as_sat_state_dict = load_state_dict_megatron_and_convert_to_sat( |
| 232 | + megatron_ckpt_path, |
| 233 | + num_layers=num_layers, |
| 234 | + num_heads=num_heads, |
| 235 | + hidden_size=hidden_dim |
| 236 | + ) |
| 237 | + |
| 238 | + print("\n==== Start QKV Comparison ====\n") |
| 239 | + for layer_idx in range(num_layers): |
| 240 | + print(f"--- Layer {layer_idx} ---") |
| 241 | + |
| 242 | + # Extract QKV from SAT |
| 243 | + sat_qkv = extract_qkv_from_sat(sat_state_dict, layer_idx) |
| 244 | + |
| 245 | + # Extract QKV from CogView |
| 246 | + cogview_qkv = extract_qkv_from_cogview( |
| 247 | + cogview_state_dict, layer_idx, num_heads, head_dim, hidden_dim |
| 248 | + ) |
| 249 | + |
| 250 | + # Extract QKV from Megatron->SAT |
| 251 | + mega_qkv = extract_qkv_from_sat(mega_as_sat_state_dict, layer_idx) |
| 252 | + |
| 253 | + # Compare: SAT vs CogView |
| 254 | + compare_qkv(sat_qkv, cogview_qkv, name1="SAT", name2="CogView4") |
| 255 | + |
| 256 | + # Compare: SAT vs Megatron |
| 257 | + compare_qkv(sat_qkv, mega_qkv, name1="SAT", name2="Megatron") |
| 258 | + |
| 259 | + # Compare: CogView vs Megatron (optional) |
| 260 | + compare_qkv(cogview_qkv, mega_qkv, name1="CogView4", name2="Megatron") |
| 261 | + |
| 262 | + print() |
| 263 | + |
| 264 | + print("=== Done ===") |
0 commit comments