Skip to content

Commit 21439e2

Browse files
draft of transformers convert
1 parent b3cadb8 commit 21439e2

File tree

8 files changed

+414
-136
lines changed

8 files changed

+414
-136
lines changed

scripts/convert_cogview3_to_diffusers.py

Lines changed: 219 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,235 @@
33

44
import torch
55
from transformers import T5EncoderModel, T5Tokenizer
6+
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler
7+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
68

79
from diffusers import (
810
CogView3PlusTransformer2DModel,
911
CogView3PlusPipeline,
1012
)
1113

14+
15+
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
16+
to_q_key = key.replace("query_key_value", "to_q")
17+
to_k_key = key.replace("query_key_value", "to_k")
18+
to_v_key = key.replace("query_key_value", "to_v")
19+
to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
20+
state_dict[to_q_key] = to_q
21+
state_dict[to_k_key] = to_k
22+
state_dict[to_v_key] = to_v
23+
state_dict.pop(key)
24+
25+
26+
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
27+
layer_id, weight_or_bias = key.split(".")[-2:]
28+
29+
if "query" in key:
30+
new_key = f"transformer_blocks.{layer_id}.attn.norm_q.{weight_or_bias}"
31+
elif "key" in key:
32+
new_key = f"transformer_blocks.{layer_id}.attn.norm_k.{weight_or_bias}"
33+
34+
state_dict[new_key] = state_dict.pop(key)
35+
36+
37+
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
38+
layer_id, _, weight_or_bias = key.split(".")[-3:]
39+
40+
weights_or_biases = state_dict[key].chunk(12, dim=0)
41+
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
42+
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
43+
44+
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
45+
state_dict[norm1_key] = norm1_weights_or_biases
46+
47+
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
48+
state_dict[norm2_key] = norm2_weights_or_biases
49+
50+
state_dict.pop(key)
51+
52+
53+
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
54+
state_dict.pop(key)
55+
56+
57+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
58+
state_dict = saved_dict
59+
if "model" in saved_dict.keys():
60+
state_dict = state_dict["model"]
61+
if "module" in saved_dict.keys():
62+
state_dict = state_dict["module"]
63+
if "state_dict" in saved_dict.keys():
64+
state_dict = state_dict["state_dict"]
65+
return state_dict
66+
67+
1268
TRANSFORMER_KEYS_RENAME_DICT = {
1369
"transformer": "transformer_blocks",
14-
"attention": "attn1",
15-
"mlp": "ff.net",
70+
"attention": "attn",
71+
"mlp": "mlp.net",
1672
"dense_h_to_4h": "0.proj",
1773
"dense_4h_to_h": "2",
1874
".layers": "",
1975
"dense": "to_out.0",
20-
"patch_embed": "norm1.norm",
21-
"post_attn1_layernorm": "norm2.norm",
22-
"mixins.patch_embed": "patch_embed",
23-
"mixins.final_layer.adaln": "norm_out",
76+
"mixins.patch_embed": "image_patch_embed",
77+
"mixins.adaln.adaln_modules": "adaln_module",
78+
"time_embed": "time_embed",
79+
"label_emb": "label_embed",
80+
"mixins.final_layer.adaln": "final_layer.adaln",
2481
"mixins.final_layer.linear": "proj_out",
25-
}
82+
}
83+
84+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
85+
"query_key_value": reassign_query_key_value_inplace,
86+
}
87+
88+
TOKENIZER_MAX_LENGTH = 224
89+
90+
91+
# VAE of CogView3Plus can be converted to diffusers without any changes
92+
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
93+
original_state_dict = torch.load(ckpt_path, map_location='cpu')["state_dict"]
94+
95+
vae = AutoencoderKL(
96+
in_channels=3,
97+
out_channels=3,
98+
down_block_types=("DownEncoderBlock2D",) * 4,
99+
up_block_types=("UpDecoderBlock2D",) * 4,
100+
block_out_channels=(128, 512, 1024, 1024),
101+
layers_per_block=3,
102+
act_fn="silu",
103+
latent_channels=16,
104+
norm_num_groups=32,
105+
sample_size=1024,
106+
scaling_factor=scaling_factor,
107+
force_upcast=True,
108+
use_quant_conv=False,
109+
use_post_quant_conv=False,
110+
mid_block_add_attention=False,
111+
).to(dtype=dtype)
112+
113+
# Convert the state dict to a format compatible with diffusers
114+
converted_state_dict = convert_ldm_vae_checkpoint(original_state_dict, vae.config)
115+
116+
# Load the converted state dict into the VAE model
117+
vae.load_state_dict(converted_state_dict, strict=False)
118+
119+
return vae
120+
121+
122+
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
123+
state_dict[new_key] = state_dict.pop(old_key)
124+
125+
126+
def convert_transformer(
127+
ckpt_path: str,
128+
num_layers: int,
129+
num_attention_heads: int,
130+
dtype: torch.dtype,
131+
):
132+
PREFIX_KEY = "model.diffusion_model."
133+
134+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
135+
transformer = CogView3PlusTransformer2DModel(
136+
in_channels=16,
137+
num_layers=num_layers,
138+
num_attention_heads=num_attention_heads,
139+
).to(dtype=dtype)
140+
141+
for key in list(original_state_dict.keys()):
142+
new_key = key[len(PREFIX_KEY):]
143+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
144+
new_key = new_key.replace(replace_key, rename_key)
145+
update_state_dict_inplace(original_state_dict, key, new_key)
146+
147+
148+
for key in list(original_state_dict.keys()):
149+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
150+
if special_key not in key:
151+
continue
152+
handler_fn_inplace(key, original_state_dict)
153+
transformer.load_state_dict(original_state_dict, strict=True)
154+
155+
return transformer
156+
157+
158+
def get_args():
159+
parser = argparse.ArgumentParser()
160+
parser.add_argument(
161+
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
162+
)
163+
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
164+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
165+
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
166+
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
167+
parser.add_argument(
168+
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
169+
)
170+
parser.add_argument(
171+
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
172+
)
173+
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
174+
parser.add_argument("--num_attention_heads", type=int, default=64, help="Number of transformer blocks")
175+
parser.add_argument("--scaling_factor", type=float, default=0.18215, help="Scaling factor in the VAE")
176+
return parser.parse_args()
177+
178+
179+
if __name__ == "__main__":
180+
args = get_args()
181+
182+
transformer = None
183+
vae = None
184+
185+
if args.fp16 and args.bf16:
186+
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
187+
188+
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
189+
if args.transformer_ckpt_path is not None:
190+
transformer = convert_transformer(
191+
args.transformer_ckpt_path,
192+
args.num_layers,
193+
args.num_attention_heads,
194+
dtype
195+
)
196+
197+
if args.vae_ckpt_path is not None:
198+
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
199+
200+
text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
201+
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
202+
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
203+
204+
scheduler = CogVideoXDDIMScheduler.from_config(
205+
{
206+
"beta_end": 0.012,
207+
"beta_schedule": "scaled_linear",
208+
"beta_start": 0.00085,
209+
"clip_sample": False,
210+
"num_train_timesteps": 1000,
211+
"prediction_type": "v_prediction",
212+
"rescale_betas_zero_snr": True,
213+
"set_alpha_to_one": True,
214+
"timestep_spacing": "trailing",
215+
}
216+
)
217+
218+
pipe = CogView3PlusPipeline(
219+
tokenizer=tokenizer,
220+
vae=vae,
221+
text_encoder=text_encoder,
222+
transformer=transformer,
223+
scheduler=scheduler,
224+
)
225+
226+
if args.fp16:
227+
pipe = pipe.to(dtype=torch.float16)
228+
if args.bf16:
229+
pipe = pipe.to(dtype=torch.bfloat16)
230+
231+
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
232+
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
233+
# is either fp16/bf16 here).
234+
235+
# This is necessary This is necessary for users with insufficient memory,
236+
# such as those using Colab and notebooks, as it can save some memory used for model loading.
237+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)

src/diffusers/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
_import_structure["modeling_utils"] = ["ModelMixin"]
4545
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
4646
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
47-
_import_structure["transformers.transformer_cogview3dplus"] = ["CogView3PlusTransformer2DModel"]
47+
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
4848
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
4949
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
5050
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]

src/diffusers/models/embeddings.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,17 @@ def forward(self, target_size: List[int]) -> torch.Tensor:
764764
return torch.cat(ret, dim=0) # Concatenate along the batch dimension
765765

766766
def reinit(self):
767-
# Initialize the positional embedding using a 2D sin-cos function
768-
pos_embed_np = self.get_2d_sincos_pos_embed(self.hidden_size, self.max_height, self.max_width)
767+
# Initialize the positional embedding using the updated 2D sin-cos function
768+
grid_size = (self.max_height, self.max_width)
769+
pos_embed_np = get_2d_sincos_pos_embed(
770+
embed_dim=self.hidden_size,
771+
grid_size=grid_size,
772+
)
773+
774+
# Reshape the positional embedding to the desired shape
775+
pos_embed_np = pos_embed_np.reshape(self.max_height, self.max_width, self.hidden_size)
776+
777+
# Copy the positional embedding data
769778
self.image_pos_embedding.data.copy_(torch.from_numpy(pos_embed_np).float())
770779

771780

0 commit comments

Comments
 (0)