1
1
import torch
2
2
import torch .nn .functional as F
3
- from vllm .model_executor .custom_op import CustomOp
4
3
from vllm .lora .layers import VocabParallelEmbeddingWithLoRA
4
+ from vllm .lora import layers
5
+ from vllm .platforms import current_platform
6
+ from typing import Optional
5
7
6
8
7
- @CustomOp .register_oot (name = 'VocabParallelEmbeddingWithLoRA' )
8
9
class HPUVocabParallelEmbeddingWithLoRA (VocabParallelEmbeddingWithLoRA ):
9
10
10
- def forward_oot (self , x : torch .Tensor ) -> torch .Tensor :
11
- # x need to reshaped into 2d as batch is there
12
- # can be removed on moving to flat tensors
13
- shape = x .shape
14
- x = x .view (shape [0 ] * shape [1 ])
15
-
11
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
16
12
added_tokens_mask = torch .where (x > self .base_layer .org_vocab_size - 1 ,
17
13
1 , 0 )
18
- embeddings_indices = torch .narrow (
19
- self .punica_wrapper ._embeddings_indices , 1 , 0 , x .size (0 ))
20
14
21
- indices = embeddings_indices [1 ]
15
+ # NB: Don't use torch.narrow here. torch.narrow triggers some
16
+ # Dynamic Shape specialization in torch.compile
17
+ # flatten to get num_tokens since HPU uses 2d input layout
18
+ # reshape indices_1, indices_0 to match shape of input
19
+ num_tokens = x .view (- 1 ).shape [0 ]
20
+ indices_1 = self .punica_wrapper ._embeddings_indices [
21
+ 1 ][:num_tokens ].view_as (x )
22
+ indices_0 = self .punica_wrapper ._embeddings_indices [
23
+ 0 ][:num_tokens ].view_as (x )
24
+
22
25
full_lora_a_embeddings = F .embedding (
23
- x + indices ,
26
+ x + indices_1 ,
24
27
self .lora_a_stacked_2d ,
25
28
)
26
- indices = embeddings_indices [0 ]
27
29
full_output = self .base_layer .forward (x +
28
- (indices * added_tokens_mask ))
30
+ (indices_0 * added_tokens_mask ))
29
31
30
32
full_output_org = full_output
31
33
if full_output .ndim == 3 :
@@ -37,11 +39,20 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
37
39
full_lora_a_embeddings .shape [1 ],
38
40
- 1 ,
39
41
)
40
- self .punica_wrapper .add_lora_embedding (full_output ,
41
- full_lora_a_embeddings ,
42
- self .lora_b_stacked ,
43
- add_input = True )
44
- # can be removed on moving to flat tensors
45
- full_output_org = full_output_org .view (shape [0 ], shape [1 ],
46
- full_output_org .shape [1 ])
42
+
43
+ lora_output : Optional [
44
+ torch .Tensor ] = self .punica_wrapper .add_lora_embedding (
45
+ full_output ,
46
+ full_lora_a_embeddings ,
47
+ self .lora_b_stacked ,
48
+ add_input = True )
49
+
50
+ if not current_platform .can_update_inplace ():
51
+ full_output = lora_output
52
+
47
53
return full_output .view_as (full_output_org )
54
+
55
+
56
+ # refer to https://github.com/vllm-project/vllm/pull/21923 for more details
57
+ # on why this patching is needed.
58
+ layers .VocabParallelEmbeddingWithLoRA = HPUVocabParallelEmbeddingWithLoRA
0 commit comments