|
6 | 6 | from vllm.config import LoRAConfig
|
7 | 7 | from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
8 | 8 | MergedColumnParallelLinearWithLoRA,
|
| 9 | + MergedQKVParallelLinearWithLoRA, |
| 10 | + QKVParallelLinearWithLoRA, |
9 | 11 | RowParallelLinearWithLoRA,
|
10 | 12 | VocabParallelEmbeddingWithLoRA)
|
| 13 | +from vllm.lora.layers.utils import _not_fully_sharded_can_replace |
11 | 14 |
|
12 | 15 | from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
13 | 16 | AscendMergedColumnParallelLinear,
|
| 17 | + AscendQKVParallelLinear, |
14 | 18 | AscendRowParallelLinear)
|
15 | 19 | from vllm_ascend.ops.vocab_parallel_embedding import \
|
16 | 20 | AscendVocabParallelEmbedding
|
@@ -69,9 +73,38 @@ def can_replace_layer(
|
69 | 73 | return type(source_layer) is AscendVocabParallelEmbedding
|
70 | 74 |
|
71 | 75 |
|
| 76 | +class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA): |
| 77 | + |
| 78 | + @classmethod |
| 79 | + @_not_fully_sharded_can_replace |
| 80 | + def can_replace_layer(cls, source_layer: nn.Module, |
| 81 | + lora_config: LoRAConfig, packed_modules_list: list, |
| 82 | + model_config: Optional[PretrainedConfig]) -> bool: |
| 83 | + return type(source_layer) is AscendQKVParallelLinear and len( |
| 84 | + packed_modules_list) == 1 |
| 85 | + |
| 86 | + |
| 87 | +class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA): |
| 88 | + |
| 89 | + @classmethod |
| 90 | + @_not_fully_sharded_can_replace |
| 91 | + def can_replace_layer( |
| 92 | + cls, |
| 93 | + source_layer: nn.Module, |
| 94 | + lora_config: LoRAConfig, |
| 95 | + packed_modules_list: list, |
| 96 | + model_config: Optional[PretrainedConfig], |
| 97 | + ) -> bool: |
| 98 | + return (type(source_layer) is AscendQKVParallelLinear |
| 99 | + and len(packed_modules_list) == 3) |
| 100 | + |
| 101 | + |
72 | 102 | def refresh_all_lora_classes():
|
73 | 103 | vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
74 | 104 | vllm.lora.utils._all_lora_classes.add(
|
75 | 105 | AscendMergedColumnParallelLinearWithLoRA)
|
76 | 106 | vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
77 | 107 | vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
| 108 | + vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA) |
| 109 | + vllm.lora.utils._all_lora_classes.add( |
| 110 | + AscendMergedQKVParallelLinearWithLoRA) |
0 commit comments