Skip to content

Commit e4fa194

Browse files
committed
Initial commit to add LoRA support
Remove dependency on LoRA worker class First working version with simple example Fixed BS>1 case Fix in platform.py to avoid error due to missing vllm_config Fix No LoRA case Fix warmup with LoRA Minor Cleanup Disable HPU Graphs Clean-up. Minor fixes Signed-off-by: Vivek <[email protected]>
1 parent 0cc8bb6 commit e4fa194

File tree

3 files changed

+349
-14
lines changed

3 files changed

+349
-14
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional, Union, final
5+
6+
import torch
7+
from vllm_gaudi.extension.ops import (dispatch_bgmv_embedding,
8+
dispatch_bgmv_linear)
9+
10+
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
11+
12+
13+
@final
14+
class PunicaWrapperHPU(PunicaWrapperBase):
15+
16+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
17+
device: Union[torch.device, str], **kwargs):
18+
# Increasing max_num_batched_tokens by 3x to handle increase in
19+
# tensor size due to padding.
20+
# TODO: Need to check if this override is still required
21+
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
22+
max_batches, device)
23+
24+
def add_lora_embedding(self,
25+
y: torch.Tensor,
26+
x: torch.Tensor,
27+
lora_b_stacked: torch.Tensor,
28+
add_inputs: bool = True,
29+
**kwargs) -> None:
30+
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
31+
32+
def add_lora_linear(self,
33+
y: torch.Tensor,
34+
x: torch.Tensor,
35+
lora_a_stacked: tuple[torch.Tensor, ...],
36+
lora_b_stacked: tuple[torch.Tensor, ...],
37+
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
38+
scale: float,
39+
output_slices: tuple[int, ...],
40+
*,
41+
buffer: Optional[tuple[torch.Tensor, ...]] = None,
42+
**kwargs) -> None:
43+
x = x.view(-1, x.shape[-1])
44+
offset_left = 0
45+
46+
for slice_idx in range(len(output_slices)):
47+
dispatch_bgmv_linear(
48+
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
49+
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
50+
offset_left += output_slices[slice_idx]
51+
52+
def add_lora_logits(self,
53+
y: torch.Tensor,
54+
x: torch.Tensor,
55+
lora_a_stacked: torch.Tensor,
56+
lora_b_stacked: torch.Tensor,
57+
scale,
58+
*,
59+
buffer: Optional[torch.Tensor] = None,
60+
**kwargs) -> None:
61+
y_org = y
62+
y = y.view(-1, y.shape[-1])
63+
x = x.view(-1, x.shape[-1])
64+
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
65+
y = y.view_as(y_org)
66+
67+
def add_shrink(
68+
self,
69+
y: Union[tuple[torch.Tensor, ...], torch.Tensor],
70+
x: torch.Tensor,
71+
lora_a_stacked: tuple[torch.Tensor, ...],
72+
scale: float,
73+
**kwargs,
74+
) -> None:
75+
raise NotImplementedError
76+
77+
def add_expand(
78+
self,
79+
y: torch.Tensor,
80+
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
81+
lora_b_stacked: tuple[torch.Tensor, ...],
82+
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
83+
output_slices: tuple[int, ...],
84+
offset_start: int = 0,
85+
add_inputs=True,
86+
**kwargs,
87+
) -> None:
88+
raise NotImplementedError

vllm_gaudi/platform.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9393
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
9494
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
9595

96-
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
96+
model_config = vllm_config.model_config
97+
if model_config is not None and model_config.dtype in (torch.float16,
98+
torch.float32):
99+
logger.warning(
100+
"The TPU backend currently does not support %s. "
101+
"Using bfloat16 instead.", model_config.dtype)
102+
model_config.dtype = torch.bfloat16
103+
'''if vllm_config.model_config.dtype in (torch.float16, torch.float32):
97104
logger.warning(
98105
"The TPU backend currently does not support %s. "
99106
"Using bfloat16 instead.", vllm_config.model_config.dtype)
100-
vllm_config.model_config.dtype = torch.bfloat16
107+
vllm_config.model_config.dtype = torch.bfloat16'''
101108

102109
if envs.VLLM_USE_V1:
103110
from vllm.config import CompilationLevel

0 commit comments

Comments
 (0)