@@ -1019,6 +1019,26 @@ def patched__compute_dynamic_ntk_parameters(
10191019 return inv_freq , attention_factor
10201020
10211021
1022+ def _get_rope_init_fn (self , layer_type = None ) -> Callable :
1023+ if hasattr (self , "rope_init_fn" ):
1024+ # transformers<=5.0
1025+ rope_init_fn = (
1026+ patched__compute_dynamic_ntk_parameters
1027+ if self .rope_init_fn
1028+ is transformers .modeling_rope_utils ._compute_dynamic_ntk_parameters
1029+ else self .rope_init_fn
1030+ )
1031+ return rope_init_fn
1032+
1033+ rope_type = self .rope_type if layer_type is None else self .rope_type [layer_type ]
1034+ rope_init_fn = self .compute_default_rope_parameters
1035+ if rope_type != "default" :
1036+ rope_init_fn = transformers .modeling_rope_utils .ROPE_INIT_FUNCTIONS [self .rope_type ]
1037+ if rope_init_fn is transformers .modeling_rope_utils ._compute_dynamic_ntk_parameters :
1038+ return patched__compute_dynamic_ntk_parameters
1039+ return rope_init_fn
1040+
1041+
10221042def patched_dynamic_rope_update (rope_forward ):
10231043 """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
10241044
@@ -1082,22 +1102,27 @@ def wrapper(self, x, position_ids):
10821102
10831103 """
10841104
1085- def longrope_frequency_update (self , position_ids , device ):
1105+ def longrope_frequency_update (self , position_ids , device , layer_type = None ):
10861106 # It is no use to patch the function after the model is created
10871107 # as rope_init_fn is an attribute set to one function when the model
10881108 # is created and when no patch is applied yet.
10891109 # So we select the patched version here.
1090- rope_init_fn = (
1091- patched__compute_dynamic_ntk_parameters
1092- if self .rope_init_fn
1093- is transformers .modeling_rope_utils ._compute_dynamic_ntk_parameters
1094- else self .rope_init_fn
1095- )
1110+ rope_init_fn = _get_rope_init_fn (self , layer_type = layer_type )
10961111 seq_len = torch .max (position_ids ) + 1
10971112 if hasattr (self .config , "original_max_position_embeddings" ):
10981113 original_max_position_embeddings = self .config .original_max_position_embeddings
10991114 else :
11001115 original_max_position_embeddings = self .config .max_position_embeddings
1116+
1117+ if layer_type is None :
1118+ # rope_type = self.rope_type
1119+ original_inv_freq = self .original_inv_freq
1120+ prefix = ""
1121+ else :
1122+ # rope_type = self.rope_type[layer_type]
1123+ original_inv_freq = getattr (self , f"{ layer_type } _original_inv_freq" )
1124+ prefix = f"{ layer_type } _"
1125+
11011126 # At export time, seq_len is unknown.
11021127 long_inv_freq , _ = rope_init_fn (
11031128 self .config , device , seq_len = original_max_position_embeddings + 1
@@ -1112,13 +1137,13 @@ def longrope_frequency_update(self, position_ids, device):
11121137 (lambda x , y : y .clone ()),
11131138 [long_inv_freq , original_inv_freq ],
11141139 )
1115- self . inv_freq = inv_freq
1140+ setattr ( self , f" { prefix } inv_freq" , inv_freq )
11161141 # if seq_len > original_max_position_embeddings:
11171142 # self.inv_freq = self.long_inv_freq
11181143 # else:
11191144 # self.inv_freq = self.original_inv_freq
11201145
1121- def dynamic_frequency_update (self , position_ids , device ):
1146+ def dynamic_frequency_update (self , position_ids , device , layer_type = None ):
11221147 # constructor:
11231148 # - self.max_seq_len_cached = config.max_position_embeddings
11241149 # - self.original_max_seq_len = config.max_position_embeddings
@@ -1128,12 +1153,7 @@ def dynamic_frequency_update(self, position_ids, device):
11281153 # as rope_init_fn is an attribute set to one function when the model
11291154 # is created and when no patch is applied yet.
11301155 # So we select the patched version here.
1131- rope_init_fn = (
1132- patched__compute_dynamic_ntk_parameters
1133- if self .rope_init_fn
1134- is transformers .modeling_rope_utils ._compute_dynamic_ntk_parameters
1135- else self .rope_init_fn
1136- )
1156+ rope_init_fn = _get_rope_init_fn (self , layer_type = layer_type )
11371157
11381158 # This behaviour is difficult to translate.
11391159 # The sequence always grows.
@@ -1162,6 +1182,19 @@ def dynamic_frequency_update(self, position_ids, device):
11621182 self .config , device , seq_len = seq_len
11631183 )
11641184
1185+ if layer_type is None :
1186+ # rope_type = self.rope_type
1187+ # max_seq_len_cached = self.max_seq_len_cached
1188+ original_inv_freq = self .original_inv_freq
1189+ prefix = ""
1190+ else :
1191+ # rope_type = self.rope_type[layer_type]
1192+ # max_seq_len_cached = getattr(
1193+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
1194+ # )
1195+ original_inv_freq = getattr (self , f"{ layer_type } _original_inv_freq" )
1196+ prefix = f"{ layer_type } _"
1197+
11651198 # Second test to translate.
11661199 # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
11671200 # But in that case the following condition is a way to restore the original cache.
@@ -1183,15 +1216,26 @@ def dynamic_frequency_update(self, position_ids, device):
11831216 (lambda x , y : y .clone ()),
11841217 [long_inv_freq , original_inv_freq ],
11851218 )
1186- self . inv_freq = inv_freq
1219+ setattr ( self , f" { prefix } inv_freq" , inv_freq )
11871220
11881221 @wraps (rope_forward )
1189- def wrapper (self , x , position_ids ):
1222+ def wrapper (self , x , position_ids , layer_type = None ):
1223+ if layer_type is None :
1224+ if "dynamic" in self .rope_type :
1225+ dynamic_frequency_update (self , position_ids , device = x .device )
1226+ elif self .rope_type == "longrope" :
1227+ longrope_frequency_update (self , position_ids , device = x .device )
1228+ return rope_forward (self , x , position_ids )
1229+
11901230 if "dynamic" in self .rope_type :
1191- dynamic_frequency_update (self , position_ids , device = x .device )
1231+ dynamic_frequency_update (
1232+ self , position_ids , device = x .device , layer_type = layer_type
1233+ )
11921234 elif self .rope_type == "longrope" :
1193- longrope_frequency_update (self , position_ids , device = x .device )
1194- return rope_forward (self , x , position_ids )
1235+ longrope_frequency_update (
1236+ self , position_ids , device = x .device , layer_type = layer_type
1237+ )
1238+ return rope_forward (self , x , position_ids , layer_type = layer_type )
11951239
11961240 return wrapper
11971241
@@ -1287,12 +1331,18 @@ class common_RotaryEmbedding(torch.nn.Module):
12871331 # @torch.no_grad()
12881332 # PATCHED: the decorator
12891333 @patched_dynamic_rope_update
1290- def forward (self , x , position_ids ):
1334+ def forward (self , x , position_ids , layer_type = None ):
1335+ if layer_type is not None :
1336+ # transformers>=5.0
1337+ inv_freq = getattr (self , f"{ layer_type } _inv_freq" )
1338+ attention_scaling = getattr (self , f"{ layer_type } _attention_scaling" )
1339+ else :
1340+ # transformers<5.0
1341+ inv_freq = self .inv_freq
1342+ attention_scaling = self .attention_scaling
1343+
12911344 inv_freq_expanded = (
1292- self .inv_freq [None , :, None ]
1293- .float ()
1294- .expand (position_ids .shape [0 ], - 1 , 1 )
1295- .to (x .device )
1345+ inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
12961346 )
12971347 position_ids_expanded = position_ids [:, None , :].float ()
12981348
@@ -1304,8 +1354,8 @@ def forward(self, x, position_ids):
13041354 with torch .autocast (device_type = device_type , enabled = False ): # Force float32
13051355 freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
13061356 emb = torch .cat ((freqs , freqs ), dim = - 1 )
1307- cos = emb .cos () * self . attention_scaling
1308- sin = emb .sin () * self . attention_scaling
1357+ cos = emb .cos () * attention_scaling
1358+ sin = emb .sin () * attention_scaling
13091359
13101360 return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
13111361
0 commit comments