diff --git a/torchchat/model.py b/torchchat/model.py index 673b582d3..25b4ddcd7 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -281,6 +281,8 @@ class TransformerArgs: # Optional biases attention_bias: bool = False feed_forward_bias: bool = False + # Whether or not to tie the input word embeddings to the output + tie_word_embeddings: bool = False def __post_init__(self): if self.n_local_heads == -1: @@ -632,12 +634,20 @@ def __init__(self, config: TransformerArgs) -> None: if config.stage_idx == config.n_stages - 1: self.norm = RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + if config.tie_word_embeddings: + self.output.weight = self.tok_embeddings.weight else: self.norm = None self.output = None self.max_batch_size = -1 self.max_seq_length = -1 + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + """Handle tied embeddings at load time""" + if self.config.tie_word_embeddings: + state_dict.setdefault("model.output.weight", state_dict["model.tok_embeddings.weight"]) def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1): if (