Skip to content

Commit e556bb5

Browse files
committed
Update hpu_lora to use patching
Signed-off-by: Vivek <[email protected]>
1 parent 6e30fbe commit e556bb5

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

vllm_gaudi/ops/hpu_lora.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,33 @@
11
import torch
22
import torch.nn.functional as F
3-
from vllm.model_executor.custom_op import CustomOp
43
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
4+
from vllm.lora import layers
5+
from vllm.platforms import current_platform
6+
from typing import Optional
57

68

7-
@CustomOp.register_oot(name='VocabParallelEmbeddingWithLoRA')
89
class HPUVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
910

10-
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
11-
# x need to reshaped into 2d as batch is there
12-
# can be removed on moving to flat tensors
13-
shape = x.shape
14-
x = x.view(shape[0] * shape[1])
15-
11+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1612
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
1713
1, 0)
18-
embeddings_indices = torch.narrow(
19-
self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))
2014

21-
indices = embeddings_indices[1]
15+
# NB: Don't use torch.narrow here. torch.narrow triggers some
16+
# Dynamic Shape specialization in torch.compile
17+
# flatten to get num_tokens since HPU uses 2d input layout
18+
# reshape indices_1, indices_0 to match shape of input
19+
num_tokens = x.view(-1).shape[0]
20+
indices_1 = self.punica_wrapper._embeddings_indices[
21+
1][:num_tokens].view_as(x)
22+
indices_0 = self.punica_wrapper._embeddings_indices[
23+
0][:num_tokens].view_as(x)
24+
2225
full_lora_a_embeddings = F.embedding(
23-
x + indices,
26+
x + indices_1,
2427
self.lora_a_stacked_2d,
2528
)
26-
indices = embeddings_indices[0]
2729
full_output = self.base_layer.forward(x +
28-
(indices * added_tokens_mask))
30+
(indices_0 * added_tokens_mask))
2931

3032
full_output_org = full_output
3133
if full_output.ndim == 3:
@@ -37,11 +39,20 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
3739
full_lora_a_embeddings.shape[1],
3840
-1,
3941
)
40-
self.punica_wrapper.add_lora_embedding(full_output,
41-
full_lora_a_embeddings,
42-
self.lora_b_stacked,
43-
add_input=True)
44-
# can be removed on moving to flat tensors
45-
full_output_org = full_output_org.view(shape[0], shape[1],
46-
full_output_org.shape[1])
42+
43+
lora_output: Optional[
44+
torch.Tensor] = self.punica_wrapper.add_lora_embedding(
45+
full_output,
46+
full_lora_a_embeddings,
47+
self.lora_b_stacked,
48+
add_input=True)
49+
50+
if not current_platform.can_update_inplace():
51+
full_output = lora_output
52+
4753
return full_output.view_as(full_output_org)
54+
55+
56+
# refer to https://github.com/vllm-project/vllm/pull/21923 for more details
57+
# on why this patching is needed.
58+
layers.VocabParallelEmbeddingWithLoRA = HPUVocabParallelEmbeddingWithLoRA

0 commit comments

Comments
 (0)