Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def populate_loras(
layer: BaseLayerWithLoRA,
layer_weights: torch.Tensor,
repeats: int = 1,
use_zero_lora_a: bool = False,
use_zero_lora_b: bool = False,
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
"""This method populates the lora layers with lora weights.

Expand Down Expand Up @@ -168,6 +170,8 @@ def populate_loras(
sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
module_name=f"fake_{i}",
weight=layer_weights,
use_zero_lora_a=use_zero_lora_a,
use_zero_lora_b=use_zero_lora_b,
)
sublora.lora_b = sublora.lora_b[
(sublora_len * i) : (sublora_len * (i + 1)), :
Expand Down Expand Up @@ -347,6 +351,80 @@ def create_random_embedding_layer():
torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2])
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize(
"use_zero_lora_a,use_zero_lora_b", [(True, False), (False, True), (True, True)]
)
def test_embeddings_with_zero_weights(
dist_init, num_loras, device, stage, use_zero_lora_a, use_zero_lora_b
) -> None:
"""Ensure that LoRA embeddings with zero matrices produce identical
outputs to the base layer."""
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
# device, see: https://github.com/triton-lang/triton/issues/2925
# Same below.
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
max_loras = 8
vocab_size = 512

punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
)

def create_random_embedding_layer():
embedding = VocabParallelEmbedding(vocab_size, 256)
embedding.weight.data = torch.rand_like(embedding.weight.data)
embedding.weight.data[vocab_size:, :] = 0
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
lora_embedding.create_lora_weights(max_loras, lora_config)
return embedding, lora_embedding

# Only run a single random seed as the behavior LoRA layers
# should be equivalent to just adding zeros to the base layer.
id_to_index = get_random_id_to_index(num_loras, max_loras)
embedding, lora_embedding = create_random_embedding_layer()
lora_embedding.set_mapping(punica_wrapper)

# Explicitly set lora_a and/or lora_b to zero
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_embedding,
layer_weights=embedding.weight.T,
use_zero_lora_a=use_zero_lora_a,
use_zero_lora_b=use_zero_lora_b,
)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200,),
input_range=(1, vocab_size),
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping,
id_to_index,
max_loras,
vocab_size,
)

cat_inputs = torch.cat(inputs)
lora_result = lora_embedding(cat_inputs)
base_layer_result = lora_embedding.base_layer.forward(cat_inputs)

rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(lora_result, base_layer_result, rtol=rtol, atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4])
@pytest.mark.parametrize("device", DEVICES)
Expand Down
23 changes: 19 additions & 4 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,36 @@ def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
def get_module_lora(self, module_name: str) -> LoRALayerWeights:
return self._loras[module_name]

@staticmethod
def get_lora_mat(size, is_zero, **kwargs):
if is_zero:
return torch.zeros(size, **kwargs)
return torch.rand(size, **kwargs)

def init_random_lora(
self,
module_name: str,
weight: torch.Tensor,
rank: int = 8,
use_zero_lora_a: bool = False,
use_zero_lora_b: bool = False,
):
lora = LoRALayerWeights(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand(
[rank, weight.shape[1]], dtype=weight.dtype, device=self._device
# lora_a / lora_b are random unless we explicitly say to use 0 mats
lora_a=self.get_lora_mat(
[rank, weight.shape[1]],
is_zero=use_zero_lora_a,
dtype=weight.dtype,
device=self._device,
),
lora_b=torch.rand(
[weight.shape[0], rank], dtype=weight.dtype, device=self._device
lora_b=self.get_lora_mat(
[weight.shape[0], rank],
is_zero=use_zero_lora_b,
dtype=weight.dtype,
device=self._device,
),
)
self.set_module_lora(module_name, lora)
Expand Down
7 changes: 7 additions & 0 deletions vllm/lora/layers/vocal_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def set_lora(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Short circuit and just run the base layer if either A or B
# are all zero we can just call the base layer directly.
if bool(torch.all(self.lora_a_stacked == 0)) or bool(
torch.all(self.lora_b_stacked == 0)
):
return self.base_layer.forward(x)
Comment on lines 93 to +99

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Short-circuit scans full LoRA stacks every forward

The new zero check reduces over self.lora_a_stacked and self.lora_b_stacked on every call to forward, converting the GPU result to a Python bool. These tensors are sized max_loras × vocab_size × rank and live on GPU; scanning them each token adds O(max_loras·vocab) work plus a host sync even when LoRA weights are non-zero, which is far heavier than the previous gather-based path and will noticeably slow embedding lookups for any LoRA-enabled run. Consider caching a flag when weights are loaded instead of recomputing a full reduction per forward.

Useful? React with 👍 / 👎.


# NB: Don't use torch.narrow here. torch.narrow triggers some
# Dynamic Shape specialization in torch.compile
num_tokens = x.shape[0]
Expand Down