Skip to content

Commit 9caf6fb

Browse files
authored
[Bugfix][LoRA] Fix LoRA bug after supporting Qwen3-Next (#3044)
### What this PR does / why we need it? LoRA e2e test uses ilama-3.2-1B model. It uses transformers.py model files. Its self-attention layer names end with "\*.attn", not "\*.self_attn". There are some other model attention layer names end with "*.attn", such as baichuan.py, bert.py. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? pytest -sv tests/e2e/singlecard/test_ilama_lora.py pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@17b4c66 --------- Signed-off-by: paulyu12 <[email protected]>
1 parent 8406aaf commit 9caf6fb

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

.github/workflows/_e2e_test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ jobs:
9292
pytest -sv tests/e2e/singlecard/test_chunked.py
9393
pytest -sv tests/e2e/singlecard/test_embedding.py
9494
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
95-
#pytest -sv tests/e2e/singlecard/test_ilama_lora.py
95+
pytest -sv tests/e2e/singlecard/test_ilama_lora.py
9696
pytest -sv tests/e2e/singlecard/test_profile_execute_duration.py
9797
pytest -sv tests/e2e/singlecard/test_quantization.py
9898
pytest -sv tests/e2e/singlecard/test_sampler.py
@@ -174,7 +174,7 @@ jobs:
174174
# external_launcher test is not stable enough. Fix it later
175175
# pytest -sv tests/e2e/multicard/test_external_launcher.py
176176
pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py
177-
#pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
177+
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
178178
179179
# To avoid oom, we need to run the test in a single process.
180180
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ

vllm_ascend/lora/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66
from vllm.config import LoRAConfig
77
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
88
MergedColumnParallelLinearWithLoRA,
9+
MergedQKVParallelLinearWithLoRA,
10+
QKVParallelLinearWithLoRA,
911
RowParallelLinearWithLoRA,
1012
VocabParallelEmbeddingWithLoRA)
13+
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
1114

1215
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
1316
AscendMergedColumnParallelLinear,
17+
AscendQKVParallelLinear,
1418
AscendRowParallelLinear)
1519
from vllm_ascend.ops.vocab_parallel_embedding import \
1620
AscendVocabParallelEmbedding
@@ -69,9 +73,38 @@ def can_replace_layer(
6973
return type(source_layer) is AscendVocabParallelEmbedding
7074

7175

76+
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
77+
78+
@classmethod
79+
@_not_fully_sharded_can_replace
80+
def can_replace_layer(cls, source_layer: nn.Module,
81+
lora_config: LoRAConfig, packed_modules_list: list,
82+
model_config: Optional[PretrainedConfig]) -> bool:
83+
return type(source_layer) is AscendQKVParallelLinear and len(
84+
packed_modules_list) == 1
85+
86+
87+
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
88+
89+
@classmethod
90+
@_not_fully_sharded_can_replace
91+
def can_replace_layer(
92+
cls,
93+
source_layer: nn.Module,
94+
lora_config: LoRAConfig,
95+
packed_modules_list: list,
96+
model_config: Optional[PretrainedConfig],
97+
) -> bool:
98+
return (type(source_layer) is AscendQKVParallelLinear
99+
and len(packed_modules_list) == 3)
100+
101+
72102
def refresh_all_lora_classes():
73103
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
74104
vllm.lora.utils._all_lora_classes.add(
75105
AscendMergedColumnParallelLinearWithLoRA)
76106
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
77107
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
108+
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
109+
vllm.lora.utils._all_lora_classes.add(
110+
AscendMergedQKVParallelLinearWithLoRA)

0 commit comments

Comments
 (0)