Skip to content

Commit 43107e3

Browse files
Improve TensorRT-LLM Functionality (#487)
Changes to get tensorrtllm to work with Mixtral Update tensorrt llm included code/build processes to a newer version Add some bits to mitigate some tokenization issues Note: the logprobs returned aren't correct still, haven't investigated. Stop sequences don't completely work, to my knowledge this is somewhat of a limitation of how tensorrt/triton works, but there may be another way around this.
1 parent 6e2ebff commit 43107e3

File tree

15 files changed

+1443
-113
lines changed

15 files changed

+1443
-113
lines changed

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@
216216
"llama-2-70b-chat",
217217
]
218218
),
219-
LLMInferenceFramework.TENSORRT_LLM: set(["llama-2-7b"]),
219+
LLMInferenceFramework.TENSORRT_LLM: set(
220+
["llama-2-7b", "mixtral-8x7b", "mixtral-8x7b-instruct"]
221+
),
220222
}
221223

222224
_SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = {
@@ -1467,11 +1469,28 @@ def model_output_to_completion_output(
14671469
num_prompt_tokens = count_tokens(
14681470
prompt, model_content.model_name, self.tokenizer_repository
14691471
)
1470-
return CompletionOutput(
1472+
if "token_ids" in model_output:
1473+
# TensorRT 23.10 has this field, TensorRT 24.03 does not
1474+
# For backwards compatibility with pre-2024/05/02
1475+
num_completion_tokens = len(model_output["token_ids"]) - num_prompt_tokens
14711476
# Output is "<s> prompt output"
1472-
text=model_output["text_output"][(len(prompt) + 4) :],
1477+
text = model_output["text_output"][(len(prompt) + 4) :]
1478+
elif "output_log_probs" in model_output:
1479+
# TensorRT 24.01 + surrounding code.
1480+
# For some reason TRT returns output_log_probs as either a list or a float
1481+
# Also the log probs don't look right, so returning log-probs is still broken
1482+
num_completion_tokens = (
1483+
len(model_output["output_log_probs"])
1484+
if type(model_output["output_log_probs"]) == list
1485+
else 1
1486+
)
1487+
# Output is just "output". See `exclude_input_in_output` inside of
1488+
# inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt
1489+
text = model_output["text_output"]
1490+
return CompletionOutput(
1491+
text=text,
14731492
num_prompt_tokens=num_prompt_tokens,
1474-
num_completion_tokens=len(model_output["token_ids"]) - num_prompt_tokens,
1493+
num_completion_tokens=num_completion_tokens,
14751494
)
14761495
else:
14771496
raise EndpointUnsupportedInferenceTypeException(

model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3
1+
FROM nvcr.io/nvidia/tritonserver:24.03-trtllm-python-py3
22

33
COPY requirements.txt /workspace/requirements.txt
44
WORKDIR /workspace
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Preparing the model weights/tokenizers
2+
3+
Our TensorRT-LLM docker image expects weights to live in s3/other blob store with the following directory structure:
4+
5+
root/
6+
model_tokenizer/
7+
<everything in a HF directory other than the weights themselves>
8+
model_weights/
9+
config.json
10+
rank<i>.engine
11+
12+
You can obtain `model_weights` by building a TRT-LLM engine via the directions found on Nvidia's site (e.g. https://github.com/NVIDIA/TensorRT-LLM/blob/main/README.md#installation, https://github.com/NVIDIA/TensorRT-LLM/blob/v0.8.0/examples/llama/convert_checkpoint.py)
13+
14+
The inference image is built via the Dockerfile in the same directory as this readme.

model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ def parse_arguments():
99
"--world_size", type=int, default=1, help="world size, only support tensor parallelism now"
1010
)
1111
parser.add_argument("--tritonserver", type=str, default="/opt/tritonserver/bin/tritonserver")
12+
parser.add_argument(
13+
"--http-address",
14+
type=str,
15+
default="ipv6:[::1]",
16+
help="Default HTTP address to ipv6:[::1].",
17+
)
1218
parser.add_argument(
1319
"--http-port",
1420
type=int,
@@ -20,14 +26,16 @@ def parse_arguments():
2026
return parser.parse_args()
2127

2228

23-
def get_cmd(world_size, tritonserver, model_repo, http_port):
29+
def get_cmd(world_size, tritonserver, model_repo, http_address, http_port):
2430
cmd = "mpirun --allow-run-as-root "
2531
for i in range(world_size):
26-
cmd += f" -n 1 {tritonserver} --model-repository={model_repo} --http-address ipv6:[::1] --http-port {http_port} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{i}_ : "
32+
cmd += f" -n 1 {tritonserver} --model-repository={model_repo} --http-address {http_address} --http-port {http_port} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{i}_ : "
2733
return cmd
2834

2935

3036
if __name__ == "__main__":
3137
args = parse_arguments()
32-
cmd = get_cmd(int(args.world_size), args.tritonserver, args.model_repo, args.http_port)
38+
cmd = get_cmd(
39+
int(args.world_size), args.tritonserver, args.model_repo, args.http_address, args.http_port
40+
)
3341
subprocess.call(cmd, shell=True)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
sentencepiece==0.1.99
2-
protobuf==4.24.4
2+
protobuf==4.24.4
3+
torch==2.2.2

0 commit comments

Comments
 (0)