Skip to content

Commit a86ece5

Browse files
authored
[Bugfix][LoRA] Fix forward error and shape mismatch when using LoRA (#3153)
### What this PR does / why we need it? Relying on #3044, this PR aims to further fix: 1. The forward error occured when `LogitsProcessorWithLoRA` calls `AscendLogitsProcessor.forward`. Since `LogitsProcessorWithLoRA` bypasses the MRO to call it, `super().forward(...)` in `AscendLogitsProcessor.forward` will raise an error. This PR fixes it by directly invoking `LogitsProcessor.forward(self, ...)`; 2. The shape mismatch in `add_lora_logits` in punica_npu.py. The `lora_a_stacked` and `lora_b_stacked` are organized as [num_loras, 1, lora_rank, hidden_size] and [num_loras, 1, vocab_size, lora_rank] shapes respectively, but they are misunderstood in #1583---the last two dimensions were assumed in reverse order, which causes errors in `bgmv_shrink` and `bgmv_expand`. This PR fixes it by reverting it to the previous version to align with the implementation in punica_cpu.py in vllm. ### Dependencies This PR depends on changes introduced by #3044 (LoRA support for `AscendQKVParallelLinear` and `AscendMergedQKVParallelLinear` layers). ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? The LoRA-related tests, e.g., test_ilama_lora.py and test_ilama_lora_tp2.py, use ilama-3.2-1B, and this model is regarded as `TransformersForCausalLM`, where `embedding_modules` attribute lacks `lm_head`. However, `LlamaForCausalLM` and most other models include both `embed_tokens` and `lm_head` in `embedding_modules`. This attribute contributes to `supported_lora_modules` when using LoRA in vllm. Therefore, without `lm_head` in `embedding_modules`, current tests using ilama-3.2-1B are unable to find the abve errors since `LogitsProcessorWithLoRA` replacing `lm_head` is skipped. Simply using Meta-Llama-3.1-8B-Instruct can reproduce the above errors and check whether these fixes can work. What's more, it's necessary to add more comprehensive tests for LoRA. - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@f225ea7 Signed-off-by: Zetong Li <[email protected]>
1 parent 3d21ed9 commit a86ece5

File tree

2 files changed

+7
-17
lines changed

2 files changed

+7
-17
lines changed

vllm_ascend/lora/punica_npu.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -341,27 +341,16 @@ def add_lora_logits(self,
341341
y_org = y
342342
y = y.view(-1, y.shape[-1])
343343
x = x.view(-1, x.shape[-1])
344-
345-
if lora_a_stacked.dim() == 2:
346-
lora_a_stacked = lora_a_stacked.unsqueeze(0)
347-
if lora_b_stacked.dim() == 2:
348-
lora_b_stacked = lora_b_stacked.unsqueeze(0)
349-
350-
r = lora_a_stacked.size(-1)
344+
r = lora_b_stacked.size(-1)
351345

352346
if buffer is None:
353347
buffer = torch.zeros((x.size(0), r),
354348
dtype=torch.float32,
355349
device=x.device)
356350

357351
indices = self.sampler_indices
358-
if indices.max() >= lora_a_stacked.size(0):
359-
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
360-
361-
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
362-
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
363352

364-
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
365-
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
353+
bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
354+
bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
366355

367356
y = y.view_as(y_org)

vllm_ascend/ops/vocab_parallel_embedding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def forward(
262262
sampling_metadata=None, # type: ignore
263263
embedding_bias: Optional[torch.Tensor] = None,
264264
) -> Optional[torch.Tensor]:
265-
return super().forward(lm_head,
266-
hidden_states,
267-
embedding_bias=embedding_bias)
265+
return LogitsProcessor.forward(self,
266+
lm_head,
267+
hidden_states,
268+
embedding_bias=embedding_bias)

0 commit comments

Comments
 (0)