Skip to content

Commit 8ac3a41

Browse files
hl475hmellorDarkLight1337
authored
[CI Failure] Fix Gemma3 RoPE configuration for sliding attention layers (#29111)
Signed-off-by: Huamin Li <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Co-authored-by: Harry Mellor <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 7d6da48 commit 8ac3a41

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/model_executor/models/gemma3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,12 @@ def __init__(
166166
else:
167167
# Transformers v4 rope config.
168168
# Global attention. Use the values in config.json.
169-
rope_parameters = config.rope_parameters.copy()
169+
rope_parameters = config.rope_parameters
170170
# Local attention. Override the values in config.json.
171171
if self.is_sliding:
172-
rope_parameters["rope_theta"] = config.rope_local_base_freq
172+
rope_parameters = dict(
173+
rope_type="default", rope_theta=config.rope_local_base_freq
174+
)
173175

174176
self.rotary_emb = get_rope(
175177
self.head_dim,

0 commit comments

Comments
 (0)