Skip to content

Commit d471d44

Browse files
committed
Gemma3 local RoPE fixes
1 parent a03db45 commit d471d44

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

exllamav2/architecture.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ class Params:
169169
swa = False
170170
alternating_swa = False
171171
sliding_rope_theta = None
172+
sliding_rope_scale = None
173+
pos_id_index = 0
172174

173175
# Model only works with eager attention
174176
eager_attn_only = False
@@ -508,6 +510,7 @@ class Params:
508510
self.lm.alternating_swa = True
509511
self.lm.residual_stream_fp32 = True
510512
self.lm.sliding_rope_theta = 10000
513+
self.lm.sliding_rope_scale = 1
511514
self.lm.default_vocab_size = 262208
512515
self.lm.default_rms_norm_eps = 1e-06
513516
self.lm.default_head_dim = 256
@@ -516,6 +519,7 @@ class Params:
516519
self.lm.default_use_qk_norm = True
517520
self.lm.default_sliding_window_pattern = 6
518521
self.lm.default_rope_theta = 1e6
522+
self.lm.pos_id_index = 1
519523

520524
self.vt_prefix = "vision_tower.vision_model."
521525
self.vt.keys.update({

exllamav2/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ class ExLlamaV2Config:
100100
vocab_size: int
101101
rotary_embedding_base: float
102102
rotary_embedding_base_alt: float | None
103+
pos_id_index: int
104+
scale_pos_emb_alt: float | None
103105
scale_long_factor: list[float] | None
104106
scale_short_factor: list[float] | None
105107
alt_rope_method: str | None
@@ -358,6 +360,8 @@ def prepare(self, no_tensors: bool = False):
358360
)
359361

360362
self.rotary_embedding_base_alt = self.arch.lm.sliding_rope_theta
363+
self.scale_pos_emb_alt = self.arch.lm.sliding_rope_scale
364+
self.pos_id_index = self.arch.lm.pos_id_index
361365

362366
self.max_seq_len = read(
363367
read_config,
@@ -373,11 +377,12 @@ def prepare(self, no_tensors: bool = False):
373377

374378
self.partial_rotary_factor = read(read_config, float, "partial_rotary_factor", 1.0)
375379

376-
rs = read(read_config, dict, "rope_scaling", None)
380+
rs = read(read_config, dict, ["rope_scaling", "text_config->rope_scaling"], None)
377381
if rs:
378382
scaling_type = rs.get("type", None)
379383
rope_type = rs.get("rope_type", None)
380384
assert not (scaling_type and rope_type), "rope_scaling key has both `type` and `rope_type` subkeys"
385+
if not scaling_type: scaling_type = rope_type
381386
if scaling_type == "linear":
382387
assert "factor" in rs, "'factor' missing from 'rope_scaling' config"
383388
self.scale_pos_emb = rs.get("factor", 1.0)
@@ -394,7 +399,7 @@ def prepare(self, no_tensors: bool = False):
394399
self.alt_rope_method = "yarn"
395400
self.yarn_rope_factor = rs["factor"]
396401
self.yarn_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]
397-
if rope_type == "llama3":
402+
if scaling_type == "llama3":
398403
self.alt_rope_method = "llama3"
399404
self.l3_rope_factor = rs["factor"]
400405
self.l3_rope_low_freq_factor = rs["low_freq_factor"]

exllamav2/device.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class ExLlamaV2DeviceContext:
4242

4343
sin: list[torch.Tensor] | None
4444
cos: list[torch.Tensor] | None
45+
local_sin: list[torch.Tensor] | None
46+
local_cos: list[torch.Tensor] | None
4547

4648
scratch: torch.Tensor | None
4749

@@ -119,13 +121,15 @@ def prepare_sincos(self):
119121
cfg = self.model.config
120122

121123
thetas = [cfg.rotary_embedding_base]
124+
scales = [cfg.scale_pos_emb]
122125
if cfg.rotary_embedding_base_alt:
123126
thetas.append(cfg.rotary_embedding_base_alt)
127+
scales.append(cfg.scale_pos_emb_alt)
124128

125129
self.sin = []
126130
self.cos = []
127131

128-
for theta in thetas:
132+
for theta, lscale in zip(thetas, scales):
129133

130134
if self.archparams.rope_style == RopeStyle.NONE:
131135
sin = torch.zeros((1,), device = device, dtype = torch.half)
@@ -140,8 +144,10 @@ def prepare_sincos(self):
140144

141145
# Common
142146

143-
scale = cfg.scale_pos_emb or 1.0
147+
scale = lscale or 1.0
144148
t = torch.arange(cfg.max_seq_len, device = device, dtype = torch.float32)
149+
if cfg.pos_id_index != 0:
150+
t += cfg.pos_id_index
145151
if scale != 1.0: t /= scale
146152

147153
freqs = torch.einsum("i,j->ij", t, inv_freq)

exllamav2/model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,12 @@ def __init__(
109109
rope_index = 0
110110

111111
if cfg.arch.lm.alternating_swa:
112-
swa = cfg.sliding_window if (layer_idx + 1) % cfg.sliding_window_pattern != 0 else 0
113-
if cfg.rotary_embedding_base_alt:
114-
rope_index = 1
112+
if cfg.sliding_window_pattern > 1:
113+
swa = cfg.sliding_window if (layer_idx + 1) % cfg.sliding_window_pattern != 0 else 0
114+
if cfg.rotary_embedding_base_alt:
115+
rope_index = 1
116+
else:
117+
swa = cfg.sliding_window if not bool(layer_idx % 2) else 0
115118
elif cfg.arch.lm.swa:
116119
swa = cfg.sliding_window
117120
else:

0 commit comments

Comments
 (0)