Skip to content

Commit dcbf428

Browse files
authored
[Frontend] Customizable RoPE theta (#5197)
1 parent 00e6a2d commit dcbf428

File tree

5 files changed

+27
-8
lines changed

5 files changed

+27
-8
lines changed

tests/test_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ def test_get_sliding_window():
6363
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
6464

6565

66-
def test_rope_scaling():
66+
def test_rope_customization():
6767
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
68+
TEST_ROPE_THETA = 16_000_000.0
6869
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
6970

7071
llama_model_config = ModelConfig(
@@ -76,6 +77,7 @@ def test_rope_scaling():
7677
seed=0,
7778
)
7879
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
80+
assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
7981
assert llama_model_config.max_model_len == 8192
8082

8183
llama_model_config = ModelConfig(
@@ -86,9 +88,12 @@ def test_rope_scaling():
8688
dtype="float16",
8789
seed=0,
8890
rope_scaling=TEST_ROPE_SCALING,
91+
rope_theta=TEST_ROPE_THETA,
8992
)
9093
assert getattr(llama_model_config.hf_config, "rope_scaling",
9194
None) == TEST_ROPE_SCALING
95+
assert getattr(llama_model_config.hf_config, "rope_theta",
96+
None) == TEST_ROPE_THETA
9297
assert llama_model_config.max_model_len == 16384
9398

9499
longchat_model_config = ModelConfig(

vllm/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
revision: Optional[str] = None,
9494
code_revision: Optional[str] = None,
9595
rope_scaling: Optional[dict] = None,
96+
rope_theta: Optional[float] = None,
9697
tokenizer_revision: Optional[str] = None,
9798
max_model_len: Optional[int] = None,
9899
quantization: Optional[str] = None,
@@ -113,6 +114,7 @@ def __init__(
113114
self.revision = revision
114115
self.code_revision = code_revision
115116
self.rope_scaling = rope_scaling
117+
self.rope_theta = rope_theta
116118
# The tokenizer version is consistent with the model version by default.
117119
if tokenizer_revision is None:
118120
self.tokenizer_revision = revision
@@ -132,7 +134,7 @@ def __init__(
132134
self.skip_tokenizer_init = skip_tokenizer_init
133135

134136
self.hf_config = get_config(self.model, trust_remote_code, revision,
135-
code_revision, rope_scaling)
137+
code_revision, rope_scaling, rope_theta)
136138
self.hf_text_config = get_hf_text_config(self.hf_config)
137139
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
138140
self.max_model_len = _get_and_verify_max_len(

vllm/engine/arg_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class EngineArgs:
5353
revision: Optional[str] = None
5454
code_revision: Optional[str] = None
5555
rope_scaling: Optional[dict] = None
56+
rope_theta: Optional[float] = None
5657
tokenizer_revision: Optional[str] = None
5758
quantization: Optional[str] = None
5859
enforce_eager: bool = False
@@ -400,6 +401,12 @@ def add_cli_args(
400401
type=json.loads,
401402
help='RoPE scaling configuration in JSON format. '
402403
'For example, {"type":"dynamic","factor":2.0}')
404+
parser.add_argument('--rope-theta',
405+
default=None,
406+
type=float,
407+
help='RoPE theta. Use with `rope_scaling`. In '
408+
'some cases, changing the RoPE theta improves the '
409+
'performance of the scaled model.')
403410
parser.add_argument('--enforce-eager',
404411
action='store_true',
405412
help='Always use eager-mode PyTorch. If False, '
@@ -630,6 +637,7 @@ def create_engine_config(self, ) -> EngineConfig:
630637
revision=self.revision,
631638
code_revision=self.code_revision,
632639
rope_scaling=self.rope_scaling,
640+
rope_theta=self.rope_theta,
633641
tokenizer_revision=self.tokenizer_revision,
634642
max_model_len=self.max_model_len,
635643
quantization=self.quantization,

vllm/engine/llm_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(
162162
"Initializing an LLM engine (v%s) with config: "
163163
"model=%r, speculative_config=%r, tokenizer=%r, "
164164
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
165-
"rope_scaling=%r, tokenizer_revision=%s, "
165+
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
166166
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
167167
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
168168
"disable_custom_all_reduce=%s, quantization=%s, "
@@ -177,6 +177,7 @@ def __init__(
177177
model_config.tokenizer_mode,
178178
model_config.revision,
179179
model_config.rope_scaling,
180+
model_config.rope_theta,
180181
model_config.tokenizer_revision,
181182
model_config.trust_remote_code,
182183
model_config.dtype,

vllm/transformers_utils/config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def get_config(model: str,
2323
trust_remote_code: bool,
2424
revision: Optional[str] = None,
2525
code_revision: Optional[str] = None,
26-
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
26+
rope_scaling: Optional[dict] = None,
27+
rope_theta: Optional[float] = None) -> PretrainedConfig:
2728
try:
2829
if VLLM_USE_MODELSCOPE:
2930
from modelscope import AutoConfig
@@ -50,10 +51,12 @@ def get_config(model: str,
5051
config = config_class.from_pretrained(model,
5152
revision=revision,
5253
code_revision=code_revision)
53-
if rope_scaling is not None:
54-
logger.info("Updating rope_scaling from %r to %r",
55-
getattr(config, "rope_scaling", None), rope_scaling)
56-
config.update({"rope_scaling": rope_scaling})
54+
for key, value in [("rope_scaling", rope_scaling),
55+
("rope_theta", rope_theta)]:
56+
if value is not None:
57+
logger.info("Updating %s from %r to %r", key,
58+
getattr(config, key, None), value)
59+
config.update({key: value})
5760
return config
5861

5962

0 commit comments

Comments
 (0)