@@ -287,6 +287,11 @@ class TransformerArgs:
287287 feed_forward_bias : bool = False
288288 # Whether or not to tie the input word embeddings to the output
289289 tie_word_embeddings : bool = False
290+ # Granite architecture multipliers
291+ embedding_multiplier : Optional [float ] = None
292+ attention_multiplier : Optional [float ] = None
293+ residual_multiplier : Optional [float ] = None
294+ logits_scaling : Optional [float ] = None
290295
291296 def __post_init__ (self ):
292297 if self .n_local_heads == - 1 :
@@ -723,13 +728,20 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int
723728 if self .tok_embeddings :
724729 x = self .tok_embeddings (x )
725730
731+ # For Granite architectures
732+ if self .config .embedding_multiplier :
733+ x = x * self .config .embedding_multiplier
734+
726735 for _ , layer in self .layers .items ():
727736 x = layer (x , input_pos , freqs_cis , mask , cache_lane = cache_lane )
728737
729738 if self .norm :
730739 x = self .norm (x )
731740 if self .output :
732741 x = self .output (x )
742+ # For granite architectures
743+ if self .config .logits_scaling :
744+ x = x / self .config .logits_scaling
733745 # print(f"output shape: {x.shape}")
734746 return x
735747
@@ -741,6 +753,12 @@ def __init__(self, config: TransformerArgs) -> None:
741753 self .feed_forward = FeedForward (config )
742754 self .ffn_norm = RMSNorm (config .dim , config .norm_eps )
743755 self .attention_norm = RMSNorm (config .dim , config .norm_eps )
756+ # None for llama architecture, set for granite architectures
757+ self .residual_multiplier = (
758+ config .residual_multiplier
759+ if config .residual_multiplier is not None
760+ else 1.0
761+ )
744762
745763 def distribute (self , device_mesh : DeviceMesh ):
746764 self .attention .distribute (device_mesh )
@@ -751,8 +769,8 @@ def forward(
751769 ) -> Tensor :
752770 h = x + self .attention (
753771 self .attention_norm (x ), freqs_cis , mask , input_pos , cache_lane = cache_lane
754- )
755- out = h + self .feed_forward (self .ffn_norm (h ))
772+ ) * self . residual_multiplier
773+ out = h + self .feed_forward (self .ffn_norm (h )) * self . residual_multiplier
756774 return out
757775
758776
@@ -779,6 +797,7 @@ def __init__(self, config: TransformerArgs):
779797 self .head_dim = config .head_dim
780798 self .n_local_heads = config .n_local_heads
781799 self .dim = config .dim
800+ self .attention_scale = config .attention_multiplier
782801 self ._register_load_state_dict_pre_hook (self .load_hook )
783802
784803 def setup_cache (self , max_batch_size , max_seq_length , cache_lanes : int = 1 ):
@@ -875,7 +894,16 @@ def forward(
875894
876895 k = k .repeat_interleave (self .n_heads // self .n_local_heads , dim = 1 )
877896 v = v .repeat_interleave (self .n_heads // self .n_local_heads , dim = 1 )
878- y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
897+ y = F .scaled_dot_product_attention (
898+ query = q ,
899+ key = k ,
900+ value = v ,
901+ attn_mask = mask ,
902+ dropout_p = 0.0 ,
903+ # This is None (default) for llama architecture and set for granite
904+ # architectures
905+ scale = self .attention_scale ,
906+ )
879907
880908 # -1 = self.dim
881909 y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , - 1 )
0 commit comments