@@ -815,6 +815,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
815815 if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35" :
816816 # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
817817 res = "minerva-7b"
818+ if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664" :
819+ # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820+ res = "hunyuan"
818821
819822 if res is None :
820823 logger .warning ("\n " )
@@ -6535,6 +6538,155 @@ def set_gguf_parameters(self):
65356538 super ().set_gguf_parameters ()
65366539 self .gguf_writer .add_audio_stack_factor (self .global_config ["stack_factor" ])
65376540
6541+
6542+ @ModelBase .register ("HunYuanMoEV1ForCausalLM" )
6543+ class HunYuanMoEModel (TextModel ):
6544+ model_arch = gguf .MODEL_ARCH .HUNYUAN_MOE
6545+
6546+ def __init__ (self , * args , ** kwargs ):
6547+ super ().__init__ (* args , ** kwargs )
6548+ # For handling tied embeddings
6549+ self ._tok_embd = None
6550+
6551+ def set_vocab (self ):
6552+ from transformers import AutoTokenizer
6553+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model , trust_remote_code = True )
6554+
6555+ # 1. Get the pre-tokenizer identifier hash
6556+ tokpre = self .get_vocab_base_pre (tokenizer )
6557+
6558+ # 2. Reverse-engineer the merges list from mergeable_ranks
6559+ merges = []
6560+ vocab = {}
6561+ mergeable_ranks = tokenizer .mergeable_ranks
6562+ for token , rank in mergeable_ranks .items ():
6563+ vocab [QwenModel .token_bytes_to_string (token )] = rank
6564+ if len (token ) == 1 :
6565+ continue
6566+ merged = QwenModel .bpe (mergeable_ranks , token , max_rank = rank )
6567+ if len (merged ) == 2 : # todo this is an assert in Qwen, why?
6568+ merges .append (' ' .join (map (QwenModel .token_bytes_to_string , merged )))
6569+
6570+ # 3. Generate the tokens and toktypes lists
6571+ vocab_size = self .hparams ["vocab_size" ]
6572+ assert tokenizer .vocab_size == vocab_size
6573+ special_tokens = tokenizer .special_tokens
6574+ reverse_vocab = {id_ : encoded_tok for encoded_tok , id_ in {** vocab , ** special_tokens }.items ()}
6575+ tokens : list [str ] = []
6576+ toktypes : list [int ] = []
6577+ for i in range (vocab_size ):
6578+ if i not in reverse_vocab :
6579+ tokens .append (f"[PAD{ i } ]" )
6580+ toktypes .append (gguf .TokenType .UNUSED )
6581+ else :
6582+ token = reverse_vocab [i ]
6583+ tokens .append (token )
6584+ if i in special_tokens .values ():
6585+ toktypes .append (gguf .TokenType .CONTROL )
6586+ else :
6587+ toktypes .append (gguf .TokenType .NORMAL )
6588+
6589+ # 4. Write all vocab-related fields to the GGUF writer
6590+ self .gguf_writer .add_tokenizer_model ("gpt2" )
6591+ self .gguf_writer .add_tokenizer_pre (tokpre )
6592+ self .gguf_writer .add_token_list (tokens )
6593+ self .gguf_writer .add_token_types (toktypes )
6594+ self .gguf_writer .add_token_merges (merges )
6595+
6596+ # 5. Add special tokens and chat templates
6597+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
6598+ special_vocab .add_to_gguf (self .gguf_writer )
6599+ # FIX for BOS token: Overwrite incorrect id read from config.json
6600+ self .gguf_writer .add_bos_token_id (127959 ) # <|bos|>
6601+
6602+ def set_gguf_parameters (self ):
6603+ super ().set_gguf_parameters ()
6604+ hparams = self .hparams
6605+
6606+ self .gguf_writer .add_expert_count (hparams ["num_experts" ])
6607+ self .gguf_writer .add_expert_shared_feed_forward_length (hparams ["intermediate_size" ])
6608+
6609+ moe_intermediate_size = hparams ["moe_intermediate_size" ]
6610+ assert all (n == moe_intermediate_size [0 ] for n in moe_intermediate_size )
6611+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size [0 ])
6612+
6613+ moe_topk = hparams ["moe_topk" ]
6614+ assert all (topk == moe_topk [0 ] for topk in moe_topk )
6615+ self .gguf_writer .add_expert_used_count (moe_topk [0 ])
6616+
6617+ moe_shared_expert = hparams ["num_shared_expert" ]
6618+ assert all (n == moe_shared_expert [0 ] for n in moe_shared_expert )
6619+ self .gguf_writer .add_expert_shared_count (moe_shared_expert [0 ])
6620+
6621+ # Rope
6622+ rope_scaling = hparams .get ("rope_scaling" , {})
6623+ if rope_scaling .get ("type" ) == "dynamic" :
6624+ # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
6625+ # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
6626+ alpha = rope_scaling .get ("alpha" , 1000 )
6627+ base = hparams .get ("rope_theta" , 10000.0 )
6628+ dim = (hparams ["hidden_size" ] // hparams ["num_attention_heads" ]) # 128
6629+ scaled_base = base * (alpha ** (dim / (dim - 2 ))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
6630+ self .gguf_writer .add_rope_freq_base (scaled_base )
6631+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
6632+ self .gguf_writer .add_rope_scaling_factor (1 )
6633+ # There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
6634+ self .gguf_writer .add_rope_scaling_orig_ctx_len (256 * 1024 ) # 256k context length
6635+ self .gguf_writer .add_context_length (256 * 1024 ) # 256k context length
6636+
6637+ # if any of our assumptions about the values are wrong, something has changed and this may need to be updated
6638+ assert alpha == 1000 and base == 10000.0 and dim == 128 and self .hparams ["max_position_embeddings" ] in [32 * 1024 , 256 * 1024 ] , \
6639+ "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
6640+
6641+ _experts : list [dict [str , Tensor ]] | None = None
6642+
6643+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
6644+ if name == "model.embed_tokens.weight" :
6645+ self ._tok_embd = data_torch .clone ()
6646+
6647+ if name == "lm_head.weight" :
6648+ if self .hparams .get ("tie_word_embeddings" , False ):
6649+ logger .info ("Skipping tied output layer 'lm_head.weight'" )
6650+ return []
6651+
6652+ if name .find ("mlp.experts" ) != - 1 :
6653+ n_experts = self .hparams ["num_experts" ]
6654+ assert bid is not None
6655+
6656+ if self ._experts is None :
6657+ self ._experts = [{} for _ in range (self .block_count )]
6658+
6659+ self ._experts [bid ][name ] = data_torch
6660+
6661+ if len (self ._experts [bid ]) >= n_experts * 3 :
6662+ # merge the experts into a single 3d tensor
6663+ tensors : list [tuple [str , Tensor ]] = []
6664+ for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
6665+ datas : list [Tensor ] = []
6666+
6667+ for xid in range (n_experts ):
6668+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
6669+ datas .append (self ._experts [bid ][ename ])
6670+ del self ._experts [bid ][ename ]
6671+
6672+ data_torch = torch .stack (datas , dim = 0 )
6673+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
6674+ new_name = self .map_tensor_name (merged_name )
6675+ tensors .append ((new_name , data_torch ))
6676+
6677+ return tensors
6678+ else :
6679+ return []
6680+
6681+ return [(self .map_tensor_name (name ), data_torch )]
6682+
6683+ def prepare_tensors (self ):
6684+ super ().prepare_tensors ()
6685+ if self ._experts is not None :
6686+ experts = [k for d in self ._experts for k in d .keys ()]
6687+ if len (experts ) > 0 :
6688+ raise ValueError (f"Unprocessed experts: { experts } " )
6689+
65386690###### CONVERSION LOGIC ######
65396691
65406692
0 commit comments