This repository was archived by the owner on Sep 10, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -281,6 +281,8 @@ class TransformerArgs:
281281 # Optional biases
282282 attention_bias : bool = False
283283 feed_forward_bias : bool = False
284+ # Whether or not to tie the input word embeddings to the output
285+ tie_word_embeddings : bool = False
284286
285287 def __post_init__ (self ):
286288 if self .n_local_heads == - 1 :
@@ -632,12 +634,20 @@ def __init__(self, config: TransformerArgs) -> None:
632634 if config .stage_idx == config .n_stages - 1 :
633635 self .norm = RMSNorm (config .dim , eps = config .norm_eps )
634636 self .output = nn .Linear (config .dim , config .vocab_size , bias = False )
637+ if config .tie_word_embeddings :
638+ self .output .weight = self .tok_embeddings .weight
635639 else :
636640 self .norm = None
637641 self .output = None
638642
639643 self .max_batch_size = - 1
640644 self .max_seq_length = - 1
645+ self ._register_load_state_dict_pre_hook (self .load_hook )
646+
647+ def load_hook (self , state_dict , prefix , * args ):
648+ """Handle tied embeddings at load time"""
649+ if self .config .tie_word_embeddings :
650+ state_dict .setdefault ("model.output.weight" , state_dict ["model.tok_embeddings.weight" ])
641651
642652 def setup_caches (self , max_batch_size , max_seq_length , cache_lanes : int = 1 ):
643653 if (
You can’t perform that action at this time.
0 commit comments