|
22 | 22 | MergedColumnParallelLinearWithLoRA,
|
23 | 23 | MergedQKVParallelLinearWithLora,
|
24 | 24 | QKVParallelLinearWithLora,
|
| 25 | + ReplicatedLinearWithLoRA, |
25 | 26 | RowParallelLinearWithLoRA,
|
26 | 27 | VocabParallelEmbeddingWithLoRA)
|
27 | 28 | # yapf: enable
|
|
31 | 32 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
32 | 33 | MergedColumnParallelLinear,
|
33 | 34 | QKVParallelLinear,
|
| 35 | + ReplicatedLinear, |
34 | 36 | RowParallelLinear)
|
35 | 37 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
36 | 38 | from vllm.model_executor.layers.rotary_embedding import get_rope
|
@@ -545,6 +547,107 @@ def _pretest():
|
545 | 547 | atol=atol)
|
546 | 548 |
|
547 | 549 |
|
| 550 | +@torch.inference_mode() |
| 551 | +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) |
| 552 | +@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 553 | +@pytest.mark.parametrize("stage", STAGES) |
| 554 | +def test_linear_replicated(dist_init, num_loras, device, stage) -> None: |
| 555 | + |
| 556 | + torch.set_default_device(device) |
| 557 | + punica_wrapper = PunicaWrapper(8192, 256, device) |
| 558 | + max_loras = 8 |
| 559 | + lora_config = LoRAConfig(max_loras=max_loras, |
| 560 | + max_lora_rank=8, |
| 561 | + lora_dtype=torch.float16) |
| 562 | + |
| 563 | + def create_random_linear_replicated_layer(): |
| 564 | + |
| 565 | + linear = ReplicatedLinear(4096, |
| 566 | + 4096, |
| 567 | + bias=False, |
| 568 | + params_dtype=torch.float16) |
| 569 | + linear.weight.data = torch.rand_like(linear.weight.data) |
| 570 | + lora_linear = ReplicatedLinearWithLoRA(linear) |
| 571 | + |
| 572 | + lora_linear.create_lora_weights(max_loras, lora_config) |
| 573 | + |
| 574 | + return linear, lora_linear |
| 575 | + |
| 576 | + for i in range(10): |
| 577 | + set_random_seed(i) |
| 578 | + |
| 579 | + id_to_index = get_random_id_to_index(num_loras, max_loras) |
| 580 | + linear, lora_linear = create_random_linear_replicated_layer() |
| 581 | + lora_linear.set_mapping(punica_wrapper) |
| 582 | + lora_dict, _ = populate_loras( |
| 583 | + id_to_index, |
| 584 | + layer=lora_linear, |
| 585 | + layer_weights=linear.weight, |
| 586 | + ) |
| 587 | + |
| 588 | + inputs, index_mapping, prompt_mapping = create_random_inputs( |
| 589 | + active_lora_ids=list(lora_dict.keys()), |
| 590 | + num_inputs=32 * num_loras, |
| 591 | + input_size=(1, 4096), |
| 592 | + input_range=(0, 1), |
| 593 | + input_type=torch.float16, |
| 594 | + ) |
| 595 | + lora_mapping = LoRAMapping(index_mapping, |
| 596 | + prompt_mapping, |
| 597 | + is_prefill=stage) |
| 598 | + punica_wrapper.update_metadata( |
| 599 | + lora_mapping, |
| 600 | + id_to_index, |
| 601 | + max_loras, |
| 602 | + 512, |
| 603 | + lora_config.lora_extra_vocab_size, |
| 604 | + ) |
| 605 | + |
| 606 | + lora_result = lora_linear(torch.cat(inputs))[0] |
| 607 | + |
| 608 | + expected_results: List[torch.Tensor] = [] |
| 609 | + for input_, lora_id in zip(inputs, prompt_mapping): |
| 610 | + lora = lora_dict[lora_id] |
| 611 | + result = linear(input_)[0] |
| 612 | + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling |
| 613 | + expected_results.append(result) |
| 614 | + expected_result = torch.cat(expected_results) |
| 615 | + |
| 616 | + rtol, atol = TOLERANCES[lora_result.dtype] |
| 617 | + assert torch.allclose(lora_result, |
| 618 | + expected_result, |
| 619 | + rtol=rtol, |
| 620 | + atol=atol) |
| 621 | + |
| 622 | + # Check that resetting the lora weights succeeds |
| 623 | + |
| 624 | + for slot_idx in range(max_loras): |
| 625 | + lora_linear.reset_lora(slot_idx) |
| 626 | + |
| 627 | + inputs, index_mapping, prompt_mapping = create_random_inputs( |
| 628 | + active_lora_ids=[0], |
| 629 | + num_inputs=32 * num_loras, |
| 630 | + input_size=(1, 4096), |
| 631 | + input_range=(0, 1), |
| 632 | + input_type=torch.float16, |
| 633 | + ) |
| 634 | + lora_mapping = LoRAMapping(index_mapping, |
| 635 | + prompt_mapping, |
| 636 | + is_prefill=stage) |
| 637 | + |
| 638 | + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, |
| 639 | + 512, lora_config.lora_extra_vocab_size) |
| 640 | + |
| 641 | + lora_result = lora_linear(torch.cat(inputs))[0] |
| 642 | + expected_result = linear(torch.cat(inputs))[0] |
| 643 | + |
| 644 | + rtol, atol = TOLERANCES[lora_result.dtype] |
| 645 | + assert torch.allclose(lora_result, |
| 646 | + expected_result, |
| 647 | + rtol=rtol, |
| 648 | + atol=atol) |
| 649 | + |
| 650 | + |
548 | 651 | @torch.inference_mode()
|
549 | 652 | @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
550 | 653 | @pytest.mark.parametrize("orientation", ["row", "column"])
|
|
0 commit comments