Skip to content

Commit d540834

Browse files
committed
Initial commit to add LoRA support
Remove dependency on LoRA worker class First working version with simple example Fixed BS>1 case Fix in platform.py to avoid error due to missing vllm_config Fix No LoRA case Fix warmup with LoRA Minor Cleanup Disable HPU Graphs Clean-up. Minor fixes Signed-off-by: Vivek <[email protected]>
1 parent 86a8ace commit d540834

File tree

3 files changed

+338
-13
lines changed

3 files changed

+338
-13
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional, Union, final
5+
6+
import torch
7+
from vllm_gaudi.extension.ops import (dispatch_bgmv_embedding,
8+
dispatch_bgmv_linear)
9+
10+
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
11+
12+
13+
@final
14+
class PunicaWrapperHPU(PunicaWrapperBase):
15+
16+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
17+
device: Union[torch.device, str], **kwargs):
18+
# Increasing max_num_batched_tokens by 3x to handle increase in
19+
# tensor size due to padding.
20+
# TODO: Need to check if this override is still required
21+
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
22+
max_batches, device)
23+
24+
def add_lora_embedding(self,
25+
y: torch.Tensor,
26+
x: torch.Tensor,
27+
lora_b_stacked: torch.Tensor,
28+
add_inputs: bool = True,
29+
**kwargs) -> None:
30+
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
31+
32+
def add_lora_linear(self,
33+
y: torch.Tensor,
34+
x: torch.Tensor,
35+
lora_a_stacked: tuple[torch.Tensor, ...],
36+
lora_b_stacked: tuple[torch.Tensor, ...],
37+
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
38+
scale: float,
39+
output_slices: tuple[int, ...],
40+
*,
41+
buffer: Optional[tuple[torch.Tensor, ...]] = None,
42+
**kwargs) -> None:
43+
x = x.view(-1, x.shape[-1])
44+
offset_left = 0
45+
46+
for slice_idx in range(len(output_slices)):
47+
dispatch_bgmv_linear(
48+
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
49+
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
50+
offset_left += output_slices[slice_idx]
51+
52+
def add_lora_logits(self,
53+
y: torch.Tensor,
54+
x: torch.Tensor,
55+
lora_a_stacked: torch.Tensor,
56+
lora_b_stacked: torch.Tensor,
57+
scale,
58+
*,
59+
buffer: Optional[torch.Tensor] = None,
60+
**kwargs) -> None:
61+
y_org = y
62+
y = y.view(-1, y.shape[-1])
63+
x = x.view(-1, x.shape[-1])
64+
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
65+
y = y.view_as(y_org)
66+
67+
def add_shrink(
68+
self,
69+
y: Union[tuple[torch.Tensor, ...], torch.Tensor],
70+
x: torch.Tensor,
71+
lora_a_stacked: tuple[torch.Tensor, ...],
72+
scale: float,
73+
**kwargs,
74+
) -> None:
75+
raise NotImplementedError
76+
77+
def add_expand(
78+
self,
79+
y: torch.Tensor,
80+
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
81+
lora_b_stacked: tuple[torch.Tensor, ...],
82+
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
83+
output_slices: tuple[int, ...],
84+
offset_start: int = 0,
85+
add_inputs=True,
86+
**kwargs,
87+
) -> None:
88+
raise NotImplementedError

vllm_gaudi/platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9595

9696
if (vllm_config.model_config is not None
9797
and vllm_config.model_config.dtype
98-
in (torch.float16, torch.float32)):
98+
in (torch.float16,)):
9999
logger.warning(
100100
"The HPU backend currently does not support %s. "
101101
"Using bfloat16 instead.", vllm_config.model_config.dtype)

0 commit comments

Comments
 (0)