1
- # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # Copyright 2024-2025 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
#
3
3
# Redistribution and use in source and binary forms, with or without
4
4
# modification, are permitted provided that the following conditions
39
39
_create_trtllm_inference_request ,
40
40
_create_vllm_inference_request ,
41
41
_get_output ,
42
+ _get_vllm_lora_names ,
42
43
_validate_triton_responses_non_streaming ,
43
44
)
44
45
from schemas .openai import (
@@ -70,6 +71,8 @@ class TritonModelMetadata:
70
71
model : tritonserver .Model
71
72
# Tokenizers used for chat templates
72
73
tokenizer : Optional [Any ]
74
+ # LoRA names supported by the backend
75
+ lora_names : Optional [List [str ]]
73
76
# Time that model was loaded by Triton
74
77
create_time : int
75
78
# Conversion format between OpenAI and Triton requests
@@ -78,13 +81,18 @@ class TritonModelMetadata:
78
81
79
82
class TritonLLMEngine (LLMEngine ):
80
83
def __init__ (
81
- self , server : tritonserver .Server , tokenizer : str , backend : Optional [str ] = None
84
+ self ,
85
+ server : tritonserver .Server ,
86
+ tokenizer : str ,
87
+ backend : Optional [str ] = None ,
88
+ lora_separator : Optional [str ] = None ,
82
89
):
83
90
# Assume an already configured and started server
84
91
self .server = server
85
92
self .tokenizer = self ._get_tokenizer (tokenizer )
86
93
# TODO: Reconsider name of "backend" vs. something like "request_format"
87
94
self .backend = backend
95
+ self .lora_separator = lora_separator
88
96
89
97
# NOTE: Creation time and model metadata will be static at startup for
90
98
# now, and won't account for dynamically loading/unloading models.
@@ -100,22 +108,35 @@ def metrics(self) -> str:
100
108
def models (self ) -> List [Model ]:
101
109
models = []
102
110
for metadata in self .model_metadata .values ():
103
- models .append (
104
- Model (
105
- id = metadata .name ,
106
- created = metadata .create_time ,
107
- object = ObjectType .model ,
108
- owned_by = "Triton Inference Server" ,
109
- ),
110
- )
111
+ model_names = [metadata .name ]
112
+ if (
113
+ self .lora_separator is not None
114
+ and len (self .lora_separator ) > 0
115
+ and metadata .lora_names is not None
116
+ ):
117
+ for lora_name in metadata .lora_names :
118
+ model_names .append (
119
+ f"{ metadata .name } { self .lora_separator } { lora_name } "
120
+ )
121
+
122
+ for model_name in model_names :
123
+ models .append (
124
+ Model (
125
+ id = model_name ,
126
+ created = metadata .create_time ,
127
+ object = ObjectType .model ,
128
+ owned_by = "Triton Inference Server" ,
129
+ ),
130
+ )
111
131
112
132
return models
113
133
114
134
async def chat (
115
135
self , request : CreateChatCompletionRequest
116
136
) -> CreateChatCompletionResponse | AsyncIterator [str ]:
117
- metadata = self .model_metadata .get (request .model )
118
- self ._validate_chat_request (request , metadata )
137
+ model_name , lora_name = self ._get_model_and_lora_name (request .model )
138
+ metadata = self .model_metadata .get (model_name )
139
+ self ._validate_chat_request (request , metadata , lora_name )
119
140
120
141
conversation = [
121
142
message .model_dump (exclude_none = True ) for message in request .messages
@@ -130,7 +151,7 @@ async def chat(
130
151
131
152
# Convert to Triton request format and perform inference
132
153
responses = metadata .model .async_infer (
133
- metadata .request_converter (metadata .model , prompt , request )
154
+ metadata .request_converter (metadata .model , prompt , request , lora_name )
134
155
)
135
156
136
157
# Prepare and send responses back to client in OpenAI format
@@ -174,20 +195,23 @@ async def completion(
174
195
self , request : CreateCompletionRequest
175
196
) -> CreateCompletionResponse | AsyncIterator [str ]:
176
197
# Validate request and convert to Triton format
177
- metadata = self .model_metadata .get (request .model )
178
- self ._validate_completion_request (request , metadata )
198
+ model_name , lora_name = self ._get_model_and_lora_name (request .model )
199
+ metadata = self .model_metadata .get (model_name )
200
+ self ._validate_completion_request (request , metadata , lora_name )
179
201
180
202
# Convert to Triton request format and perform inference
181
203
responses = metadata .model .async_infer (
182
- metadata .request_converter (metadata .model , request .prompt , request )
204
+ metadata .request_converter (
205
+ metadata .model , request .prompt , request , lora_name
206
+ )
183
207
)
184
208
185
209
# Prepare and send responses back to client in OpenAI format
186
210
request_id = f"cmpl-{ uuid .uuid1 ()} "
187
211
created = int (time .time ())
188
212
if request .stream :
189
213
return self ._streaming_completion_iterator (
190
- request_id , created , metadata . name , responses
214
+ request_id , created , request . model , responses
191
215
)
192
216
193
217
# Response validation with decoupled models in mind
@@ -208,7 +232,7 @@ async def completion(
208
232
system_fingerprint = None ,
209
233
object = ObjectType .text_completion ,
210
234
created = created ,
211
- model = metadata . name ,
235
+ model = request . model ,
212
236
)
213
237
214
238
# TODO: This behavior should be tested further
@@ -234,6 +258,16 @@ def _determine_request_converter(self, backend: str):
234
258
# an ensemble, a python or BLS model, a TRT-LLM backend model, etc.
235
259
return _create_trtllm_inference_request
236
260
261
+ def _get_model_and_lora_name (self , request_model_name : str ):
262
+ if self .lora_separator is None or len (self .lora_separator ) == 0 :
263
+ return request_model_name , None
264
+
265
+ names = request_model_name .split (self .lora_separator )
266
+ if len (names ) != 2 :
267
+ return request_model_name , None
268
+
269
+ return names [0 ], names [1 ]
270
+
237
271
def _get_tokenizer (self , tokenizer_name : str ):
238
272
tokenizer = None
239
273
if tokenizer_name :
@@ -254,11 +288,18 @@ def _get_model_metadata(self) -> Dict[str, TritonModelMetadata]:
254
288
backend = "ensemble"
255
289
print (f"Found model: { name = } , { backend = } " )
256
290
291
+ lora_names = None
292
+ if self .backend == "vllm" or backend == "vllm" :
293
+ lora_names = _get_vllm_lora_names (
294
+ self .server .options .model_repository , name , model .version
295
+ )
296
+
257
297
metadata = TritonModelMetadata (
258
298
name = name ,
259
299
backend = backend ,
260
300
model = model ,
261
301
tokenizer = self .tokenizer ,
302
+ lora_names = lora_names ,
262
303
create_time = self .create_time ,
263
304
request_converter = self ._determine_request_converter (backend ),
264
305
)
@@ -343,7 +384,10 @@ async def _streaming_chat_iterator(
343
384
yield "data: [DONE]\n \n "
344
385
345
386
def _validate_chat_request (
346
- self , request : CreateChatCompletionRequest , metadata : TritonModelMetadata
387
+ self ,
388
+ request : CreateChatCompletionRequest ,
389
+ metadata : TritonModelMetadata ,
390
+ lora_name : str | None ,
347
391
):
348
392
"""
349
393
Validates a chat request to align with currently supported features.
@@ -362,6 +406,13 @@ def _validate_chat_request(
362
406
if not metadata .request_converter :
363
407
raise Exception (f"Unknown request format for model: { request .model } " )
364
408
409
+ if (
410
+ metadata .lora_names is not None
411
+ and lora_name is not None
412
+ and lora_name not in metadata .lora_names
413
+ ):
414
+ raise Exception (f"Unknown LoRA: { lora_name } ; for model: { request .model } " )
415
+
365
416
# Reject unsupported features if requested
366
417
if request .n and request .n > 1 :
367
418
raise Exception (
@@ -396,7 +447,10 @@ async def _streaming_completion_iterator(
396
447
yield "data: [DONE]\n \n "
397
448
398
449
def _validate_completion_request (
399
- self , request : CreateCompletionRequest , metadata : TritonModelMetadata
450
+ self ,
451
+ request : CreateCompletionRequest ,
452
+ metadata : TritonModelMetadata ,
453
+ lora_name : str | None ,
400
454
):
401
455
"""
402
456
Validates a completions request to align with currently supported features.
@@ -411,6 +465,13 @@ def _validate_completion_request(
411
465
if not metadata .request_converter :
412
466
raise Exception (f"Unknown request format for model: { request .model } " )
413
467
468
+ if (
469
+ metadata .lora_names is not None
470
+ and lora_name is not None
471
+ and lora_name not in metadata .lora_names
472
+ ):
473
+ raise Exception (f"Unknown LoRA: { lora_name } ; for model: { request .model } " )
474
+
414
475
# Reject unsupported features if requested
415
476
if request .suffix is not None :
416
477
raise Exception ("suffix is not currently supported" )
0 commit comments