Skip to content

Commit 05d1352

Browse files
committed
Add RoPE scaling for Llama3.1
1 parent 46a803f commit 05d1352

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

exllamav2/config.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
T = TypeVar('T')
1111
no_default = object()
1212

13-
def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str], default = no_default) -> T:
13+
def read(input_dict: dict[str, Any], expected_type: type | list[type], keys: str | list[str], default = no_default) -> T:
14+
15+
expected_types = expected_type if isinstance(expected_type, list) else [expected_type]
1416

1517
if isinstance(keys, str): keys = [keys]
1618

@@ -34,10 +36,10 @@ def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str],
3436
if expected_type == int and isinstance(x, float) and x == int(x):
3537
x = int(x)
3638

37-
if isinstance(x, expected_type):
38-
return cast(T, x)
39-
else:
40-
raise TypeError(f"Value for {key} is not of expected type {expected_type}")
39+
for t in expected_types:
40+
if isinstance(x, t):
41+
return cast(T, x)
42+
raise TypeError(f"Value for {key} is not of expected type {expected_type}")
4143

4244
if default != no_default: return default
4345
raise ValueError(f"Missing any of the following keys: {keys}")
@@ -105,7 +107,10 @@ class ExLlamaV2Config:
105107
attn_logit_softcapping: float | None
106108
sliding_window: int
107109
norm_head: int | None
108-
110+
l3_rope_factor: float | None
111+
l3_rope_low_freq_factor: float | None
112+
l3_rope_high_freq_factor: float | None
113+
l3_rope_original_max_position_embeddings: int | None
109114
checkpoint_fused_mlp: bool
110115
checkpoint_offset_qzeros: bool
111116

@@ -191,10 +196,13 @@ def prepare(self, no_tensors: bool = False):
191196
# Vocab params
192197

193198
self.bos_token_id = read(read_config, int, "bos_token_id", None) # 1
194-
self.eos_token_id = read(read_config, int, "eos_token_id", None) # 2
199+
self.eos_token_id = read(read_config, [int, list], "eos_token_id", None) # 2
195200
self.pad_token_id = read(read_config, int, "pad_token_id", None) # 0
196201
self.vocab_size = read(read_config, int, "vocab_size")
197202

203+
if isinstance(self.eos_token_id, list):
204+
self.eos_token_id = self.eos_token_id[0] # TODO: Figure out a way to maybe use all the EOS tokens somehow
205+
198206
# Standard params
199207

200208
self.initializer_range = read(read_config, float, ["initializer_range"])
@@ -287,6 +295,13 @@ def prepare(self, no_tensors: bool = False):
287295
self.alt_rope_method = "su"
288296
# if scaling_type == "yarn":
289297
# self.scale_alpha_value = factor
298+
rope_type = rs.get("rope_type", None)
299+
if rope_type == "llama3":
300+
self.alt_rope_method = "llama3"
301+
self.l3_rope_factor = rs["factor"]
302+
self.l3_rope_low_freq_factor = rs["low_freq_factor"]
303+
self.l3_rope_high_freq_factor = rs["high_freq_factor"]
304+
self.l3_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]
290305

291306
# Checkpoint format (for GPTQ models)
292307

exllamav2/model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,31 @@ def get_scratch_slice(self, size_bytes):
129129
return scratch_slice
130130

131131

132+
@staticmethod
133+
def _apply_scaling(
134+
freqs: torch.Tensor,
135+
scale_factor: float = 8,
136+
low_freq_factor: float = 1,
137+
high_freq_factor: float = 4,
138+
old_context_len: int = 8192, # original llama3 length
139+
):
140+
low_freq_wavelen = old_context_len / low_freq_factor
141+
high_freq_wavelen = old_context_len / high_freq_factor
142+
new_freqs = []
143+
144+
for freq in freqs:
145+
wavelen = 2 * math.pi / freq
146+
if wavelen < high_freq_wavelen:
147+
new_freqs.append(freq)
148+
elif wavelen > low_freq_wavelen:
149+
new_freqs.append(freq / scale_factor)
150+
else:
151+
assert low_freq_wavelen != high_freq_wavelen
152+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
153+
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
154+
return torch.tensor(new_freqs, dtype = freqs.dtype, device = freqs.device)
155+
156+
132157
def prepare_sincos(self):
133158

134159
device = _torch_device(self.device_idx)
@@ -163,6 +188,19 @@ def prepare_sincos(self):
163188

164189
inv_freq = 1.0 / (ext_factors * base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim))
165190

191+
# Llama 3.1
192+
193+
elif cfg.alt_rope_method == "llama3":
194+
195+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim))
196+
inv_freq = self._apply_scaling(
197+
inv_freq,
198+
cfg.l3_rope_factor,
199+
cfg.l3_rope_low_freq_factor,
200+
cfg.l3_rope_high_freq_factor,
201+
cfg.l3_rope_original_max_position_embeddings,
202+
)
203+
166204
# Regular
167205

168206
else:

0 commit comments

Comments
 (0)