@@ -728,13 +728,20 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int
728728        if  self .tok_embeddings :
729729            x  =  self .tok_embeddings (x )
730730
731+             # For Granite architectures 
732+             if  self .config .embedding_multiplier :
733+                 x  =  x  *  self .config .embedding_multiplier 
734+ 
731735        for  _ , layer  in  self .layers .items ():
732736            x  =  layer (x , input_pos , freqs_cis , mask , cache_lane = cache_lane )
733737
734738        if  self .norm :
735739            x  =  self .norm (x )
736740        if  self .output :
737741            x  =  self .output (x )
742+         # For granite architectures 
743+         if  self .config .logits_scaling :
744+             x  =  x  /  self .config .logits_scaling 
738745        # print(f"output shape: {x.shape}") 
739746        return  x 
740747
@@ -746,6 +753,12 @@ def __init__(self, config: TransformerArgs) -> None:
746753        self .feed_forward  =  FeedForward (config )
747754        self .ffn_norm  =  RMSNorm (config .dim , config .norm_eps )
748755        self .attention_norm  =  RMSNorm (config .dim , config .norm_eps )
756+         # None for llama architecture, set for granite architectures 
757+         self .residual_multiplier  =  (
758+             config .residual_multiplier 
759+             if  config .residual_multiplier  is  not   None 
760+             else  1.0 
761+         )
749762
750763    def  distribute (self , device_mesh : DeviceMesh ):
751764        self .attention .distribute (device_mesh )
@@ -756,8 +769,8 @@ def forward(
756769    ) ->  Tensor :
757770        h  =  x  +  self .attention (
758771            self .attention_norm (x ), freqs_cis , mask , input_pos , cache_lane = cache_lane 
759-         )
760-         out  =  h  +  self .feed_forward (self .ffn_norm (h ))
772+         )  *   self . residual_multiplier 
773+         out  =  h  +  self .feed_forward (self .ffn_norm (h ))  *   self . residual_multiplier 
761774        return  out 
762775
763776
@@ -784,6 +797,7 @@ def __init__(self, config: TransformerArgs):
784797        self .head_dim  =  config .head_dim 
785798        self .n_local_heads  =  config .n_local_heads 
786799        self .dim  =  config .dim 
800+         self .attention_scale  =  config .attention_multiplier 
787801        self ._register_load_state_dict_pre_hook (self .load_hook )
788802
789803    def  setup_cache (self , max_batch_size , max_seq_length , cache_lanes : int  =  1 ):
@@ -880,7 +894,16 @@ def forward(
880894
881895        k  =  k .repeat_interleave (self .n_heads  //  self .n_local_heads , dim = 1 )
882896        v  =  v .repeat_interleave (self .n_heads  //  self .n_local_heads , dim = 1 )
883-         y  =  F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
897+         y  =  F .scaled_dot_product_attention (
898+             query = q ,
899+             key = k ,
900+             value = v ,
901+             attn_mask = mask ,
902+             dropout_p = 0.0 ,
903+             # This is None (default) for llama architecture and set for granite 
904+             # architectures 
905+             scale = self .attention_scale ,
906+         )
884907
885908        # -1 = self.dim 
886909        y  =  y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , - 1 )
0 commit comments