diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 9df3a07a9e5e..43ee06afeda7 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -137,6 +137,8 @@ def populate_loras( layer: BaseLayerWithLoRA, layer_weights: torch.Tensor, repeats: int = 1, + use_zero_lora_a: bool = False, + use_zero_lora_b: bool = False, ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: """This method populates the lora layers with lora weights. @@ -168,6 +170,8 @@ def populate_loras( sublora = DummyLoRAManager(layer_weights.device).init_random_lora( module_name=f"fake_{i}", weight=layer_weights, + use_zero_lora_a=use_zero_lora_a, + use_zero_lora_b=use_zero_lora_b, ) sublora.lora_b = sublora.lora_b[ (sublora_len * i) : (sublora_len * (i + 1)), : @@ -347,6 +351,80 @@ def create_random_embedding_layer(): torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2]) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("stage", STAGES) +@pytest.mark.parametrize( + "use_zero_lora_a,use_zero_lora_b", [(True, False), (False, True), (True, True)] +) +def test_embeddings_with_zero_weights( + dist_init, num_loras, device, stage, use_zero_lora_a, use_zero_lora_b +) -> None: + """Ensure that LoRA embeddings with zero matrices produce identical + outputs to the base layer.""" + # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA + # device, see: https://github.com/triton-lang/triton/issues/2925 + # Same below. + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + + torch.set_default_device(device) + max_loras = 8 + vocab_size = 512 + + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(vocab_size, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[vocab_size:, :] = 0 + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + return embedding, lora_embedding + + # Only run a single random seed as the behavior LoRA layers + # should be equivalent to just adding zeros to the base layer. + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + lora_embedding.set_mapping(punica_wrapper) + + # Explicitly set lora_a and/or lora_b to zero + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + use_zero_lora_a=use_zero_lora_a, + use_zero_lora_b=use_zero_lora_b, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200,), + input_range=(1, vocab_size), + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + ) + + cat_inputs = torch.cat(inputs) + lora_result = lora_embedding(cat_inputs) + base_layer_result = lora_embedding.base_layer.forward(cat_inputs) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, base_layer_result, rtol=rtol, atol=atol) + + @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 6aba5299b582..d6f1a6b3a876 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -23,21 +23,36 @@ def set_module_lora(self, module_name: str, lora: LoRALayerWeights): def get_module_lora(self, module_name: str) -> LoRALayerWeights: return self._loras[module_name] + @staticmethod + def get_lora_mat(size, is_zero, **kwargs): + if is_zero: + return torch.zeros(size, **kwargs) + return torch.rand(size, **kwargs) + def init_random_lora( self, module_name: str, weight: torch.Tensor, rank: int = 8, + use_zero_lora_a: bool = False, + use_zero_lora_b: bool = False, ): lora = LoRALayerWeights( module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand( - [rank, weight.shape[1]], dtype=weight.dtype, device=self._device + # lora_a / lora_b are random unless we explicitly say to use 0 mats + lora_a=self.get_lora_mat( + [rank, weight.shape[1]], + is_zero=use_zero_lora_a, + dtype=weight.dtype, + device=self._device, ), - lora_b=torch.rand( - [weight.shape[0], rank], dtype=weight.dtype, device=self._device + lora_b=self.get_lora_mat( + [weight.shape[0], rank], + is_zero=use_zero_lora_b, + dtype=weight.dtype, + device=self._device, ), ) self.set_module_lora(module_name, lora) diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index 5b1f7886bc23..2fad95650c08 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -91,6 +91,13 @@ def set_lora( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + # Short circuit and just run the base layer if either A or B + # are all zero we can just call the base layer directly. + if bool(torch.all(self.lora_a_stacked == 0)) or bool( + torch.all(self.lora_b_stacked == 0) + ): + return self.base_layer.forward(x) + # NB: Don't use torch.narrow here. torch.narrow triggers some # Dynamic Shape specialization in torch.compile num_tokens = x.shape[0]