|  | 
|  | 1 | +import argparse | 
|  | 2 | +from typing import Any, Dict | 
|  | 3 | + | 
|  | 4 | +import torch | 
|  | 5 | +from accelerate import init_empty_weights | 
|  | 6 | +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer | 
|  | 7 | + | 
|  | 8 | +from diffusers import ( | 
|  | 9 | +    AutoencoderKLHunyuanVideo, | 
|  | 10 | +    FlowMatchEulerDiscreteScheduler, | 
|  | 11 | +    HunyuanVideoPipeline, | 
|  | 12 | +    HunyuanVideoTransformer3DModel, | 
|  | 13 | +) | 
|  | 14 | + | 
|  | 15 | + | 
|  | 16 | +def remap_norm_scale_shift_(key, state_dict): | 
|  | 17 | +    weight = state_dict.pop(key) | 
|  | 18 | +    shift, scale = weight.chunk(2, dim=0) | 
|  | 19 | +    new_weight = torch.cat([scale, shift], dim=0) | 
|  | 20 | +    state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +def remap_txt_in_(key, state_dict): | 
|  | 24 | +    def rename_key(key): | 
|  | 25 | +        new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") | 
|  | 26 | +        new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") | 
|  | 27 | +        new_key = new_key.replace("txt_in", "context_embedder") | 
|  | 28 | +        new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") | 
|  | 29 | +        new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") | 
|  | 30 | +        new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") | 
|  | 31 | +        new_key = new_key.replace("mlp", "ff") | 
|  | 32 | +        return new_key | 
|  | 33 | + | 
|  | 34 | +    if "self_attn_qkv" in key: | 
|  | 35 | +        weight = state_dict.pop(key) | 
|  | 36 | +        to_q, to_k, to_v = weight.chunk(3, dim=0) | 
|  | 37 | +        state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q | 
|  | 38 | +        state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k | 
|  | 39 | +        state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v | 
|  | 40 | +    else: | 
|  | 41 | +        state_dict[rename_key(key)] = state_dict.pop(key) | 
|  | 42 | + | 
|  | 43 | + | 
|  | 44 | +def remap_img_attn_qkv_(key, state_dict): | 
|  | 45 | +    weight = state_dict.pop(key) | 
|  | 46 | +    to_q, to_k, to_v = weight.chunk(3, dim=0) | 
|  | 47 | +    state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q | 
|  | 48 | +    state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k | 
|  | 49 | +    state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v | 
|  | 50 | + | 
|  | 51 | + | 
|  | 52 | +def remap_txt_attn_qkv_(key, state_dict): | 
|  | 53 | +    weight = state_dict.pop(key) | 
|  | 54 | +    to_q, to_k, to_v = weight.chunk(3, dim=0) | 
|  | 55 | +    state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q | 
|  | 56 | +    state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k | 
|  | 57 | +    state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v | 
|  | 58 | + | 
|  | 59 | + | 
|  | 60 | +def remap_single_transformer_blocks_(key, state_dict): | 
|  | 61 | +    hidden_size = 3072 | 
|  | 62 | + | 
|  | 63 | +    if "linear1.weight" in key: | 
|  | 64 | +        linear1_weight = state_dict.pop(key) | 
|  | 65 | +        split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) | 
|  | 66 | +        q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) | 
|  | 67 | +        new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") | 
|  | 68 | +        state_dict[f"{new_key}.attn.to_q.weight"] = q | 
|  | 69 | +        state_dict[f"{new_key}.attn.to_k.weight"] = k | 
|  | 70 | +        state_dict[f"{new_key}.attn.to_v.weight"] = v | 
|  | 71 | +        state_dict[f"{new_key}.proj_mlp.weight"] = mlp | 
|  | 72 | + | 
|  | 73 | +    elif "linear1.bias" in key: | 
|  | 74 | +        linear1_bias = state_dict.pop(key) | 
|  | 75 | +        split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) | 
|  | 76 | +        q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) | 
|  | 77 | +        new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") | 
|  | 78 | +        state_dict[f"{new_key}.attn.to_q.bias"] = q_bias | 
|  | 79 | +        state_dict[f"{new_key}.attn.to_k.bias"] = k_bias | 
|  | 80 | +        state_dict[f"{new_key}.attn.to_v.bias"] = v_bias | 
|  | 81 | +        state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias | 
|  | 82 | + | 
|  | 83 | +    else: | 
|  | 84 | +        new_key = key.replace("single_blocks", "single_transformer_blocks") | 
|  | 85 | +        new_key = new_key.replace("linear2", "proj_out") | 
|  | 86 | +        new_key = new_key.replace("q_norm", "attn.norm_q") | 
|  | 87 | +        new_key = new_key.replace("k_norm", "attn.norm_k") | 
|  | 88 | +        state_dict[new_key] = state_dict.pop(key) | 
|  | 89 | + | 
|  | 90 | + | 
|  | 91 | +TRANSFORMER_KEYS_RENAME_DICT = { | 
|  | 92 | +    "img_in": "x_embedder", | 
|  | 93 | +    "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", | 
|  | 94 | +    "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", | 
|  | 95 | +    "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", | 
|  | 96 | +    "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", | 
|  | 97 | +    "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", | 
|  | 98 | +    "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", | 
|  | 99 | +    "double_blocks": "transformer_blocks", | 
|  | 100 | +    "img_attn_q_norm": "attn.norm_q", | 
|  | 101 | +    "img_attn_k_norm": "attn.norm_k", | 
|  | 102 | +    "img_attn_proj": "attn.to_out.0", | 
|  | 103 | +    "txt_attn_q_norm": "attn.norm_added_q", | 
|  | 104 | +    "txt_attn_k_norm": "attn.norm_added_k", | 
|  | 105 | +    "txt_attn_proj": "attn.to_add_out", | 
|  | 106 | +    "img_mod.linear": "norm1.linear", | 
|  | 107 | +    "img_norm1": "norm1.norm", | 
|  | 108 | +    "img_norm2": "norm2", | 
|  | 109 | +    "img_mlp": "ff", | 
|  | 110 | +    "txt_mod.linear": "norm1_context.linear", | 
|  | 111 | +    "txt_norm1": "norm1.norm", | 
|  | 112 | +    "txt_norm2": "norm2_context", | 
|  | 113 | +    "txt_mlp": "ff_context", | 
|  | 114 | +    "self_attn_proj": "attn.to_out.0", | 
|  | 115 | +    "modulation.linear": "norm.linear", | 
|  | 116 | +    "pre_norm": "norm.norm", | 
|  | 117 | +    "final_layer.norm_final": "norm_out.norm", | 
|  | 118 | +    "final_layer.linear": "proj_out", | 
|  | 119 | +    "fc1": "net.0.proj", | 
|  | 120 | +    "fc2": "net.2", | 
|  | 121 | +    "input_embedder": "proj_in", | 
|  | 122 | +} | 
|  | 123 | + | 
|  | 124 | +TRANSFORMER_SPECIAL_KEYS_REMAP = { | 
|  | 125 | +    "txt_in": remap_txt_in_, | 
|  | 126 | +    "img_attn_qkv": remap_img_attn_qkv_, | 
|  | 127 | +    "txt_attn_qkv": remap_txt_attn_qkv_, | 
|  | 128 | +    "single_blocks": remap_single_transformer_blocks_, | 
|  | 129 | +    "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, | 
|  | 130 | +} | 
|  | 131 | + | 
|  | 132 | +VAE_KEYS_RENAME_DICT = {} | 
|  | 133 | + | 
|  | 134 | +VAE_SPECIAL_KEYS_REMAP = {} | 
|  | 135 | + | 
|  | 136 | + | 
|  | 137 | +def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: | 
|  | 138 | +    state_dict[new_key] = state_dict.pop(old_key) | 
|  | 139 | + | 
|  | 140 | + | 
|  | 141 | +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: | 
|  | 142 | +    state_dict = saved_dict | 
|  | 143 | +    if "model" in saved_dict.keys(): | 
|  | 144 | +        state_dict = state_dict["model"] | 
|  | 145 | +    if "module" in saved_dict.keys(): | 
|  | 146 | +        state_dict = state_dict["module"] | 
|  | 147 | +    if "state_dict" in saved_dict.keys(): | 
|  | 148 | +        state_dict = state_dict["state_dict"] | 
|  | 149 | +    return state_dict | 
|  | 150 | + | 
|  | 151 | + | 
|  | 152 | +def convert_transformer(ckpt_path: str): | 
|  | 153 | +    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) | 
|  | 154 | + | 
|  | 155 | +    with init_empty_weights(): | 
|  | 156 | +        transformer = HunyuanVideoTransformer3DModel() | 
|  | 157 | + | 
|  | 158 | +    for key in list(original_state_dict.keys()): | 
|  | 159 | +        new_key = key[:] | 
|  | 160 | +        for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): | 
|  | 161 | +            new_key = new_key.replace(replace_key, rename_key) | 
|  | 162 | +        update_state_dict_(original_state_dict, key, new_key) | 
|  | 163 | + | 
|  | 164 | +    for key in list(original_state_dict.keys()): | 
|  | 165 | +        for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): | 
|  | 166 | +            if special_key not in key: | 
|  | 167 | +                continue | 
|  | 168 | +            handler_fn_inplace(key, original_state_dict) | 
|  | 169 | + | 
|  | 170 | +    transformer.load_state_dict(original_state_dict, strict=True, assign=True) | 
|  | 171 | +    return transformer | 
|  | 172 | + | 
|  | 173 | + | 
|  | 174 | +def convert_vae(ckpt_path: str): | 
|  | 175 | +    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) | 
|  | 176 | + | 
|  | 177 | +    with init_empty_weights(): | 
|  | 178 | +        vae = AutoencoderKLHunyuanVideo() | 
|  | 179 | + | 
|  | 180 | +    for key in list(original_state_dict.keys()): | 
|  | 181 | +        new_key = key[:] | 
|  | 182 | +        for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): | 
|  | 183 | +            new_key = new_key.replace(replace_key, rename_key) | 
|  | 184 | +        update_state_dict_(original_state_dict, key, new_key) | 
|  | 185 | + | 
|  | 186 | +    for key in list(original_state_dict.keys()): | 
|  | 187 | +        for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): | 
|  | 188 | +            if special_key not in key: | 
|  | 189 | +                continue | 
|  | 190 | +            handler_fn_inplace(key, original_state_dict) | 
|  | 191 | + | 
|  | 192 | +    vae.load_state_dict(original_state_dict, strict=True, assign=True) | 
|  | 193 | +    return vae | 
|  | 194 | + | 
|  | 195 | + | 
|  | 196 | +def get_args(): | 
|  | 197 | +    parser = argparse.ArgumentParser() | 
|  | 198 | +    parser.add_argument( | 
|  | 199 | +        "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" | 
|  | 200 | +    ) | 
|  | 201 | +    parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") | 
|  | 202 | +    parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") | 
|  | 203 | +    parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") | 
|  | 204 | +    parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint") | 
|  | 205 | +    parser.add_argument("--save_pipeline", action="store_true") | 
|  | 206 | +    parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") | 
|  | 207 | +    parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") | 
|  | 208 | +    return parser.parse_args() | 
|  | 209 | + | 
|  | 210 | + | 
|  | 211 | +DTYPE_MAPPING = { | 
|  | 212 | +    "fp32": torch.float32, | 
|  | 213 | +    "fp16": torch.float16, | 
|  | 214 | +    "bf16": torch.bfloat16, | 
|  | 215 | +} | 
|  | 216 | + | 
|  | 217 | + | 
|  | 218 | +if __name__ == "__main__": | 
|  | 219 | +    args = get_args() | 
|  | 220 | + | 
|  | 221 | +    transformer = None | 
|  | 222 | +    dtype = DTYPE_MAPPING[args.dtype] | 
|  | 223 | + | 
|  | 224 | +    if args.save_pipeline: | 
|  | 225 | +        assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None | 
|  | 226 | +        assert args.text_encoder_path is not None | 
|  | 227 | +        assert args.tokenizer_path is not None | 
|  | 228 | +        assert args.text_encoder_2_path is not None | 
|  | 229 | + | 
|  | 230 | +    if args.transformer_ckpt_path is not None: | 
|  | 231 | +        transformer = convert_transformer(args.transformer_ckpt_path) | 
|  | 232 | +        transformer = transformer.to(dtype=dtype) | 
|  | 233 | +        if not args.save_pipeline: | 
|  | 234 | +            transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | 
|  | 235 | + | 
|  | 236 | +    if args.vae_ckpt_path is not None: | 
|  | 237 | +        vae = convert_vae(args.vae_ckpt_path) | 
|  | 238 | +        if not args.save_pipeline: | 
|  | 239 | +            vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | 
|  | 240 | + | 
|  | 241 | +    if args.save_pipeline: | 
|  | 242 | +        text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) | 
|  | 243 | +        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") | 
|  | 244 | +        text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) | 
|  | 245 | +        tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) | 
|  | 246 | +        scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) | 
|  | 247 | + | 
|  | 248 | +        pipe = HunyuanVideoPipeline( | 
|  | 249 | +            transformer=transformer, | 
|  | 250 | +            vae=vae, | 
|  | 251 | +            text_encoder=text_encoder, | 
|  | 252 | +            tokenizer=tokenizer, | 
|  | 253 | +            text_encoder_2=text_encoder_2, | 
|  | 254 | +            tokenizer_2=tokenizer_2, | 
|  | 255 | +            scheduler=scheduler, | 
|  | 256 | +        ) | 
|  | 257 | +        pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | 
0 commit comments