This repository was archived by the owner on Sep 10, 2025. It is now read-only.
  
  
  
  
    
    
    
      
    
  
  
    
File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -281,6 +281,8 @@ class TransformerArgs:
281281    # Optional biases 
282282    attention_bias : bool  =  False 
283283    feed_forward_bias : bool  =  False 
284+     # Whether or not to tie the input word embeddings to the output 
285+     tie_word_embeddings : bool  =  False 
284286
285287    def  __post_init__ (self ):
286288        if  self .n_local_heads  ==  - 1 :
@@ -632,12 +634,20 @@ def __init__(self, config: TransformerArgs) -> None:
632634        if  config .stage_idx  ==  config .n_stages  -  1 :
633635            self .norm  =  RMSNorm (config .dim , eps = config .norm_eps )
634636            self .output  =  nn .Linear (config .dim , config .vocab_size , bias = False )
637+             if  config .tie_word_embeddings :
638+                 self .output .weight  =  self .tok_embeddings .weight 
635639        else :
636640            self .norm  =  None 
637641            self .output  =  None 
638642
639643        self .max_batch_size  =  - 1 
640644        self .max_seq_length  =  - 1 
645+         self ._register_load_state_dict_pre_hook (self .load_hook )
646+ 
647+     def  load_hook (self , state_dict , prefix , * args ):
648+         """Handle tied embeddings at load time""" 
649+         if  self .config .tie_word_embeddings :
650+             state_dict .setdefault ("model.output.weight" , state_dict ["model.tok_embeddings.weight" ])
641651
642652    def  setup_caches (self , max_batch_size , max_seq_length , cache_lanes : int  =  1 ):
643653        if  (
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments