Skip to content

Commit 99d7cab

Browse files
authored
[LoRA] ReplicatedLinear support LoRA (#7081)
1 parent fb2c1c8 commit 99d7cab

File tree

3 files changed

+199
-0
lines changed

3 files changed

+199
-0
lines changed

tests/lora/test_layers.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
MergedColumnParallelLinearWithLoRA,
2323
MergedQKVParallelLinearWithLora,
2424
QKVParallelLinearWithLora,
25+
ReplicatedLinearWithLoRA,
2526
RowParallelLinearWithLoRA,
2627
VocabParallelEmbeddingWithLoRA)
2728
# yapf: enable
@@ -31,6 +32,7 @@
3132
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3233
MergedColumnParallelLinear,
3334
QKVParallelLinear,
35+
ReplicatedLinear,
3436
RowParallelLinear)
3537
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3638
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -545,6 +547,107 @@ def _pretest():
545547
atol=atol)
546548

547549

550+
@torch.inference_mode()
551+
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
552+
@pytest.mark.parametrize("device", CUDA_DEVICES)
553+
@pytest.mark.parametrize("stage", STAGES)
554+
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
555+
556+
torch.set_default_device(device)
557+
punica_wrapper = PunicaWrapper(8192, 256, device)
558+
max_loras = 8
559+
lora_config = LoRAConfig(max_loras=max_loras,
560+
max_lora_rank=8,
561+
lora_dtype=torch.float16)
562+
563+
def create_random_linear_replicated_layer():
564+
565+
linear = ReplicatedLinear(4096,
566+
4096,
567+
bias=False,
568+
params_dtype=torch.float16)
569+
linear.weight.data = torch.rand_like(linear.weight.data)
570+
lora_linear = ReplicatedLinearWithLoRA(linear)
571+
572+
lora_linear.create_lora_weights(max_loras, lora_config)
573+
574+
return linear, lora_linear
575+
576+
for i in range(10):
577+
set_random_seed(i)
578+
579+
id_to_index = get_random_id_to_index(num_loras, max_loras)
580+
linear, lora_linear = create_random_linear_replicated_layer()
581+
lora_linear.set_mapping(punica_wrapper)
582+
lora_dict, _ = populate_loras(
583+
id_to_index,
584+
layer=lora_linear,
585+
layer_weights=linear.weight,
586+
)
587+
588+
inputs, index_mapping, prompt_mapping = create_random_inputs(
589+
active_lora_ids=list(lora_dict.keys()),
590+
num_inputs=32 * num_loras,
591+
input_size=(1, 4096),
592+
input_range=(0, 1),
593+
input_type=torch.float16,
594+
)
595+
lora_mapping = LoRAMapping(index_mapping,
596+
prompt_mapping,
597+
is_prefill=stage)
598+
punica_wrapper.update_metadata(
599+
lora_mapping,
600+
id_to_index,
601+
max_loras,
602+
512,
603+
lora_config.lora_extra_vocab_size,
604+
)
605+
606+
lora_result = lora_linear(torch.cat(inputs))[0]
607+
608+
expected_results: List[torch.Tensor] = []
609+
for input_, lora_id in zip(inputs, prompt_mapping):
610+
lora = lora_dict[lora_id]
611+
result = linear(input_)[0]
612+
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
613+
expected_results.append(result)
614+
expected_result = torch.cat(expected_results)
615+
616+
rtol, atol = TOLERANCES[lora_result.dtype]
617+
assert torch.allclose(lora_result,
618+
expected_result,
619+
rtol=rtol,
620+
atol=atol)
621+
622+
# Check that resetting the lora weights succeeds
623+
624+
for slot_idx in range(max_loras):
625+
lora_linear.reset_lora(slot_idx)
626+
627+
inputs, index_mapping, prompt_mapping = create_random_inputs(
628+
active_lora_ids=[0],
629+
num_inputs=32 * num_loras,
630+
input_size=(1, 4096),
631+
input_range=(0, 1),
632+
input_type=torch.float16,
633+
)
634+
lora_mapping = LoRAMapping(index_mapping,
635+
prompt_mapping,
636+
is_prefill=stage)
637+
638+
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
639+
512, lora_config.lora_extra_vocab_size)
640+
641+
lora_result = lora_linear(torch.cat(inputs))[0]
642+
expected_result = linear(torch.cat(inputs))[0]
643+
644+
rtol, atol = TOLERANCES[lora_result.dtype]
645+
assert torch.allclose(lora_result,
646+
expected_result,
647+
rtol=rtol,
648+
atol=atol)
649+
650+
548651
@torch.inference_mode()
549652
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
550653
@pytest.mark.parametrize("orientation", ["row", "column"])

vllm/lora/layers.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2222
MergedColumnParallelLinear,
2323
QKVParallelLinear,
24+
ReplicatedLinear,
2425
RowParallelLinear)
2526
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2627
from vllm.model_executor.layers.rotary_embedding import (
@@ -262,6 +263,99 @@ def can_replace_layer(
262263
return type(source_layer) is VocabParallelEmbedding
263264

264265

266+
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
267+
268+
def __init__(self, base_layer: ReplicatedLinear) -> None:
269+
super().__init__()
270+
self.base_layer = base_layer
271+
self.input_size = self.base_layer.input_size
272+
self.output_size = self.base_layer.output_size
273+
self.device = _get_lora_device(self.base_layer)
274+
275+
def create_lora_weights(
276+
self,
277+
max_loras: int,
278+
lora_config: LoRAConfig,
279+
model_config: Optional[PretrainedConfig] = None,
280+
) -> None:
281+
self.lora_config = lora_config
282+
lora_a_output_size = lora_config.max_lora_rank
283+
self.lora_a_stacked = torch.zeros(
284+
max_loras,
285+
1,
286+
lora_a_output_size,
287+
self.input_size,
288+
dtype=lora_config.lora_dtype,
289+
device=self.device,
290+
)
291+
self.lora_b_stacked = torch.zeros(
292+
max_loras,
293+
1,
294+
self.output_size,
295+
lora_config.max_lora_rank,
296+
dtype=lora_config.lora_dtype,
297+
device=self.device,
298+
)
299+
300+
def reset_lora(self, index: int):
301+
self.lora_a_stacked[index] = 0
302+
self.lora_b_stacked[index] = 0
303+
304+
def set_lora(
305+
self,
306+
index: int,
307+
lora_a: torch.Tensor,
308+
lora_b: torch.Tensor,
309+
embeddings_tensor: Optional[torch.Tensor],
310+
):
311+
self.reset_lora(index)
312+
313+
self.lora_a_stacked[index,
314+
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
315+
lora_a.T, non_blocking=True)
316+
self.lora_b_stacked[index,
317+
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
318+
lora_b.T, non_blocking=True)
319+
320+
def apply(self, x: torch.Tensor,
321+
bias: Optional[torch.Tensor]) -> torch.Tensor:
322+
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
323+
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
324+
self.lora_b_stacked, 1.0)
325+
return output
326+
327+
def forward(self, input_):
328+
"""Forward of ReplicatedLinearWithLoRA
329+
330+
Args:
331+
input_: Tensor whose last dimension is `input_size`.
332+
333+
Returns:
334+
- output
335+
- bias
336+
"""
337+
bias = (self.base_layer.bias
338+
if not self.base_layer.skip_bias_add else None)
339+
340+
# Matrix multiply.
341+
output = self.apply(input_, bias)
342+
343+
output_bias = (self.base_layer.bias
344+
if self.base_layer.skip_bias_add else None)
345+
return output, output_bias
346+
347+
@classmethod
348+
@_not_fully_sharded_can_replace
349+
def can_replace_layer(
350+
cls,
351+
source_layer: nn.Module,
352+
lora_config: LoRAConfig,
353+
packed_modules_list: List,
354+
model_config: Optional[PretrainedConfig],
355+
) -> bool:
356+
return type(source_layer) is ReplicatedLinear
357+
358+
265359
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
266360
"""
267361
LoRA on top of ColumnParallelLinear layer.

vllm/lora/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
MergedColumnParallelLinearWithLoRA,
2424
MergedQKVParallelLinearWithLora,
2525
QKVParallelLinearWithLora,
26+
ReplicatedLinearWithLoRA,
2627
RowParallelLinearWithLoRA,
2728
VocabParallelEmbeddingWithLoRA)
2829
# yapf: enable
@@ -38,6 +39,7 @@
3839
QKVParallelLinearWithLora,
3940
MergedQKVParallelLinearWithLora,
4041
RowParallelLinearWithLoRA,
42+
ReplicatedLinearWithLoRA,
4143
LogitsProcessorWithLoRA,
4244
ColumnParallelLinearWithShardedLoRA,
4345
QKVParallelLinearWithShardedLora,

0 commit comments

Comments
 (0)