77import os
88import re
99import sys
10+ import glob
1011from pathlib import Path
11- from typing import Optional
12+ from typing import Any , Dict , Optional
1213
1314import torch
15+ import safetensors .torch
16+ import shutil
1417
1518# support running without installing as a package
1619wd = Path (__file__ ).parent .parent
@@ -24,34 +27,34 @@ def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]:
2427 translated_state_dict = {}
2528 hf_weight_prefix = "vision_model."
2629 name_mapping = {
27- f"{ hf_weight_prefix } embeddings.class_embedding" : "model. encoder.cls_token_embedding.weight" ,
28- f"{ hf_weight_prefix } embeddings.position_embedding.weight" : "model. encoder.token_pos_embedding.positional_embedding" ,
29- f"{ hf_weight_prefix } embeddings.patch_embedding.weight" : "model. encoder.conv.weight" ,
30- f"{ hf_weight_prefix } pre_layrnorm.weight" : "model. encoder.ln_pre.weight" ,
31- f"{ hf_weight_prefix } pre_layrnorm.bias" : "model. encoder.ln_pre.bias" ,
32- f"{ hf_weight_prefix } post_layernorm.weight" : "model. encoder.ln_post.weight" ,
33- f"{ hf_weight_prefix } post_layernorm.bias" : "model. encoder.ln_post.bias" ,
30+ f"{ hf_weight_prefix } embeddings.class_embedding" : "encoder.cls_token_embedding.weight" ,
31+ f"{ hf_weight_prefix } embeddings.position_embedding.weight" : "encoder.token_pos_embedding.positional_embedding" ,
32+ f"{ hf_weight_prefix } embeddings.patch_embedding.weight" : "encoder.conv.weight" ,
33+ f"{ hf_weight_prefix } pre_layrnorm.weight" : "encoder.ln_pre.weight" ,
34+ f"{ hf_weight_prefix } pre_layrnorm.bias" : "encoder.ln_pre.bias" ,
35+ f"{ hf_weight_prefix } post_layernorm.weight" : "encoder.ln_post.weight" ,
36+ f"{ hf_weight_prefix } post_layernorm.bias" : "encoder.ln_post.bias" ,
3437 }
3538 patterns = [
3639 (
3740 rf"{ hf_weight_prefix } encoder\.layers\.([0-9]+)\.self_attn\.(k|q|v)_proj\.(weight|bias)" ,
38- lambda match : f"model. encoder.layers.{ match .group (1 )} .attn.{ match .group (2 )} _proj.{ match .group (3 )} " ,
41+ lambda match : f"encoder.layers.{ match .group (1 )} .attn.{ match .group (2 )} _proj.{ match .group (3 )} " ,
3942 ),
4043 (
4144 rf"{ hf_weight_prefix } encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.(weight|bias)" ,
42- lambda match : f"model. encoder.layers.{ match .group (1 )} .attn.output_proj.{ match .group (2 )} " ,
45+ lambda match : f"encoder.layers.{ match .group (1 )} .attn.output_proj.{ match .group (2 )} " ,
4346 ),
4447 (
4548 rf"{ hf_weight_prefix } encoder\.layers\.([0-9]+)\.mlp\.fc(1|2)\.(weight|bias)" ,
46- lambda match : f"model. encoder.layers.{ match .group (1 )} .mlp.w{ match .group (2 )} .{ match .group (3 )} " ,
49+ lambda match : f"encoder.layers.{ match .group (1 )} .mlp.w{ match .group (2 )} .{ match .group (3 )} " ,
4750 ),
4851 (
4952 rf"{ hf_weight_prefix } encoder\.layers\.([0-9]+)\.layer_norm1\.(weight|bias)" ,
50- lambda match : f"model. encoder.layers.{ match .group (1 )} .sa_norm.{ match .group (2 )} " ,
53+ lambda match : f"encoder.layers.{ match .group (1 )} .sa_norm.{ match .group (2 )} " ,
5154 ),
5255 (
5356 rf"{ hf_weight_prefix } encoder\.layers\.([0-9]+)\.layer_norm2\.(weight|bias)" ,
54- lambda match : f"model. encoder.layers.{ match .group (1 )} .mlp_norm.{ match .group (2 )} " ,
57+ lambda match : f"encoder.layers.{ match .group (1 )} .mlp_norm.{ match .group (2 )} " ,
5558 ),
5659 ]
5760 for pattern , replacement in patterns :
@@ -82,18 +85,18 @@ def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]:
8285
8386 def _translate_state_dict_for_text_model (hf_state_dict ) -> Dict [str , Any ]:
8487 key_map = {
85- r"model.layers.([0-9]+).self_attn.q_proj." : r"model. decoder.layers.\1.attention.wq." ,
86- r"model.layers.([0-9]+).self_attn.k_proj." : r"model. decoder.layers.\1.attention.wk." ,
87- r"model.layers.([0-9]+).self_attn.v_proj." : r"model. decoder.layers.\1.attention.wv." ,
88- r"model.layers.([0-9]+).self_attn.o_proj." : r"model. decoder.layers.\1.attention.wo." ,
89- r"model.layers.([0-9]+).input_layernorm." : r"model. decoder.layers.\1.attention_norm." ,
90- r"model.layers.([0-9]+).mlp.gate_proj." : r"model. decoder.layers.\1.feed_forward.w1." ,
91- r"model.layers.([0-9]+).mlp.down_proj." : r"model. decoder.layers.\1.feed_forward.w2." ,
92- r"model.layers.([0-9]+).mlp.up_proj." : r"model. decoder.layers.\1.feed_forward.w3." ,
93- r"model.layers.([0-9]+).post_attention_layernorm." : r"model. decoder.layers.\1.ffn_norm." ,
94- r"model.norm." : r"model. decoder.norm." ,
88+ r"model.layers.([0-9]+).self_attn.q_proj." : r"decoder.layers.\1.attention.wq." ,
89+ r"model.layers.([0-9]+).self_attn.k_proj." : r"decoder.layers.\1.attention.wk." ,
90+ r"model.layers.([0-9]+).self_attn.v_proj." : r"decoder.layers.\1.attention.wv." ,
91+ r"model.layers.([0-9]+).self_attn.o_proj." : r"decoder.layers.\1.attention.wo." ,
92+ r"model.layers.([0-9]+).input_layernorm." : r"decoder.layers.\1.attention_norm." ,
93+ r"model.layers.([0-9]+).mlp.gate_proj." : r"decoder.layers.\1.feed_forward.w1." ,
94+ r"model.layers.([0-9]+).mlp.down_proj." : r"decoder.layers.\1.feed_forward.w2." ,
95+ r"model.layers.([0-9]+).mlp.up_proj." : r"decoder.layers.\1.feed_forward.w3." ,
96+ r"model.layers.([0-9]+).post_attention_layernorm." : r"decoder.layers.\1.ffn_norm." ,
97+ r"model.norm." : r"decoder.norm." ,
9598 # r"model.embed_tokens.": r"tok_embeddings.", # load separately
96- r"lm_head." : r"model. decoder.output." ,
99+ r"lm_head." : r"decoder.output." ,
97100 }
98101 new_state_dict = {}
99102 def get_new_key (old_key : str ) -> str :
@@ -109,7 +112,7 @@ def get_new_key(old_key: str) -> str:
109112 def _translate_state_dict_for_mm_projector_model (hf_state_dict ) -> Dict [str , Any ]:
110113 new_state_dict = {}
111114 for old_key in hf_state_dict .keys ():
112- new_key = "model. mm_projector." + old_key
115+ new_key = "mm_projector." + old_key
113116 new_state_dict [new_key ] = hf_state_dict [old_key ]
114117 return new_state_dict
115118
@@ -127,13 +130,65 @@ def split_checkpoint(llava_ckpt):
127130 return language_model_ckpt , multi_modal_ckpt , vision_tower_ckpt
128131 language_model_ckpt , multi_modal_ckpt , vision_tower_ckpt = split_checkpoint (llava_ckpt )
129132 remapped_state_dict = {
130- "model. tok_embeddings.weight" : language_model_ckpt .pop ("model.embed_tokens.weight" ),
133+ "tok_embeddings.weight" : language_model_ckpt .pop ("model.embed_tokens.weight" ),
131134 }
132135 remapped_state_dict .update (_translate_state_dict_for_text_model (language_model_ckpt ))
133136 remapped_state_dict .update (_translate_state_dict_for_vision_model (vision_tower_ckpt ))
134137 remapped_state_dict .update (_translate_state_dict_for_mm_projector_model (multi_modal_ckpt ))
135138 return remapped_state_dict
136139
140+
141+ @torch .inference_mode
142+ def convert_llava_checkpoint (
143+ * ,
144+ model_dir : Optional [Path ] = None ,
145+ ) -> None :
146+
147+ """
148+ Process safetensor files from a specific directory structure and save the remapped model.
149+
150+ Args:
151+ model_dir (str): Base directory containing the model subdirectories.
152+ """
153+
154+ def _get_llava_files_with_pattern (pattern ):
155+ pattern = os .path .join (model_dir , f"models--llava-hf--llava-1.5-7b-hf/snapshots/*/{ pattern } " )
156+ return glob .glob (pattern )
157+
158+ # get all safetensor files in the model directory
159+ safetensor_files = _get_llava_files_with_pattern ("*.safetensors" )
160+
161+ if not safetensor_files :
162+ raise ValueError ("No safetensor files found." )
163+
164+ merged_weights = {}
165+
166+ # Merge safetensor files into a whole
167+ for file in safetensor_files :
168+ # Load weights from the current file
169+ part_weights = safetensors .torch .load_file (file )
170+
171+ # Iterate over each weight in the current file
172+ for key , value in part_weights .items ():
173+ if key in merged_weights :
174+ # If the key already exists, concatenate tensors
175+ merged_weights [key ] = torch .cat ((merged_weights [key ], value ), dim = 0 )
176+ else :
177+ # If the key does not exist, add it to the dictionary
178+ merged_weights [key ] = value
179+
180+ # Remap the checkpoint and save it as pth
181+ remapped_weights = remap_llava_checkpoint (merged_weights )
182+ model_path = model_dir / "model.pth"
183+ torch .save (remapped_weights , model_path )
184+
185+ # copy tokenizer
186+ tokenizer_files = _get_llava_files_with_pattern ("tokenizer.model" )
187+ assert len (tokenizer_files ) == 1 , "Should get only one tokenizer file, but got {}" .format (tokenizer_files )
188+
189+ tokenizer_path = model_dir / "tokenizer.model"
190+ shutil .copy (tokenizer_files [0 ], tokenizer_path )
191+
137192
138193@torch .inference_mode ()
139194def convert_text_only_hf_checkpoint (
@@ -245,18 +300,18 @@ def permute(w, n_heads):
245300
246301
247302@torch .inference_mode ()
248- def convert_text_only_hf_checkpoint (
303+ def convert_hf_checkpoint (
249304 * ,
250305 model_dir : Optional [Path ] = None ,
251306 model_name : Optional [str ] = None ,
252307 remove_bin_files : bool = False ,
253308):
254- if model_name == "llava-1.5" :
255- print ("Converting LLaVA 1.5 checkpoint. " )
256- print ( os . listdir ( model_dir ))
257- exit ( 0 )
309+ print ( model_name )
310+ print ("*********************** " )
311+ if "llava" in model_name :
312+ convert_llava_checkpoint ( model_dir = model_dir )
258313 else :
259- convert_text_only_hf_checkpoint (model_dir , model_name , remove_bin_files )
314+ convert_text_only_hf_checkpoint (model_dir = model_dir , model_name = model_name , remove_bin_files = remove_bin_files )
260315
261316
262317if __name__ == "__main__" :
0 commit comments