3434try :
3535 # TODO: remove this after we figure out where in torchtune an `evaluate` module
3636 # is being imported, which is being confused with huggingface's `evaluate``.
37- import lm_eval # noqa
37+ import lm_eval # noqa
3838except Exception :
3939 pass
4040
@@ -278,6 +278,11 @@ class TransformerArgs:
278278 # For pipeline parallel
279279 n_stages : int = 1
280280 stage_idx : int = 0
281+ # Optional biases
282+ attention_bias : bool = False
283+ feed_forward_bias : bool = False
284+ # Whether or not to tie the input word embeddings to the output
285+ tie_word_embeddings : bool = False
281286
282287 def __post_init__ (self ):
283288 if self .n_local_heads == - 1 :
@@ -394,7 +399,7 @@ def from_name(cls, name: str):
394399 config = [
395400 config
396401 for config in known_model_params
397- if config in str (name ).upper () or config in str (name )
402+ if config . upper () in str (name ).upper () or config in str (name )
398403 ]
399404
400405 # We may have two or more configs matched (e.g., "7B" and
@@ -471,7 +476,7 @@ def build_model(self) -> nn.Module:
471476 modules [name ] = module_class (TransformerArgs .from_params (config_args ))
472477 else :
473478 modules [name ] = module_class (** config_args )
474-
479+
475480 # Temporary add extra params to the DeepFusionModel.
476481 # TODO: Remove it once we can make fusion model configurable in model_param.
477482 if recipe .fusion_class == DeepFusionModel :
@@ -629,12 +634,20 @@ def __init__(self, config: TransformerArgs) -> None:
629634 if config .stage_idx == config .n_stages - 1 :
630635 self .norm = RMSNorm (config .dim , eps = config .norm_eps )
631636 self .output = nn .Linear (config .dim , config .vocab_size , bias = False )
637+ if config .tie_word_embeddings :
638+ self .output .weight = self .tok_embeddings .weight
632639 else :
633640 self .norm = None
634641 self .output = None
635642
636643 self .max_batch_size = - 1
637644 self .max_seq_length = - 1
645+ self ._register_load_state_dict_pre_hook (self .load_hook )
646+
647+ def load_hook (self , state_dict , prefix , * args ):
648+ """Handle tied embeddings at load time"""
649+ if self .config .tie_word_embeddings :
650+ state_dict .setdefault ("model.output.weight" , state_dict ["model.tok_embeddings.weight" ])
638651
639652 def setup_caches (self , max_batch_size , max_seq_length , cache_lanes : int = 1 ):
640653 if (
@@ -730,16 +743,16 @@ def __init__(self, config: TransformerArgs):
730743
731744 # key, query, value projections for all heads, but in a batch
732745 # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
733- # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False )
734- self .wq = nn .Linear (config .dim , config .n_heads * config .head_dim , bias = False )
746+ # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias )
747+ self .wq = nn .Linear (config .dim , config .n_heads * config .head_dim , bias = config . attention_bias )
735748 self .wk = nn .Linear (
736- config .dim , config .n_local_heads * config .head_dim , bias = False
749+ config .dim , config .n_local_heads * config .head_dim , bias = config . attention_bias
737750 )
738751 self .wv = nn .Linear (
739- config .dim , config .n_local_heads * config .head_dim , bias = False
752+ config .dim , config .n_local_heads * config .head_dim , bias = config . attention_bias
740753 )
741754
742- self .wo = nn .Linear (config .dim , config .dim , bias = False )
755+ self .wo = nn .Linear (config .dim , config .dim , bias = config . attention_bias )
743756 self .kv_cache = None
744757
745758 self .n_heads = config .n_heads
@@ -766,14 +779,16 @@ def load_hook(self, state_dict, prefix, *args):
766779 # wv = state_dict.pop(prefix + "wv.weight")
767780 # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
768781
769- if prefix + "wqkv.weight" in state_dict :
770- wqkv = state_dict .pop (prefix + "wqkv.weight" )
771- q_size = self .n_heads * self .head_dim
772- kv_size = self .n_local_heads * self .head_dim
773- wq , wk , wv = torch .split (wqkv , (q_size , kv_size , kv_size ), dim = 0 )
774- state_dict [prefix + "wq.weight" ] = wq
775- state_dict [prefix + "wk.weight" ] = wk
776- state_dict [prefix + "wv.weight" ] = wv
782+ for tensor_suffix in ["weight" , "bias" ]:
783+ wqkv_key = f"{ prefix } wqkv.{ tensor_suffix } "
784+ if wqkv_key in state_dict :
785+ wqkv = state_dict .pop (wqkv_key )
786+ q_size = self .n_heads * self .head_dim
787+ kv_size = self .n_local_heads * self .head_dim
788+ wq , wk , wv = torch .split (wqkv , (q_size , kv_size , kv_size ), dim = 0 )
789+ state_dict [f"{ prefix } wq.{ tensor_suffix } " ] = wq
790+ state_dict [f"{ prefix } wk.{ tensor_suffix } " ] = wk
791+ state_dict [f"{ prefix } wv.{ tensor_suffix } " ] = wv
777792
778793 return
779794
@@ -852,9 +867,9 @@ def forward(
852867class FeedForward (nn .Module ):
853868 def __init__ (self , config : TransformerArgs ) -> None :
854869 super ().__init__ ()
855- self .w1 = nn .Linear (config .dim , config .hidden_dim , bias = False )
856- self .w2 = nn .Linear (config .hidden_dim , config .dim , bias = False )
857- self .w3 = nn .Linear (config .dim , config .hidden_dim , bias = False )
870+ self .w1 = nn .Linear (config .dim , config .hidden_dim , bias = config . feed_forward_bias )
871+ self .w2 = nn .Linear (config .hidden_dim , config .dim , bias = config . feed_forward_bias )
872+ self .w3 = nn .Linear (config .dim , config .hidden_dim , bias = config . feed_forward_bias )
858873
859874 def distribute (self , device_mesh : DeviceMesh ):
860875 parallelize_module (self .w1 , device_mesh , ColwiseParallel ())
0 commit comments