Skip to content

Commit 36ea790

Browse files
authored
[Misc][LoRA] Support loading LoRA weights for target_modules in reg format (#9275)
1 parent e808156 commit 36ea790

File tree

4 files changed

+59
-5
lines changed

4 files changed

+59
-5
lines changed

tests/lora/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ def baichuan_zero_lora_files():
199199
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
200200

201201

202+
@pytest.fixture(scope="session")
203+
def baichuan_regex_lora_files():
204+
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
205+
206+
202207
@pytest.fixture(scope="session")
203208
def minicpmv_lora_files():
204209
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")

tests/lora/test_lora_checkpoints.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
from vllm.lora.models import LoRAModel
66
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
77

8-
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
8+
lora_lst = [
9+
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
10+
]
911

1012

1113
@pytest.mark.parametrize("lora_name", lora_lst)
1214
def test_load_checkpoints(
1315
lora_name,
1416
baichuan_lora_files,
1517
baichuan_zero_lora_files,
18+
baichuan_regex_lora_files,
1619
chatglm3_lora_files,
1720
):
1821
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
@@ -36,7 +39,7 @@ def test_load_checkpoints(
3639
embedding_modules=embedding_modules,
3740
embedding_padding_modules=embed_padding_modules)
3841
elif lora_name == "baichuan7B-zero":
39-
#Test that the target_modules contain prefix
42+
# Test that the target_modules contain prefix
4043
# such as "model.layers.0.self_atten.W_pack", and
4144
# the test should pass.
4245
LoRAModel.from_local_checkpoint(
@@ -46,6 +49,16 @@ def test_load_checkpoints(
4649
device="cpu",
4750
embedding_modules=embedding_modules,
4851
embedding_padding_modules=embed_padding_modules)
52+
elif lora_name == "baichuan7B-zero-regex":
53+
# Test that the `target_modules` in the form of regular expressions,
54+
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
55+
LoRAModel.from_local_checkpoint(
56+
baichuan_regex_lora_files,
57+
expected_lora_modules,
58+
lora_model_id=1,
59+
device="cpu",
60+
embedding_modules=embedding_modules,
61+
embedding_padding_modules=embed_padding_modules)
4962
else:
5063
# For the baichuan7B model, load chatglm3-6b's LoRA,
5164
# and the test should raise the following error.

vllm/lora/models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
2424
from vllm.lora.punica import PunicaWrapper
2525
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
26+
is_regex_target_modules,
2627
parse_fine_tuned_lora_name, replace_submodule)
2728
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
2829
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -233,6 +234,8 @@ def from_local_checkpoint(
233234
# modules.
234235
unexpected_modules = []
235236
target_modules = config["target_modules"]
237+
if not isinstance(target_modules, list):
238+
target_modules = [target_modules]
236239
for module in target_modules:
237240
# Compatible with more modules,
238241
# such as:layers.11.self_attn.k_proj
@@ -243,8 +246,8 @@ def from_local_checkpoint(
243246
# expected_lora_modules. It is not reliable. See
244247
# https://github.com/vllm-project/vllm/pull/5909. But there's no
245248
# other better mechanism.
246-
if unexpected_modules:
247-
print(unexpected_modules, "modules")
249+
if unexpected_modules and not is_regex_target_modules(
250+
config["target_modules"], expected_lora_modules):
248251
raise ValueError(
249252
f"While loading {lora_dir}, expected"
250253
f" target modules in {expected_lora_modules}"

vllm/lora/utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
2-
from typing import List, Optional, Set, Tuple, Type
2+
import re
3+
from typing import List, Optional, Set, Tuple, Type, Union
34

45
import huggingface_hub
56
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
@@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
113114
raise ValueError(f"{name} is unsupported LoRA weight")
114115

115116

117+
def is_regex_target_modules(load_modules: Union[str, List[str]],
118+
expected_lora_modules: List[str]) -> bool:
119+
"""
120+
PEFT supports passing `target_modules` in the form of regular expressions,
121+
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
122+
determine whether the suffix in the regular expression is present in the
123+
`expected_lora_modules`.
124+
"""
125+
126+
def is_valid_regex(pattern):
127+
try:
128+
re.compile(pattern)
129+
return True
130+
except re.error:
131+
return False
132+
133+
def is_subset(sub_list, full_list):
134+
return set(sub_list).issubset(set(full_list))
135+
136+
# Similar to PEFT's processing logic, regex-related operations are only
137+
# executed when the load_modules is a `str`.
138+
if not isinstance(load_modules, str):
139+
return False
140+
141+
if is_valid_regex(load_modules):
142+
match = re.search(r"\((.*?)\)\$?$", load_modules)
143+
if match:
144+
suffix = match.group(1).split("|")
145+
return is_subset(suffix, expected_lora_modules)
146+
return False
147+
148+
116149
def get_adapter_absolute_path(lora_path: str) -> str:
117150
"""
118151
Resolves the given lora_path to an absolute local path.

0 commit comments

Comments
 (0)