Skip to content

Commit 5eb09ce

Browse files
authored
feat: Add multi-LoRA support to OpenAI frontend (#8038)
1 parent 27f0410 commit 5eb09ce

File tree

6 files changed

+576
-28
lines changed

6 files changed

+576
-28
lines changed

python/openai/README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,69 @@ pip install -r requirements-test.txt
230230
pytest -v tests/
231231
```
232232

233+
### LoRA Adapters
234+
235+
If the command line argument `--lora-separator=<separator_string>` is provided
236+
when starting the OpenAI Frontend, a vLLM LoRA adaptor listed on the
237+
`multi_lora.json` may be selected by appending the LoRA name to the model name,
238+
separated by the LoRA separator, on the inference request in
239+
`<model_name><separator_string><lora_name>` format.
240+
241+
<details>
242+
<summary>For example</summary>
243+
244+
```bash
245+
# start server with model named gemma-2b
246+
python3 openai_frontend/main.py --lora-separator=_lora_ ...
247+
248+
# inference without LoRA
249+
curl -s http://localhost:9000/v1/completions -H 'Content-Type: application/json' -d '{
250+
"model": "gemma-2b",
251+
"temperature": 0,
252+
"prompt": "When was the wheel invented?"
253+
}'
254+
{
255+
...
256+
"choices":[{..."text":"\n\nThe wheel was invented by the Sumerians in Mesopotamia around 350"}],
257+
...
258+
}
259+
260+
# inference with LoRA named doll
261+
curl -s http://localhost:9000/v1/completions -H 'Content-Type: application/json' -d '{
262+
"model": "gemma-2b_lora_doll",
263+
"temperature": 0,
264+
"prompt": "When was the wheel invented?"
265+
}'
266+
{
267+
...
268+
"choices":[{..."text":"\n\nThe wheel was invented in Mesopotamia around 3500 BC.\n\n"}],
269+
...
270+
}
271+
272+
# inference with LoRA named sheep
273+
curl -s http://localhost:9000/v1/completions -H 'Content-Type: application/json' -d '{
274+
"model": "gemma-2b_lora_sheep",
275+
"temperature": 0,
276+
"prompt": "When was the wheel invented?"
277+
}'
278+
{
279+
...
280+
"choices":[{..."text":"\n\nThe wheel was invented around 3000 BC in Mesopotamia.\n\n"}],
281+
...
282+
}
283+
```
284+
285+
</details>
286+
287+
When listing or retrieving model(s), the model id will include the LoRA name in
288+
the same `<model_name><separator_string><lora_name>` format for each LoRA
289+
adapter listed on the `multi_lora.json`. Note: The LoRA name inclusion is
290+
limited to locally stored models, inference requests are not limited though.
291+
292+
See the
293+
[vLLM documentation](https://github.com/triton-inference-server/vllm_backend/blob/main/docs/llama_multi_lora_tutorial.md)
294+
on how to serve a model with LoRA adapters.
295+
233296
## TensorRT-LLM
234297

235298
0. Prepare your model repository for a TensorRT-LLM model, build the engine, etc. You can try any of the following options:

python/openai/openai_frontend/engine/triton_engine.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Redistribution and use in source and binary forms, with or without
44
# modification, are permitted provided that the following conditions
@@ -39,6 +39,7 @@
3939
_create_trtllm_inference_request,
4040
_create_vllm_inference_request,
4141
_get_output,
42+
_get_vllm_lora_names,
4243
_validate_triton_responses_non_streaming,
4344
)
4445
from schemas.openai import (
@@ -70,6 +71,8 @@ class TritonModelMetadata:
7071
model: tritonserver.Model
7172
# Tokenizers used for chat templates
7273
tokenizer: Optional[Any]
74+
# LoRA names supported by the backend
75+
lora_names: Optional[List[str]]
7376
# Time that model was loaded by Triton
7477
create_time: int
7578
# Conversion format between OpenAI and Triton requests
@@ -78,13 +81,18 @@ class TritonModelMetadata:
7881

7982
class TritonLLMEngine(LLMEngine):
8083
def __init__(
81-
self, server: tritonserver.Server, tokenizer: str, backend: Optional[str] = None
84+
self,
85+
server: tritonserver.Server,
86+
tokenizer: str,
87+
backend: Optional[str] = None,
88+
lora_separator: Optional[str] = None,
8289
):
8390
# Assume an already configured and started server
8491
self.server = server
8592
self.tokenizer = self._get_tokenizer(tokenizer)
8693
# TODO: Reconsider name of "backend" vs. something like "request_format"
8794
self.backend = backend
95+
self.lora_separator = lora_separator
8896

8997
# NOTE: Creation time and model metadata will be static at startup for
9098
# now, and won't account for dynamically loading/unloading models.
@@ -100,22 +108,35 @@ def metrics(self) -> str:
100108
def models(self) -> List[Model]:
101109
models = []
102110
for metadata in self.model_metadata.values():
103-
models.append(
104-
Model(
105-
id=metadata.name,
106-
created=metadata.create_time,
107-
object=ObjectType.model,
108-
owned_by="Triton Inference Server",
109-
),
110-
)
111+
model_names = [metadata.name]
112+
if (
113+
self.lora_separator is not None
114+
and len(self.lora_separator) > 0
115+
and metadata.lora_names is not None
116+
):
117+
for lora_name in metadata.lora_names:
118+
model_names.append(
119+
f"{metadata.name}{self.lora_separator}{lora_name}"
120+
)
121+
122+
for model_name in model_names:
123+
models.append(
124+
Model(
125+
id=model_name,
126+
created=metadata.create_time,
127+
object=ObjectType.model,
128+
owned_by="Triton Inference Server",
129+
),
130+
)
111131

112132
return models
113133

114134
async def chat(
115135
self, request: CreateChatCompletionRequest
116136
) -> CreateChatCompletionResponse | AsyncIterator[str]:
117-
metadata = self.model_metadata.get(request.model)
118-
self._validate_chat_request(request, metadata)
137+
model_name, lora_name = self._get_model_and_lora_name(request.model)
138+
metadata = self.model_metadata.get(model_name)
139+
self._validate_chat_request(request, metadata, lora_name)
119140

120141
conversation = [
121142
message.model_dump(exclude_none=True) for message in request.messages
@@ -130,7 +151,7 @@ async def chat(
130151

131152
# Convert to Triton request format and perform inference
132153
responses = metadata.model.async_infer(
133-
metadata.request_converter(metadata.model, prompt, request)
154+
metadata.request_converter(metadata.model, prompt, request, lora_name)
134155
)
135156

136157
# Prepare and send responses back to client in OpenAI format
@@ -174,20 +195,23 @@ async def completion(
174195
self, request: CreateCompletionRequest
175196
) -> CreateCompletionResponse | AsyncIterator[str]:
176197
# Validate request and convert to Triton format
177-
metadata = self.model_metadata.get(request.model)
178-
self._validate_completion_request(request, metadata)
198+
model_name, lora_name = self._get_model_and_lora_name(request.model)
199+
metadata = self.model_metadata.get(model_name)
200+
self._validate_completion_request(request, metadata, lora_name)
179201

180202
# Convert to Triton request format and perform inference
181203
responses = metadata.model.async_infer(
182-
metadata.request_converter(metadata.model, request.prompt, request)
204+
metadata.request_converter(
205+
metadata.model, request.prompt, request, lora_name
206+
)
183207
)
184208

185209
# Prepare and send responses back to client in OpenAI format
186210
request_id = f"cmpl-{uuid.uuid1()}"
187211
created = int(time.time())
188212
if request.stream:
189213
return self._streaming_completion_iterator(
190-
request_id, created, metadata.name, responses
214+
request_id, created, request.model, responses
191215
)
192216

193217
# Response validation with decoupled models in mind
@@ -208,7 +232,7 @@ async def completion(
208232
system_fingerprint=None,
209233
object=ObjectType.text_completion,
210234
created=created,
211-
model=metadata.name,
235+
model=request.model,
212236
)
213237

214238
# TODO: This behavior should be tested further
@@ -234,6 +258,16 @@ def _determine_request_converter(self, backend: str):
234258
# an ensemble, a python or BLS model, a TRT-LLM backend model, etc.
235259
return _create_trtllm_inference_request
236260

261+
def _get_model_and_lora_name(self, request_model_name: str):
262+
if self.lora_separator is None or len(self.lora_separator) == 0:
263+
return request_model_name, None
264+
265+
names = request_model_name.split(self.lora_separator)
266+
if len(names) != 2:
267+
return request_model_name, None
268+
269+
return names[0], names[1]
270+
237271
def _get_tokenizer(self, tokenizer_name: str):
238272
tokenizer = None
239273
if tokenizer_name:
@@ -254,11 +288,18 @@ def _get_model_metadata(self) -> Dict[str, TritonModelMetadata]:
254288
backend = "ensemble"
255289
print(f"Found model: {name=}, {backend=}")
256290

291+
lora_names = None
292+
if self.backend == "vllm" or backend == "vllm":
293+
lora_names = _get_vllm_lora_names(
294+
self.server.options.model_repository, name, model.version
295+
)
296+
257297
metadata = TritonModelMetadata(
258298
name=name,
259299
backend=backend,
260300
model=model,
261301
tokenizer=self.tokenizer,
302+
lora_names=lora_names,
262303
create_time=self.create_time,
263304
request_converter=self._determine_request_converter(backend),
264305
)
@@ -343,7 +384,10 @@ async def _streaming_chat_iterator(
343384
yield "data: [DONE]\n\n"
344385

345386
def _validate_chat_request(
346-
self, request: CreateChatCompletionRequest, metadata: TritonModelMetadata
387+
self,
388+
request: CreateChatCompletionRequest,
389+
metadata: TritonModelMetadata,
390+
lora_name: str | None,
347391
):
348392
"""
349393
Validates a chat request to align with currently supported features.
@@ -362,6 +406,13 @@ def _validate_chat_request(
362406
if not metadata.request_converter:
363407
raise Exception(f"Unknown request format for model: {request.model}")
364408

409+
if (
410+
metadata.lora_names is not None
411+
and lora_name is not None
412+
and lora_name not in metadata.lora_names
413+
):
414+
raise Exception(f"Unknown LoRA: {lora_name}; for model: {request.model}")
415+
365416
# Reject unsupported features if requested
366417
if request.n and request.n > 1:
367418
raise Exception(
@@ -396,7 +447,10 @@ async def _streaming_completion_iterator(
396447
yield "data: [DONE]\n\n"
397448

398449
def _validate_completion_request(
399-
self, request: CreateCompletionRequest, metadata: TritonModelMetadata
450+
self,
451+
request: CreateCompletionRequest,
452+
metadata: TritonModelMetadata,
453+
lora_name: str | None,
400454
):
401455
"""
402456
Validates a completions request to align with currently supported features.
@@ -411,6 +465,13 @@ def _validate_completion_request(
411465
if not metadata.request_converter:
412466
raise Exception(f"Unknown request format for model: {request.model}")
413467

468+
if (
469+
metadata.lora_names is not None
470+
and lora_name is not None
471+
and lora_name not in metadata.lora_names
472+
):
473+
raise Exception(f"Unknown LoRA: {lora_name}; for model: {request.model}")
474+
414475
# Reject unsupported features if requested
415476
if request.suffix is not None:
416477
raise Exception("suffix is not currently supported")

0 commit comments

Comments
 (0)