From 8ffd97f78038657522524c6fc34b3502c4f5ec64 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 20 Nov 2025 00:09:13 +0000 Subject: [PATCH 1/7] Support leaving lora a / b as zero in random init util Signed-off-by: Alex-Brooks --- tests/lora/utils.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 6aba5299b582..d6f1a6b3a876 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -23,21 +23,36 @@ def set_module_lora(self, module_name: str, lora: LoRALayerWeights): def get_module_lora(self, module_name: str) -> LoRALayerWeights: return self._loras[module_name] + @staticmethod + def get_lora_mat(size, is_zero, **kwargs): + if is_zero: + return torch.zeros(size, **kwargs) + return torch.rand(size, **kwargs) + def init_random_lora( self, module_name: str, weight: torch.Tensor, rank: int = 8, + use_zero_lora_a: bool = False, + use_zero_lora_b: bool = False, ): lora = LoRALayerWeights( module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand( - [rank, weight.shape[1]], dtype=weight.dtype, device=self._device + # lora_a / lora_b are random unless we explicitly say to use 0 mats + lora_a=self.get_lora_mat( + [rank, weight.shape[1]], + is_zero=use_zero_lora_a, + dtype=weight.dtype, + device=self._device, ), - lora_b=torch.rand( - [weight.shape[0], rank], dtype=weight.dtype, device=self._device + lora_b=self.get_lora_mat( + [weight.shape[0], rank], + is_zero=use_zero_lora_b, + dtype=weight.dtype, + device=self._device, ), ) self.set_module_lora(module_name, lora) From 006fe229ccd7f943fdd44944b371bee976754265 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 20 Nov 2025 00:09:57 +0000 Subject: [PATCH 2/7] Add test explicitly comparing outputs to base layer for zero lora Signed-off-by: Alex-Brooks --- tests/lora/test_layers.py | 229 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 9df3a07a9e5e..7f1f5d4bed6d 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -137,6 +137,8 @@ def populate_loras( layer: BaseLayerWithLoRA, layer_weights: torch.Tensor, repeats: int = 1, + use_zero_lora_a: bool = False, + use_zero_lora_b: bool = False, ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: """This method populates the lora layers with lora weights. @@ -168,6 +170,12 @@ def populate_loras( sublora = DummyLoRAManager(layer_weights.device).init_random_lora( module_name=f"fake_{i}", weight=layer_weights, +<<<<<<< HEAD +======= + generate_embeddings_tensor=generate_embeddings_tensor, + use_zero_lora_a=use_zero_lora_a, + use_zero_lora_b=use_zero_lora_b, +>>>>>>> 0f3de2ff0 (Add test explicitly comparing outputs to base layer for zero lora) ) sublora.lora_b = sublora.lora_b[ (sublora_len * i) : (sublora_len * (i + 1)), : @@ -348,6 +356,227 @@ def create_random_embedding_layer(): @torch.inference_mode() +<<<<<<< HEAD +======= +@pytest.mark.parametrize("num_loras", [1, 2]) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("stage", STAGES) +@pytest.mark.parametrize( + "use_zero_lora_a,use_zero_lora_b", [(True, False), (False, True), (True, True)] +) +def test_embeddings_with_zero_weights( + dist_init, num_loras, device, stage, use_zero_lora_a, use_zero_lora_b +) -> None: + """Ensure that LoRA embeddings with zero matrices produce identical + outputs to the base layer.""" + # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA + # device, see: https://github.com/triton-lang/triton/issues/2925 + # Same below. + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + + torch.set_default_device(device) + max_loras = 8 + vocab_size = 512 + + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(vocab_size, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[vocab_size:, :] = 0 + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + return embedding, lora_embedding + + for i in range(NUM_RANDOM_SEEDS): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + lora_embedding.set_mapping(punica_wrapper) + # Don't populate - stacked tensors are originally 0 anyway + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + use_zero_lora_b=use_zero_lora_a, + use_zero_lora_a=use_zero_lora_b, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200,), + input_range=(1, vocab_size), + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) + + cat_inputs = torch.cat(inputs) + lora_result = lora_embedding(cat_inputs) + base_layer_result = lora_embedding.base_layer.forward(cat_inputs) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, base_layer_result, rtol=rtol, atol=atol) + + +@torch.inference_mode() +# @pytest.mark.skip( +# reason="Fails when loras are in any slot other than the first.") +@pytest.mark.parametrize("num_loras", [1, 2, 4]) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +@pytest.mark.parametrize("stage", STAGES) +def test_embeddings_with_new_embeddings( + dist_init, num_loras, device, vocab_size, stage +) -> None: + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + + torch.set_default_device(device) + max_loras = 8 + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(vocab_size, 256) + embedding_data = torch.rand_like(embedding.weight.data) + embedding.weight.data = embedding_data + embedding.weight.data[vocab_size:, :] = 0 + expanded_embedding = VocabParallelEmbedding( + vocab_size + lora_config.lora_extra_vocab_size * max_loras, + 256, + org_num_embeddings=vocab_size, + ) + expanded_embedding.weight.data[:vocab_size, :] = embedding_data + # We need to deepcopy the embedding as it will be modified + # in place + lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding)) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return expanded_embedding, lora_embedding + + for i in range(NUM_RANDOM_SEEDS): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + expanded_embedding, lora_embedding = create_random_embedding_layer() + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=torch.zeros( + (256, vocab_size + lora_config.lora_extra_vocab_size) + ), + generate_embeddings_tensor=256, + ) + + lora_embedding.set_mapping(punica_wrapper) + # All embeddings tensors have the same shape. + embeddings_tensors = [ + lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) + ] + embeddings_tensor_len = embeddings_tensors[0].shape[0] + + # Add empty embeddings_tensors for unoccupied lora slots. + for _ in range(max_loras - len(embeddings_tensors)): + embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200,), + input_range=(1, vocab_size), + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) + original_inputs = deepcopy(inputs) + + # Force some of the inputs to be in the extended embeddings range + # to guarantee that their behavior is tested. + for input_, original_input_, lora_id in zip( + inputs, original_inputs, prompt_mapping + ): + embedding_id = lora_id - 1 + input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) + original_input_[-1] = vocab_size + input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = vocab_size + embeddings_tensor_len - 1 + + expanded_embedding.weight[ + vocab_size : vocab_size + (embeddings_tensor_len * max_loras) + ] = torch.cat(embeddings_tensors) + + lora_result = lora_embedding(torch.cat(original_inputs)) + + expected_results: list[torch.Tensor] = [] + for input_, original_input_, lora_id in zip( + inputs, original_inputs, prompt_mapping + ): + lora = lora_dict[lora_id] + result = expanded_embedding(input_) + after_a = F.embedding( + original_input_, + lora.lora_a.T, + ) + result += after_a @ lora.lora_b.T + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200,), + input_range=(1, vocab_size), + device=device, + ) + original_inputs = deepcopy(inputs) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) + lora_result = lora_embedding(torch.cat(original_inputs)) + expected_result = expanded_embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) + + +@torch.inference_mode() +>>>>>>> 0f3de2ff0 (Add test explicitly comparing outputs to base layer for zero lora) @pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) From c6ab03f26362d71b6b1bb56f946a393416247a9e Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 20 Nov 2025 00:16:47 +0000 Subject: [PATCH 3/7] Only test one random seed Signed-off-by: Alex-Brooks --- tests/lora/test_layers.py | 68 +++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 7f1f5d4bed6d..4a158f46cdb2 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -393,43 +393,43 @@ def create_random_embedding_layer(): lora_embedding.create_lora_weights(max_loras, lora_config) return embedding, lora_embedding - for i in range(NUM_RANDOM_SEEDS): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - embedding, lora_embedding = create_random_embedding_layer() - lora_embedding.set_mapping(punica_wrapper) - # Don't populate - stacked tensors are originally 0 anyway - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_embedding, - layer_weights=embedding.weight.T, - use_zero_lora_b=use_zero_lora_a, - use_zero_lora_a=use_zero_lora_b, - ) + # Only run a single random seed as the behavior LoRA layers + # should be equivalent to just adding zeros to the base layer. + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + lora_embedding.set_mapping(punica_wrapper) + + # Explicitly set lora_a and/or lora_b to zero + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + use_zero_lora_b=use_zero_lora_a, + use_zero_lora_a=use_zero_lora_b, + ) - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=num_loras * 3, - input_size=(200,), - input_range=(1, vocab_size), - device=device, - ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200,), + input_range=(1, vocab_size), + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) - cat_inputs = torch.cat(inputs) - lora_result = lora_embedding(cat_inputs) - base_layer_result = lora_embedding.base_layer.forward(cat_inputs) + cat_inputs = torch.cat(inputs) + lora_result = lora_embedding(cat_inputs) + base_layer_result = lora_embedding.base_layer.forward(cat_inputs) - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, base_layer_result, rtol=rtol, atol=atol) + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, base_layer_result, rtol=rtol, atol=atol) @torch.inference_mode() From 263cce8472880a6d7cf190b80275c2cb859dda8d Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 20 Nov 2025 11:27:55 +0000 Subject: [PATCH 4/7] only short circuit to base layer if no tokens are added Signed-off-by: Alex-Brooks --- vllm/lora/layers/vocal_parallel_embedding.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index 5b1f7886bc23..113aaf71aa19 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -91,6 +91,12 @@ def set_lora( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) + # Short circuit and just run the base layer if either A or B + # are all zero we can just call the base layer directly. + if bool(torch.all(self.lora_a_stacked == 0)) or bool(torch.all(self.lora_b_stacked == 0)): + return self.base_layer.forward(x) + # NB: Don't use torch.narrow here. torch.narrow triggers some # Dynamic Shape specialization in torch.compile num_tokens = x.shape[0] From ef1c56c5e15225f765834c199d243f6f8f0f9320 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 21 Nov 2025 07:57:40 +0000 Subject: [PATCH 5/7] remove vocab from zero lora test Signed-off-by: Alex-Brooks --- tests/lora/test_layers.py | 151 -------------------------------------- 1 file changed, 151 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 4a158f46cdb2..2483729ee0e2 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -170,12 +170,8 @@ def populate_loras( sublora = DummyLoRAManager(layer_weights.device).init_random_lora( module_name=f"fake_{i}", weight=layer_weights, -<<<<<<< HEAD -======= - generate_embeddings_tensor=generate_embeddings_tensor, use_zero_lora_a=use_zero_lora_a, use_zero_lora_b=use_zero_lora_b, ->>>>>>> 0f3de2ff0 (Add test explicitly comparing outputs to base layer for zero lora) ) sublora.lora_b = sublora.lora_b[ (sublora_len * i) : (sublora_len * (i + 1)), : @@ -356,8 +352,6 @@ def create_random_embedding_layer(): @torch.inference_mode() -<<<<<<< HEAD -======= @pytest.mark.parametrize("num_loras", [1, 2]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @@ -421,7 +415,6 @@ def create_random_embedding_layer(): id_to_index, max_loras, vocab_size, - lora_config.lora_extra_vocab_size, ) cat_inputs = torch.cat(inputs) @@ -433,150 +426,6 @@ def create_random_embedding_layer(): @torch.inference_mode() -# @pytest.mark.skip( -# reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4]) -@pytest.mark.parametrize("device", DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -@pytest.mark.parametrize("stage", STAGES) -def test_embeddings_with_new_embeddings( - dist_init, num_loras, device, vocab_size, stage -) -> None: - if current_platform.is_cuda_alike(): - torch.cuda.set_device(device) - - torch.set_default_device(device) - max_loras = 8 - punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) - assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig( - max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 - ) - - def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(vocab_size, 256) - embedding_data = torch.rand_like(embedding.weight.data) - embedding.weight.data = embedding_data - embedding.weight.data[vocab_size:, :] = 0 - expanded_embedding = VocabParallelEmbedding( - vocab_size + lora_config.lora_extra_vocab_size * max_loras, - 256, - org_num_embeddings=vocab_size, - ) - expanded_embedding.weight.data[:vocab_size, :] = embedding_data - # We need to deepcopy the embedding as it will be modified - # in place - lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding)) - lora_embedding.create_lora_weights(max_loras, lora_config) - - return expanded_embedding, lora_embedding - - for i in range(NUM_RANDOM_SEEDS): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - expanded_embedding, lora_embedding = create_random_embedding_layer() - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_embedding, - layer_weights=torch.zeros( - (256, vocab_size + lora_config.lora_extra_vocab_size) - ), - generate_embeddings_tensor=256, - ) - - lora_embedding.set_mapping(punica_wrapper) - # All embeddings tensors have the same shape. - embeddings_tensors = [ - lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) - ] - embeddings_tensor_len = embeddings_tensors[0].shape[0] - - # Add empty embeddings_tensors for unoccupied lora slots. - for _ in range(max_loras - len(embeddings_tensors)): - embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=num_loras * 3, - input_size=(200,), - input_range=(1, vocab_size), - device=device, - ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - original_inputs = deepcopy(inputs) - - # Force some of the inputs to be in the extended embeddings range - # to guarantee that their behavior is tested. - for input_, original_input_, lora_id in zip( - inputs, original_inputs, prompt_mapping - ): - embedding_id = lora_id - 1 - input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) - original_input_[-1] = vocab_size - input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - - expanded_embedding.weight[ - vocab_size : vocab_size + (embeddings_tensor_len * max_loras) - ] = torch.cat(embeddings_tensors) - - lora_result = lora_embedding(torch.cat(original_inputs)) - - expected_results: list[torch.Tensor] = [] - for input_, original_input_, lora_id in zip( - inputs, original_inputs, prompt_mapping - ): - lora = lora_dict[lora_id] - result = expanded_embedding(input_) - after_a = F.embedding( - original_input_, - lora.lora_a.T, - ) - result += after_a @ lora.lora_b.T - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_embedding.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=num_loras * 3, - input_size=(200,), - input_range=(1, vocab_size), - device=device, - ) - original_inputs = deepcopy(inputs) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - lora_result = lora_embedding(torch.cat(original_inputs)) - expected_result = expanded_embedding(torch.cat(inputs)) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) - - -@torch.inference_mode() ->>>>>>> 0f3de2ff0 (Add test explicitly comparing outputs to base layer for zero lora) @pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) From dddca287e754d0cf492f1b5cd8d31b4144acd5df Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 21 Nov 2025 07:58:43 +0000 Subject: [PATCH 6/7] remove unused extra vocab mask Signed-off-by: Alex-Brooks --- vllm/lora/layers/vocal_parallel_embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index 113aaf71aa19..2fad95650c08 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -91,10 +91,11 @@ def set_lora( ) def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) # Short circuit and just run the base layer if either A or B # are all zero we can just call the base layer directly. - if bool(torch.all(self.lora_a_stacked == 0)) or bool(torch.all(self.lora_b_stacked == 0)): + if bool(torch.all(self.lora_a_stacked == 0)) or bool( + torch.all(self.lora_b_stacked == 0) + ): return self.base_layer.forward(x) # NB: Don't use torch.narrow here. torch.narrow triggers some From ecd19f4a8e06df37cf3c7a5cf50370ba9b924820 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 21 Nov 2025 08:44:35 +0000 Subject: [PATCH 7/7] fix variable swap Signed-off-by: Alex-Brooks --- tests/lora/test_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 2483729ee0e2..43ee06afeda7 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -398,8 +398,8 @@ def create_random_embedding_layer(): id_to_index, layer=lora_embedding, layer_weights=embedding.weight.T, - use_zero_lora_b=use_zero_lora_a, - use_zero_lora_a=use_zero_lora_b, + use_zero_lora_a=use_zero_lora_a, + use_zero_lora_b=use_zero_lora_b, ) inputs, index_mapping, prompt_mapping = create_random_inputs(