Skip to content

Commit 310da29

Browse files
update a new convert from megatron
1 parent e239c3c commit 310da29

File tree

1 file changed

+264
-0
lines changed

1 file changed

+264
-0
lines changed
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
import torch
2+
from collections import OrderedDict
3+
from diffusers import CogView4Transformer2DModel
4+
5+
def load_state_dict_sat(file_path):
6+
"""Load the SAT state dictionary from a given file path."""
7+
# Typically, the stored SAT ckpt is in the format: {'module': {...}}
8+
ckpt = torch.load(file_path, map_location="cuda")
9+
return ckpt["module"]
10+
11+
12+
def extract_qkv_from_sat(state_dict, layer_idx):
13+
"""
14+
Extract QKV weights and biases from a SAT state_dict.
15+
Expects keys like:
16+
model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value
17+
"""
18+
prefix = f"model.diffusion_model.transformer.layers.{layer_idx}.attention.query_key_value"
19+
w = state_dict[f"{prefix}.weight"].clone()
20+
b = state_dict[f"{prefix}.bias"].clone()
21+
return (w, b)
22+
23+
24+
def load_state_dict_cogview(cogview_path):
25+
"""
26+
Loads the CogView4 model from diffusers and returns its state_dict().
27+
NOTE: You should adjust 'torch_dtype' and 'device_map' as appropriate.
28+
"""
29+
cogview_model = CogView4Transformer2DModel.from_pretrained(
30+
cogview_path, torch_dtype=torch.bfloat16, device_map="auto"
31+
)
32+
return cogview_model.state_dict()
33+
34+
35+
def extract_qkv_from_cogview(state_dict, layer_idx, num_heads, head_dim, hidden_dim):
36+
"""
37+
Extract Q, K, V from CogView4 checkpoint and reshape them into the same shape as SAT’s QKV.
38+
For each layer i:
39+
Q prefix: transformer_blocks.{layer_idx}.attn1.to_q
40+
K prefix: transformer_blocks.{layer_idx}.attn1.to_k
41+
V prefix: transformer_blocks.{layer_idx}.attn1.to_v
42+
Final shape must match SAT's [3*hidden_dim, hidden_dim] for weight, and [3*hidden_dim] for bias.
43+
"""
44+
q_prefix = f"transformer_blocks.{layer_idx}.attn1.to_q"
45+
k_prefix = f"transformer_blocks.{layer_idx}.attn1.to_k"
46+
v_prefix = f"transformer_blocks.{layer_idx}.attn1.to_v"
47+
48+
# Extract
49+
q_weight = state_dict[f"{q_prefix}.weight"].clone()
50+
k_weight = state_dict[f"{k_prefix}.weight"].clone()
51+
v_weight = state_dict[f"{v_prefix}.weight"].clone()
52+
53+
q_bias = state_dict[f"{q_prefix}.bias"].clone()
54+
k_bias = state_dict[f"{k_prefix}.bias"].clone()
55+
v_bias = state_dict[f"{v_prefix}.bias"].clone()
56+
57+
# Reshape weights: [hidden_dim, hidden_dim] -> [num_heads, head_dim, hidden_dim]
58+
# Then concat along the first dimension (which will become 3*num_heads*head_dim)
59+
q_weight = q_weight.view(num_heads, head_dim, hidden_dim)
60+
k_weight = k_weight.view(num_heads, head_dim, hidden_dim)
61+
v_weight = v_weight.view(num_heads, head_dim, hidden_dim)
62+
63+
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) # shape: (3*num_heads, head_dim, hidden_dim)
64+
qkv_weight = qkv_weight.view(3 * num_heads * head_dim, hidden_dim) # flatten
65+
66+
# Reshape biases: [hidden_dim] -> [num_heads, head_dim]
67+
q_bias = q_bias.view(num_heads, head_dim)
68+
k_bias = k_bias.view(num_heads, head_dim)
69+
v_bias = v_bias.view(num_heads, head_dim)
70+
71+
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) # (3*num_heads, head_dim)
72+
qkv_bias = qkv_bias.view(3 * num_heads * head_dim)
73+
74+
return (qkv_weight, qkv_bias)
75+
76+
def create_sat_state_dict_from_megatron(megatron_ckpt_dict, num_layers=48, num_heads=32, hidden_size=3072):
77+
"""
78+
Convert a loaded Megatron checkpoint's 'model' dictionary into the same
79+
format used by SAT. This returns something like {'module': {...}} for
80+
easy comparison with SAT.
81+
82+
The code below is adapted from your 'create_sat_state_dict' function,
83+
but we rename it here to keep it direct.
84+
"""
85+
from tqdm import tqdm
86+
87+
hidden_size_per_head = hidden_size // num_heads
88+
mega_weight = megatron_ckpt_dict["model"]
89+
sat_weight = {}
90+
91+
# --- patch_embed ---
92+
sat_weight["model.diffusion_model.mixins.patch_embed.proj.weight"] = \
93+
mega_weight["encoder_expand_linear.weight"].reshape(hidden_size, 64).clone()
94+
sat_weight["model.diffusion_model.mixins.patch_embed.proj.bias"] = \
95+
mega_weight["encoder_expand_linear.bias"].clone()
96+
97+
sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.weight"] = \
98+
mega_weight["text_projector.weight"].clone()
99+
sat_weight["model.diffusion_model.mixins.patch_embed.text_proj.bias"] = \
100+
mega_weight["text_projector.bias"].clone()
101+
102+
# --- time embedding ---
103+
sat_weight["model.diffusion_model.time_embed.0.weight"] = \
104+
mega_weight["time_embedding.time_embed.0.weight"].clone()
105+
sat_weight["model.diffusion_model.time_embed.0.bias"] = \
106+
mega_weight["time_embedding.time_embed.0.bias"].clone()
107+
sat_weight["model.diffusion_model.time_embed.2.weight"] = \
108+
mega_weight["time_embedding.time_embed.2.weight"].clone()
109+
sat_weight["model.diffusion_model.time_embed.2.bias"] = \
110+
mega_weight["time_embedding.time_embed.2.bias"].clone()
111+
112+
# --- label embedding ---
113+
sat_weight["model.diffusion_model.label_emb.0.0.weight"] = \
114+
mega_weight["label_embedding.label_embed.0.weight"].clone()
115+
sat_weight["model.diffusion_model.label_emb.0.0.bias"] = \
116+
mega_weight["label_embedding.label_embed.0.bias"].clone()
117+
sat_weight["model.diffusion_model.label_emb.0.2.weight"] = \
118+
mega_weight["label_embedding.label_embed.2.weight"].clone()
119+
sat_weight["model.diffusion_model.label_emb.0.2.bias"] = \
120+
mega_weight["label_embedding.label_embed.2.bias"].clone()
121+
122+
# --- layers ---
123+
for i in tqdm(range(num_layers), desc="Converting Megatron->SAT"):
124+
# attention output
125+
sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.weight"] = \
126+
mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.weight"].clone()
127+
sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.dense.bias"] = \
128+
mega_weight[f"decoder.layers.{i}.self_attention.linear_proj.bias"].clone()
129+
130+
# QKV
131+
qkv_weight = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.weight"].clone()
132+
qkv_bias = mega_weight[f"decoder.layers.{i}.self_attention.linear_qkv.bias"].clone()
133+
134+
# Reshape QKV from Megatron format into SAT format
135+
# qkv_weight: [3*hidden_size, hidden_size] -> [num_heads, 3, hidden_size_per_head, hidden_size] -> ...
136+
sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.weight"] = \
137+
qkv_weight.view(num_heads, 3, hidden_size_per_head, hidden_size) \
138+
.permute(1, 0, 2, 3) \
139+
.reshape(3 * hidden_size, hidden_size).clone()
140+
sat_weight[f"model.diffusion_model.transformer.layers.{i}.attention.query_key_value.bias"] = \
141+
qkv_bias.view(num_heads, 3, hidden_size_per_head) \
142+
.permute(1, 0, 2) \
143+
.reshape(3 * hidden_size) \
144+
.clone()
145+
146+
# MLP
147+
sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.weight"] = \
148+
mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.weight"].clone()
149+
sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_h_to_4h.bias"] = \
150+
mega_weight[f"decoder.layers.{i}.mlp.linear_fc1.bias"].clone()
151+
152+
sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.weight"] = \
153+
mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.weight"].clone()
154+
sat_weight[f"model.diffusion_model.transformer.layers.{i}.mlp.dense_4h_to_h.bias"] = \
155+
mega_weight[f"decoder.layers.{i}.mlp.linear_fc2.bias"].clone()
156+
157+
# AdaLN
158+
adaln_weight = mega_weight[f"decoder.layers.{i}.adaln.weight"].clone()
159+
adaln_bias = mega_weight[f"decoder.layers.{i}.adaln.bias"].clone()
160+
161+
sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.weight"] = adaln_weight.clone()
162+
sat_weight[f"model.diffusion_model.mixins.adaln.adaln_modules.{i}.1.bias"] = adaln_bias.clone()
163+
164+
# --- final layers ---
165+
sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.weight"] = \
166+
mega_weight["adaln_final.weight"].clone()
167+
sat_weight["model.diffusion_model.mixins.final_layer.adaln.1.bias"] = \
168+
mega_weight["adaln_final.bias"].clone()
169+
sat_weight["model.diffusion_model.mixins.final_layer.linear.weight"] = \
170+
mega_weight["output_projector.weight"].clone()
171+
sat_weight["model.diffusion_model.mixins.final_layer.linear.bias"] = \
172+
mega_weight["output_projector.bias"].clone()
173+
174+
return OrderedDict(sat_weight)
175+
176+
177+
def load_state_dict_megatron_and_convert_to_sat(megatron_ckpt_path, num_layers, num_heads, hidden_size):
178+
"""
179+
Load a Megatron checkpoint from <megatron_ckpt_path>, then convert it into
180+
an SAT-style OrderedDict for direct QKV comparison.
181+
182+
Typically, <megatron_ckpt_path> = ".../iter_0287500/mp_rank_00/model_optim_rng.pt"
183+
"""
184+
ckpt = torch.load(megatron_ckpt_path, map_location="cuda")
185+
# Convert to SAT
186+
sat_like_weight = create_sat_state_dict_from_megatron(
187+
ckpt, num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size
188+
)
189+
return sat_like_weight
190+
191+
def compute_l2_difference(tensor1, tensor2):
192+
"""Compute L2 norm of the difference between two tensors."""
193+
return torch.norm(tensor1 - tensor2, p=2).item()
194+
195+
196+
def compare_qkv(qkv1, qkv2, name1="Model1", name2="Model2", atol=1e-6):
197+
"""
198+
Compare QKV from two different sources (each is a tuple of (weight, bias)).
199+
Returns (weight_match, bias_match, weight_l2, bias_l2).
200+
"""
201+
w1, b1 = qkv1
202+
w2, b2 = qkv2
203+
204+
weight_match = torch.allclose(w1, w2, atol=atol)
205+
bias_match = torch.allclose(b1, b2, atol=atol)
206+
weight_l2_diff = compute_l2_difference(w1, w2)
207+
bias_l2_diff = compute_l2_difference(b1, b2)
208+
209+
if not (weight_match and bias_match):
210+
print(f"[QKV Mismatch] {name1} vs {name2}")
211+
print(f" Weight L2: {weight_l2_diff:.6f}, Bias L2: {bias_l2_diff:.6f}")
212+
else:
213+
# If everything matches well:
214+
print(f"[QKV Match] {name1} vs {name2} (Weight L2={weight_l2_diff:.6f}, Bias L2={bias_l2_diff:.6f})")
215+
216+
return weight_match, bias_match, weight_l2_diff, bias_l2_diff
217+
218+
if __name__ == "__main__":
219+
num_layers = 28
220+
num_heads = 32
221+
hidden_dim = 4096
222+
head_dim = hidden_dim // num_heads
223+
224+
sat_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_sat/0287500/mp_rank_00_model_states.pt"
225+
sat_state_dict = load_state_dict_sat(sat_ckpt_path)
226+
227+
cogview_path = "/share/zyx/CogView4-6B-0128/transformer" # directory containing model index for diffusers
228+
cogview_state_dict = load_state_dict_cogview(cogview_path)
229+
230+
megatron_ckpt_path = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/pt_ema/iter_0287500/mp_rank_00/model_optim_rng.pt"
231+
mega_as_sat_state_dict = load_state_dict_megatron_and_convert_to_sat(
232+
megatron_ckpt_path,
233+
num_layers=num_layers,
234+
num_heads=num_heads,
235+
hidden_size=hidden_dim
236+
)
237+
238+
print("\n==== Start QKV Comparison ====\n")
239+
for layer_idx in range(num_layers):
240+
print(f"--- Layer {layer_idx} ---")
241+
242+
# Extract QKV from SAT
243+
sat_qkv = extract_qkv_from_sat(sat_state_dict, layer_idx)
244+
245+
# Extract QKV from CogView
246+
cogview_qkv = extract_qkv_from_cogview(
247+
cogview_state_dict, layer_idx, num_heads, head_dim, hidden_dim
248+
)
249+
250+
# Extract QKV from Megatron->SAT
251+
mega_qkv = extract_qkv_from_sat(mega_as_sat_state_dict, layer_idx)
252+
253+
# Compare: SAT vs CogView
254+
compare_qkv(sat_qkv, cogview_qkv, name1="SAT", name2="CogView4")
255+
256+
# Compare: SAT vs Megatron
257+
compare_qkv(sat_qkv, mega_qkv, name1="SAT", name2="Megatron")
258+
259+
# Compare: CogView vs Megatron (optional)
260+
compare_qkv(cogview_qkv, mega_qkv, name1="CogView4", name2="Megatron")
261+
262+
print()
263+
264+
print("=== Done ===")

0 commit comments

Comments
 (0)