|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
4 | 4 | import os
|
| 5 | +from typing import Optional |
5 | 6 |
|
6 | 7 | import pytest
|
7 | 8 | import torch
|
|
26 | 27 | from vllm.lora.request import LoRARequest
|
27 | 28 | from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager
|
28 | 29 | from vllm.platforms import current_platform
|
| 30 | +from vllm.validation.plugins import ( |
| 31 | + ModelType, |
| 32 | + ModelValidationPlugin, |
| 33 | + ModelValidationPluginRegistry, |
| 34 | +) |
29 | 35 |
|
30 | 36 | from .utils import create_peft_lora
|
31 | 37 |
|
@@ -714,3 +720,36 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
714 | 720 | torch.testing.assert_close(
|
715 | 721 | packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
|
716 | 722 | )
|
| 723 | + |
| 724 | + |
| 725 | +class MyModelValidator(ModelValidationPlugin): |
| 726 | + def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool: |
| 727 | + return True |
| 728 | + |
| 729 | + def validate_model( |
| 730 | + self, model_type: ModelType, model_path: str, model: Optional[str] = None |
| 731 | + ) -> None: |
| 732 | + raise BaseException("Model did not validate") |
| 733 | + |
| 734 | + |
| 735 | +def test_worker_adapter_manager_security_policy(dist_init, dummy_model_gate_up): |
| 736 | + my_model_validator = MyModelValidator() |
| 737 | + ModelValidationPluginRegistry.register_plugin("test", my_model_validator) |
| 738 | + |
| 739 | + lora_config = LoRAConfig( |
| 740 | + max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE |
| 741 | + ) |
| 742 | + |
| 743 | + model_config = ModelConfig(max_model_len=16) |
| 744 | + vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config) |
| 745 | + |
| 746 | + worker_adapter_manager = WorkerLoRAManager( |
| 747 | + vllm_config, "cpu", EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES |
| 748 | + ) |
| 749 | + worker_adapter_manager.create_lora_manager(dummy_model_gate_up) |
| 750 | + |
| 751 | + mapping = LoRAMapping([], []) |
| 752 | + with pytest.raises(BaseException, match="Model did not validate"): |
| 753 | + worker_adapter_manager.set_active_adapters( |
| 754 | + [LoRARequest("1", 1, "/not/used")], mapping |
| 755 | + ) |
0 commit comments