@@ -15,6 +15,8 @@ def get_rope_params_su(
15
15
):
16
16
head_dim = cfg .head_dim
17
17
base = cfg .rotary_embedding_base
18
+ if cfg .scale_alpha_value and cfg .scale_alpha_value != 1.0 :
19
+ base *= cfg .scale_alpha_value ** (cfg .head_dim / (cfg .head_dim - 2 ))
18
20
19
21
a = cfg .max_seq_len
20
22
b = cfg .original_max_seq_len
@@ -28,7 +30,6 @@ def get_rope_params_su(
28
30
inv_freq = 1.0 / (ext_factors * base ** (torch .arange (0 , head_dim , 2 , device = device ).float () / head_dim ))
29
31
return inv_freq , scaling_factor
30
32
31
-
32
33
# Llama 3.1
33
34
34
35
def get_rope_params_llama3 (
@@ -37,6 +38,8 @@ def get_rope_params_llama3(
37
38
):
38
39
head_dim = cfg .head_dim
39
40
base = cfg .rotary_embedding_base
41
+ if cfg .scale_alpha_value and cfg .scale_alpha_value != 1.0 :
42
+ base *= cfg .scale_alpha_value ** (cfg .head_dim / (cfg .head_dim - 2 ))
40
43
41
44
def apply_scaling (
42
45
freqs : torch .Tensor ,
@@ -80,6 +83,9 @@ def get_rope_params_yarn(
80
83
):
81
84
head_dim = cfg .head_dim
82
85
base = cfg .rotary_embedding_base
86
+ if cfg .scale_alpha_value and cfg .scale_alpha_value != 1.0 :
87
+ base *= cfg .scale_alpha_value ** (cfg .head_dim / (cfg .head_dim - 2 ))
88
+
83
89
yarn_max_position_embeddings = cfg .max_seq_len
84
90
85
91
# Only activate if longer than original ctx
@@ -146,6 +152,8 @@ def get_rope_params_default(
146
152
):
147
153
head_dim = cfg .head_dim
148
154
base = cfg .rotary_embedding_base
155
+ if cfg .scale_alpha_value and cfg .scale_alpha_value != 1.0 :
156
+ base *= cfg .scale_alpha_value ** (cfg .head_dim / (cfg .head_dim - 2 ))
149
157
150
158
inv_freq = 1.0 / (base ** (torch .arange (0 , head_dim , 2 , device = device ).float () / head_dim ))
151
159
return inv_freq , 1.0
0 commit comments