2828import json
2929from abc import abstractmethod
3030from io import BytesIO
31- from typing import Callable
31+ from typing import Callable , Dict , List , Optional
3232
3333import numpy as np
3434import triton_python_backend_utils as pb_utils
@@ -51,7 +51,7 @@ class RequestBase:
5151 def __init__ (
5252 self , request , executor_callback : Callable , output_dtype : np .dtype , logger
5353 ):
54- self .request = request
54+ self .triton_request = request
5555 self .executor_callback = executor_callback
5656 self .output_dtype = output_dtype
5757 self .logger = logger
@@ -74,20 +74,31 @@ def create_response(self, request_output, *args, **kwargs):
7474
7575class GenerateRequest (RequestBase ):
7676 def __init__ (
77- self , request , executor_callback : Callable , output_dtype : np .dtype , logger
77+ self ,
78+ request ,
79+ executor_callback : Callable ,
80+ output_dtype : np .dtype ,
81+ logger ,
82+ lora_repository : Optional [Dict [str , str ]] = None ,
83+ supported_loras : Optional [List [str ]] = None ,
7884 ):
7985 super ().__init__ (request , executor_callback , output_dtype , logger )
86+ # Attributes for generate requests
87+ if lora_repository is not None :
88+ self .lora_repository = lora_repository
89+ if supported_loras is not None :
90+ self .supported_loras = supported_loras
8091
8192 def _get_input_tensors (self ):
8293 # prompt
8394 prompt = pb_utils .get_input_tensor_by_name (
84- self .request , "text_input"
95+ self .triton_request , "text_input"
8596 ).as_numpy ()[0 ]
8697 if isinstance (prompt , bytes ):
8798 prompt = prompt .decode ("utf-8" )
8899
89100 # image
90- images = pb_utils .get_input_tensor_by_name (self .request , "image" )
101+ images = pb_utils .get_input_tensor_by_name (self .triton_request , "image" )
91102 if images :
92103 images_vllm = []
93104 for image_np in images .as_numpy ():
@@ -101,15 +112,15 @@ def _get_input_tensors(self):
101112 }
102113
103114 # stream
104- stream = pb_utils .get_input_tensor_by_name (self .request , "stream" )
115+ stream = pb_utils .get_input_tensor_by_name (self .triton_request , "stream" )
105116 if stream :
106117 stream = stream .as_numpy ()[0 ]
107118 else :
108119 stream = False
109120
110121 # prepend_input / exclude_input_in_output
111122 prepend_input = pb_utils .get_input_tensor_by_name (
112- self .request , "exclude_input_in_output"
123+ self .triton_request , "exclude_input_in_output"
113124 )
114125 if prepend_input :
115126 # When `exclude_input_in_output` is False, we want to prepend input prompt
@@ -128,12 +139,12 @@ def _get_input_tensors(self):
128139 # An alternative mechanism to receive serialized parameters as an input
129140 # tensor, because request parameters are not yet supported via BLS.
130141 sampling_parameters = pb_utils .get_input_tensor_by_name (
131- self .request , "sampling_parameters"
142+ self .triton_request , "sampling_parameters"
132143 )
133144 if sampling_parameters :
134145 parameters = sampling_parameters .as_numpy ()[0 ].decode ("utf-8" )
135146 else :
136- parameters = self .request .parameters ()
147+ parameters = self .triton_request .parameters ()
137148
138149 # additional outputs
139150 additional_outputs = {
@@ -144,7 +155,7 @@ def _get_input_tensors(self):
144155 "return_num_output_tokens" : None ,
145156 }
146157 for tensor_name in additional_outputs .keys ():
147- tensor = pb_utils .get_input_tensor_by_name (self .request , tensor_name )
158+ tensor = pb_utils .get_input_tensor_by_name (self .triton_request , tensor_name )
148159 if tensor :
149160 tensor = bool (tensor .as_numpy ()[0 ])
150161 else :
@@ -302,7 +313,7 @@ def __init__(
302313
303314 def _get_input_tensors (self ):
304315 embedding_request = pb_utils .get_input_tensor_by_name (
305- self .request , "embedding_request"
316+ self .triton_request , "embedding_request"
306317 ).as_numpy ()[0 ]
307318 embedding_request = json .loads (embedding_request .decode ("utf-8" ))
308319 # prompt
@@ -324,7 +335,7 @@ def _get_input_tensors(self):
324335 "return_num_output_tokens" : None ,
325336 }
326337 for tensor_name in additional_outputs .keys ():
327- tensor = pb_utils .get_input_tensor_by_name (self .request , tensor_name )
338+ tensor = pb_utils .get_input_tensor_by_name (self .triton_request , tensor_name )
328339 if tensor :
329340 tensor = bool (tensor .as_numpy ()[0 ])
330341 else :
0 commit comments