Skip to content

Commit b5c1703

Browse files
committed
Initial commit to add LoRA support
Remove dependency on LoRA worker class First working version with simple example Fixed BS>1 case Fix in platform.py to avoid error due to missing vllm_config Fix No LoRA case Fix warmup with LoRA
1 parent 0cc8bb6 commit b5c1703

File tree

4 files changed

+356
-11
lines changed

4 files changed

+356
-11
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import TYPE_CHECKING, Optional, Union, final
5+
6+
import torch
7+
from vllm_gaudi.extension.ops import (dispatch_bgmv_embedding,
8+
dispatch_bgmv_linear)
9+
10+
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
11+
from vllm.lora.punica_wrapper.utils import convert_mapping
12+
13+
if TYPE_CHECKING:
14+
# avoid circuit import
15+
from vllm.lora.layers import LoRAMapping
16+
from vllm.lora.models import LongContextLoRAContext
17+
18+
19+
@final
20+
class PunicaWrapperHPU(PunicaWrapperBase):
21+
22+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
23+
device: Union[torch.device, str], **kwargs):
24+
# Increasing max_num_batched_tokens by 3x to handle increase in
25+
# tensor size due to padding.
26+
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
27+
max_batches, device)
28+
29+
def _update_base_metadata(
30+
self,
31+
mapping: "LoRAMapping",
32+
lora_index_to_id: list[Optional[int]],
33+
max_loras: int,
34+
vocab_size: int,
35+
extra_vocab_size: int,
36+
):
37+
(
38+
base_indices,
39+
sampler_indices,
40+
sampler_indices_padded,
41+
embeddings_indices,
42+
indices_len,
43+
) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size,
44+
extra_vocab_size, self.device)
45+
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
46+
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
47+
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
48+
sampler_indices_padded)
49+
self._embeddings_indices[:embeddings_indices.
50+
shape[0], :embeddings_indices.shape[1]].copy_(
51+
embeddings_indices)
52+
self.indices_len[:] = indices_len
53+
54+
def add_lora_embedding(self,
55+
y: torch.Tensor,
56+
x: torch.Tensor,
57+
lora_b_stacked: torch.Tensor,
58+
add_inputs: bool = True,
59+
**kwargs) -> None:
60+
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
61+
62+
def add_lora_linear(self,
63+
y: torch.Tensor,
64+
x: torch.Tensor,
65+
lora_a_stacked: tuple[torch.Tensor, ...],
66+
lora_b_stacked: tuple[torch.Tensor, ...],
67+
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
68+
scale: float,
69+
output_slices: tuple[int, ...],
70+
*,
71+
buffer: Optional[tuple[torch.Tensor, ...]] = None,
72+
**kwargs) -> None:
73+
x = x.view(-1, x.shape[-1])
74+
offset_left = 0
75+
76+
for slice_idx in range(len(output_slices)):
77+
dispatch_bgmv_linear(
78+
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
79+
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
80+
offset_left += output_slices[slice_idx]
81+
82+
def add_lora_logits(self,
83+
y: torch.Tensor,
84+
x: torch.Tensor,
85+
lora_a_stacked: torch.Tensor,
86+
lora_b_stacked: torch.Tensor,
87+
scale,
88+
*,
89+
buffer: Optional[torch.Tensor] = None,
90+
**kwargs) -> None:
91+
y_org = y
92+
y = y.view(-1, y.shape[-1])
93+
x = x.view(-1, x.shape[-1])
94+
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
95+
y = y.view_as(y_org)
96+
97+
def add_shrink(
98+
self,
99+
y: Union[tuple[torch.Tensor, ...], torch.Tensor],
100+
x: torch.Tensor,
101+
lora_a_stacked: tuple[torch.Tensor, ...],
102+
scale: float,
103+
**kwargs,
104+
) -> None:
105+
raise NotImplementedError
106+
107+
def add_expand(
108+
self,
109+
y: torch.Tensor,
110+
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
111+
lora_b_stacked: tuple[torch.Tensor, ...],
112+
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
113+
output_slices: tuple[int, ...],
114+
offset_start: int = 0,
115+
add_inputs=True,
116+
**kwargs,
117+
) -> None:
118+
raise NotImplementedError

vllm_gaudi/ops/hpu_lora.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33
from vllm.model_executor.custom_op import CustomOp
4+
from vllm.lora import layers
45
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
56

67

vllm_gaudi/platform.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9393
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
9494
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
9595

96-
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
96+
model_config = vllm_config.model_config
97+
if model_config is not None and model_config.dtype in (torch.float16,
98+
torch.float32):
99+
logger.warning(
100+
"The TPU backend currently does not support %s. "
101+
"Using bfloat16 instead.", model_config.dtype)
102+
model_config.dtype = torch.bfloat16
103+
'''if vllm_config.model_config.dtype in (torch.float16, torch.float32):
97104
logger.warning(
98105
"The TPU backend currently does not support %s. "
99106
"Using bfloat16 instead.", vllm_config.model_config.dtype)
100-
vllm_config.model_config.dtype = torch.bfloat16
107+
vllm_config.model_config.dtype = torch.bfloat16'''
101108

102109
if envs.VLLM_USE_V1:
103110
from vllm.config import CompilationLevel

0 commit comments

Comments
 (0)