Skip to content

Commit 52e125f

Browse files
author
jibxie
committed
Support modal model input
1 parent b71088a commit 52e125f

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

src/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ class TritonPythonModel:
5252
def auto_complete_config(auto_complete_model_config):
5353
inputs = [
5454
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
55+
{
56+
"name": "multi_modal_data",
57+
"data_type": "TYPE_STRING",
58+
"dims": [1],
59+
"optional": True,
60+
},
5561
{
5662
"name": "stream",
5763
"data_type": "TYPE_BOOL",
@@ -385,6 +391,21 @@ async def generate(self, request):
385391
).as_numpy()[0]
386392
if isinstance(prompt, bytes):
387393
prompt = prompt.decode("utf-8")
394+
395+
multi_modal_data = pb_utils.get_input_tensor_by_name(
396+
request, "multi_modal_data"
397+
).as_numpy()[0]
398+
if isinstance(multi_modal_data, bytes):
399+
multi_modal_data = multi_modal_data.decode("utf-8")
400+
401+
if multi_modal_data is not None:
402+
# Build TextPrompt format prompt for multi modal models
403+
multi_modal_data = json.loads(multi_modal_data)
404+
prompt = {
405+
"prompt": prompt,
406+
"multi_modal_data": multi_modal_data
407+
}
408+
388409
stream = pb_utils.get_input_tensor_by_name(request, "stream")
389410
if stream:
390411
stream = stream.as_numpy()[0]

0 commit comments

Comments
 (0)