Skip to content

Commit 23b5afa

Browse files
committed
Fix Lora
1 parent 4538344 commit 23b5afa

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

src/model.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,6 @@ def _setup_lora(self):
352352
lora_repository: Dict[str, str] = json.load(lora_file)
353353
self.lora_repository = lora_repository
354354
self.supported_loras: List[str] = list(self.lora_repository.keys())
355-
self.supported_loras_len = len(self.supported_loras)
356355
self.enable_lora = True
357356
except FileNotFoundError:
358357
raise FileNotFoundError(
@@ -461,9 +460,22 @@ async def _infer(self, request):
461460
try:
462461
request_task_name = self._validate_request_task_name(request)
463462
if request_task_name == "generate":
464-
request = GenerateRequest(
465-
request, self._llm_engine.generate, self.output_dtype, self.logger
466-
)
463+
if self.enable_lora:
464+
request = GenerateRequest(
465+
request,
466+
self._llm_engine.generate,
467+
self.output_dtype,
468+
self.logger,
469+
self.lora_repository,
470+
self.supported_loras,
471+
)
472+
else:
473+
request = GenerateRequest(
474+
request,
475+
self._llm_engine.generate,
476+
self.output_dtype,
477+
self.logger,
478+
)
467479
elif request_task_name == "embed":
468480
request = EmbedRequest(
469481
request, self._llm_engine.encode, self.output_dtype, self.logger

src/utils/request.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import json
2929
from abc import abstractmethod
3030
from io import BytesIO
31-
from typing import Callable
31+
from typing import Callable, Dict, List, Optional
3232

3333
import numpy as np
3434
import triton_python_backend_utils as pb_utils
@@ -51,7 +51,7 @@ class RequestBase:
5151
def __init__(
5252
self, request, executor_callback: Callable, output_dtype: np.dtype, logger
5353
):
54-
self.request = request
54+
self.triton_request = request
5555
self.executor_callback = executor_callback
5656
self.output_dtype = output_dtype
5757
self.logger = logger
@@ -74,20 +74,31 @@ def create_response(self, request_output, *args, **kwargs):
7474

7575
class GenerateRequest(RequestBase):
7676
def __init__(
77-
self, request, executor_callback: Callable, output_dtype: np.dtype, logger
77+
self,
78+
request,
79+
executor_callback: Callable,
80+
output_dtype: np.dtype,
81+
logger,
82+
lora_repository: Optional[Dict[str, str]] = None,
83+
supported_loras: Optional[List[str]] = None,
7884
):
7985
super().__init__(request, executor_callback, output_dtype, logger)
86+
# Attributes for generate requests
87+
if lora_repository is not None:
88+
self.lora_repository = lora_repository
89+
if supported_loras is not None:
90+
self.supported_loras = supported_loras
8091

8192
def _get_input_tensors(self):
8293
# prompt
8394
prompt = pb_utils.get_input_tensor_by_name(
84-
self.request, "text_input"
95+
self.triton_request, "text_input"
8596
).as_numpy()[0]
8697
if isinstance(prompt, bytes):
8798
prompt = prompt.decode("utf-8")
8899

89100
# image
90-
images = pb_utils.get_input_tensor_by_name(self.request, "image")
101+
images = pb_utils.get_input_tensor_by_name(self.triton_request, "image")
91102
if images:
92103
images_vllm = []
93104
for image_np in images.as_numpy():
@@ -101,15 +112,15 @@ def _get_input_tensors(self):
101112
}
102113

103114
# stream
104-
stream = pb_utils.get_input_tensor_by_name(self.request, "stream")
115+
stream = pb_utils.get_input_tensor_by_name(self.triton_request, "stream")
105116
if stream:
106117
stream = stream.as_numpy()[0]
107118
else:
108119
stream = False
109120

110121
# prepend_input / exclude_input_in_output
111122
prepend_input = pb_utils.get_input_tensor_by_name(
112-
self.request, "exclude_input_in_output"
123+
self.triton_request, "exclude_input_in_output"
113124
)
114125
if prepend_input:
115126
# When `exclude_input_in_output` is False, we want to prepend input prompt
@@ -128,12 +139,12 @@ def _get_input_tensors(self):
128139
# An alternative mechanism to receive serialized parameters as an input
129140
# tensor, because request parameters are not yet supported via BLS.
130141
sampling_parameters = pb_utils.get_input_tensor_by_name(
131-
self.request, "sampling_parameters"
142+
self.triton_request, "sampling_parameters"
132143
)
133144
if sampling_parameters:
134145
parameters = sampling_parameters.as_numpy()[0].decode("utf-8")
135146
else:
136-
parameters = self.request.parameters()
147+
parameters = self.triton_request.parameters()
137148

138149
# additional outputs
139150
additional_outputs = {
@@ -144,7 +155,7 @@ def _get_input_tensors(self):
144155
"return_num_output_tokens": None,
145156
}
146157
for tensor_name in additional_outputs.keys():
147-
tensor = pb_utils.get_input_tensor_by_name(self.request, tensor_name)
158+
tensor = pb_utils.get_input_tensor_by_name(self.triton_request, tensor_name)
148159
if tensor:
149160
tensor = bool(tensor.as_numpy()[0])
150161
else:
@@ -302,7 +313,7 @@ def __init__(
302313

303314
def _get_input_tensors(self):
304315
embedding_request = pb_utils.get_input_tensor_by_name(
305-
self.request, "embedding_request"
316+
self.triton_request, "embedding_request"
306317
).as_numpy()[0]
307318
embedding_request = json.loads(embedding_request.decode("utf-8"))
308319
# prompt
@@ -324,7 +335,7 @@ def _get_input_tensors(self):
324335
"return_num_output_tokens": None,
325336
}
326337
for tensor_name in additional_outputs.keys():
327-
tensor = pb_utils.get_input_tensor_by_name(self.request, tensor_name)
338+
tensor = pb_utils.get_input_tensor_by_name(self.triton_request, tensor_name)
328339
if tensor:
329340
tensor = bool(tensor.as_numpy()[0])
330341
else:

0 commit comments

Comments
 (0)