@@ -1019,7 +1019,7 @@ def patched__compute_dynamic_ntk_parameters(
10191019 return inv_freq , attention_factor
10201020
10211021
1022- def _get_rope_init_fn (self ) -> Callable :
1022+ def _get_rope_init_fn (self , layer_type = None ) -> Callable :
10231023 if hasattr (self , "rope_init_fn" ):
10241024 # transformers<=5.0
10251025 rope_init_fn = (
@@ -1030,8 +1030,9 @@ def _get_rope_init_fn(self) -> Callable:
10301030 )
10311031 return rope_init_fn
10321032
1033+ rope_type = self .rope_type if layer_type is None else self .rope_type [layer_type ]
10331034 rope_init_fn = self .compute_default_rope_parameters
1034- if self . rope_type != "default" :
1035+ if rope_type != "default" :
10351036 rope_init_fn = transformers .modeling_rope_utils .ROPE_INIT_FUNCTIONS [self .rope_type ]
10361037 if rope_init_fn is transformers .modeling_rope_utils ._compute_dynamic_ntk_parameters :
10371038 return patched__compute_dynamic_ntk_parameters
@@ -1101,17 +1102,27 @@ def wrapper(self, x, position_ids):
11011102
11021103 """
11031104
1104- def longrope_frequency_update (self , position_ids , device ):
1105+ def longrope_frequency_update (self , position_ids , device , layer_type = None ):
11051106 # It is no use to patch the function after the model is created
11061107 # as rope_init_fn is an attribute set to one function when the model
11071108 # is created and when no patch is applied yet.
11081109 # So we select the patched version here.
1109- rope_init_fn = _get_rope_init_fn (self )
1110+ rope_init_fn = _get_rope_init_fn (self , layer_type = layer_type )
11101111 seq_len = torch .max (position_ids ) + 1
11111112 if hasattr (self .config , "original_max_position_embeddings" ):
11121113 original_max_position_embeddings = self .config .original_max_position_embeddings
11131114 else :
11141115 original_max_position_embeddings = self .config .max_position_embeddings
1116+
1117+ if layer_type is None :
1118+ # rope_type = self.rope_type
1119+ original_inv_freq = self .original_inv_freq
1120+ prefix = ""
1121+ else :
1122+ # rope_type = self.rope_type[layer_type]
1123+ original_inv_freq = getattr (self , f"{ layer_type } _original_inv_freq" )
1124+ prefix = f"{ layer_type } _"
1125+
11151126 # At export time, seq_len is unknown.
11161127 long_inv_freq , _ = rope_init_fn (
11171128 self .config , device , seq_len = original_max_position_embeddings + 1
@@ -1126,13 +1137,13 @@ def longrope_frequency_update(self, position_ids, device):
11261137 (lambda x , y : y .clone ()),
11271138 [long_inv_freq , original_inv_freq ],
11281139 )
1129- self . inv_freq = inv_freq
1140+ setattr ( self , f" { prefix } inv_freq" , inv_freq )
11301141 # if seq_len > original_max_position_embeddings:
11311142 # self.inv_freq = self.long_inv_freq
11321143 # else:
11331144 # self.inv_freq = self.original_inv_freq
11341145
1135- def dynamic_frequency_update (self , position_ids , device ):
1146+ def dynamic_frequency_update (self , position_ids , device , layer_type = None ):
11361147 # constructor:
11371148 # - self.max_seq_len_cached = config.max_position_embeddings
11381149 # - self.original_max_seq_len = config.max_position_embeddings
@@ -1142,7 +1153,7 @@ def dynamic_frequency_update(self, position_ids, device):
11421153 # as rope_init_fn is an attribute set to one function when the model
11431154 # is created and when no patch is applied yet.
11441155 # So we select the patched version here.
1145- rope_init_fn = _get_rope_init_fn (self )
1156+ rope_init_fn = _get_rope_init_fn (self , layer_type = layer_type )
11461157
11471158 # This behaviour is difficult to translate.
11481159 # The sequence always grows.
@@ -1171,6 +1182,19 @@ def dynamic_frequency_update(self, position_ids, device):
11711182 self .config , device , seq_len = seq_len
11721183 )
11731184
1185+ if layer_type is None :
1186+ # rope_type = self.rope_type
1187+ # max_seq_len_cached = self.max_seq_len_cached
1188+ original_inv_freq = self .original_inv_freq
1189+ prefix = ""
1190+ else :
1191+ # rope_type = self.rope_type[layer_type]
1192+ # max_seq_len_cached = getattr(
1193+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
1194+ # )
1195+ original_inv_freq = getattr (self , f"{ layer_type } _original_inv_freq" )
1196+ prefix = f"{ layer_type } _"
1197+
11741198 # Second test to translate.
11751199 # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
11761200 # But in that case the following condition is a way to restore the original cache.
@@ -1192,15 +1216,26 @@ def dynamic_frequency_update(self, position_ids, device):
11921216 (lambda x , y : y .clone ()),
11931217 [long_inv_freq , original_inv_freq ],
11941218 )
1195- self . inv_freq = inv_freq
1219+ setattr ( self , f" { prefix } inv_freq" , inv_freq )
11961220
11971221 @wraps (rope_forward )
1198- def wrapper (self , x , position_ids ):
1222+ def wrapper (self , x , position_ids , layer_type = None ):
1223+ if layer_type is None :
1224+ if "dynamic" in self .rope_type :
1225+ dynamic_frequency_update (self , position_ids , device = x .device )
1226+ elif self .rope_type == "longrope" :
1227+ longrope_frequency_update (self , position_ids , device = x .device )
1228+ return rope_forward (self , x , position_ids )
1229+
11991230 if "dynamic" in self .rope_type :
1200- dynamic_frequency_update (self , position_ids , device = x .device )
1231+ dynamic_frequency_update (
1232+ self , position_ids , device = x .device , layer_type = layer_type
1233+ )
12011234 elif self .rope_type == "longrope" :
1202- longrope_frequency_update (self , position_ids , device = x .device )
1203- return rope_forward (self , x , position_ids )
1235+ longrope_frequency_update (
1236+ self , position_ids , device = x .device , layer_type = layer_type
1237+ )
1238+ return rope_forward (self , x , position_ids , layer_type = layer_type )
12041239
12051240 return wrapper
12061241
@@ -1296,12 +1331,18 @@ class common_RotaryEmbedding(torch.nn.Module):
12961331 # @torch.no_grad()
12971332 # PATCHED: the decorator
12981333 @patched_dynamic_rope_update
1299- def forward (self , x , position_ids ):
1334+ def forward (self , x , position_ids , layer_type = None ):
1335+ if layer_type is not None :
1336+ # transformers>=5.0
1337+ inv_freq = getattr (self , f"{ layer_type } _inv_freq" )
1338+ attention_scaling = getattr (self , f"{ layer_type } _attention_scaling" )
1339+ else :
1340+ # transformers<5.0
1341+ inv_freq = self .inv_freq
1342+ attention_scaling = self .attention_scaling
1343+
13001344 inv_freq_expanded = (
1301- self .inv_freq [None , :, None ]
1302- .float ()
1303- .expand (position_ids .shape [0 ], - 1 , 1 )
1304- .to (x .device )
1345+ inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
13051346 )
13061347 position_ids_expanded = position_ids [:, None , :].float ()
13071348
@@ -1313,8 +1354,8 @@ def forward(self, x, position_ids):
13131354 with torch .autocast (device_type = device_type , enabled = False ): # Force float32
13141355 freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
13151356 emb = torch .cat ((freqs , freqs ), dim = - 1 )
1316- cos = emb .cos () * self . attention_scaling
1317- sin = emb .sin () * self . attention_scaling
1357+ cos = emb .cos () * attention_scaling
1358+ sin = emb .sin () * attention_scaling
13181359
13191360 return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
13201361
0 commit comments