Skip to content

Commit d815f5f

Browse files
committed
Fix RoPE alpha after refactor in #4d25874
1 parent b2dd5a7 commit d815f5f

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

exllamav2/device.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,20 +123,13 @@ def prepare_sincos(self):
123123
self.cos = self.sin
124124
return
125125

126-
base = cfg.rotary_embedding_base
127-
alpha = cfg.scale_alpha_value or 1.0
128-
scale = cfg.scale_pos_emb or 1.0
129-
130-
# Alpha scaling for any rope_scaling type
131-
132-
if alpha != 1.0: base *= alpha ** (cfg.head_dim / (cfg.head_dim - 2))
133-
134126
# RoPE params
135127

136128
inv_freq, scaling_factor = rope.get_rope_params(device, cfg)
137129

138130
# Common
139131

132+
scale = cfg.scale_pos_emb or 1.0
140133
t = torch.arange(cfg.max_seq_len, device = device, dtype = torch.float32)
141134
if scale != 1.0: t /= scale
142135

exllamav2/rope.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def get_rope_params_su(
1515
):
1616
head_dim = cfg.head_dim
1717
base = cfg.rotary_embedding_base
18+
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
19+
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
1820

1921
a = cfg.max_seq_len
2022
b = cfg.original_max_seq_len
@@ -28,7 +30,6 @@ def get_rope_params_su(
2830
inv_freq = 1.0 / (ext_factors * base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim))
2931
return inv_freq, scaling_factor
3032

31-
3233
# Llama 3.1
3334

3435
def get_rope_params_llama3(
@@ -37,6 +38,8 @@ def get_rope_params_llama3(
3738
):
3839
head_dim = cfg.head_dim
3940
base = cfg.rotary_embedding_base
41+
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
42+
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
4043

4144
def apply_scaling(
4245
freqs: torch.Tensor,
@@ -80,6 +83,9 @@ def get_rope_params_yarn(
8083
):
8184
head_dim = cfg.head_dim
8285
base = cfg.rotary_embedding_base
86+
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
87+
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
88+
8389
yarn_max_position_embeddings = cfg.max_seq_len
8490

8591
# Only activate if longer than original ctx
@@ -146,6 +152,8 @@ def get_rope_params_default(
146152
):
147153
head_dim = cfg.head_dim
148154
base = cfg.rotary_embedding_base
155+
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
156+
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
149157

150158
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim))
151159
return inv_freq, 1.0

0 commit comments

Comments
 (0)