Skip to content

Commit bf88f4a

Browse files
committed
Fix L0_orca_trtllm which was broken due to new changes in the trtllm directory structure and add func defn
1 parent ff4bd4e commit bf88f4a

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

qa/L0_orca/test.sh

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ MODEL_NAME="gpt2_tensorrt_llm"
3636
NAME="tensorrt_llm_benchmarking_test"
3737
MODEL_REPOSITORY="$(pwd)/triton_model_repo"
3838
TENSORRTLLM_BACKEND_DIR="/workspace/tensorrtllm_backend"
39-
GPT_DIR="$TENSORRTLLM_BACKEND_DIR/tensorrt_llm/examples/gpt"
39+
GPT_DIR="$TENSORRTLLM_BACKEND_DIR/tensorrt_llm/examples/models/core/gpt"
4040
TOKENIZER_DIR="$GPT_DIR/gpt2"
4141
ENGINES_DIR="${BASE_DIR}/engines/inflight_batcher_llm/${NUM_GPUS}-gpu"
4242
TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"}
@@ -48,6 +48,13 @@ CLIENT_PY=${BASE_DIR}/orca_http_test.py
4848
CLIENT_LOG="${NAME}_orca_http_test.log"
4949
source ../common/util.sh
5050

51+
function replace_config_tags {
52+
tag_to_replace="${1}"
53+
new_value="${2}"
54+
config_file_path="${3}"
55+
sed -i "s|${tag_to_replace}|${new_value}|g" ${config_file_path}
56+
}
57+
5158
function prepare_model_repository {
5259
rm -rf ${MODEL_REPOSITORY} && mkdir ${MODEL_REPOSITORY}
5360
cp -r ${TENSORRTLLM_BACKEND_DIR}/all_models/inflight_batcher_llm/* ${MODEL_REPOSITORY}
@@ -138,6 +145,42 @@ function kill_server {
138145
done
139146
}
140147

148+
function clone_tensorrt_llm_backend_repo {
149+
rm -rf $TENSORRTLLM_BACKEND_DIR && mkdir $TENSORRTLLM_BACKEND_DIR
150+
apt-get update && apt-get install git-lfs -y --no-install-recommends
151+
git clone --single-branch --depth=1 -b ${TENSORRTLLM_BACKEND_REPO_TAG} ${TRITON_REPO_ORG}/tensorrtllm_backend.git $TENSORRTLLM_BACKEND_DIR
152+
cd $TENSORRTLLM_BACKEND_DIR && git lfs install && git submodule update --init --recursive
153+
}
154+
155+
function build_gpt2_base_model {
156+
# Download weights from HuggingFace Transformers
157+
cd ${GPT_DIR} && rm -rf gpt2 && git clone https://huggingface.co/gpt2-medium gpt2 && cd gpt2
158+
rm pytorch_model.bin model.safetensors
159+
if ! wget -q https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin; then
160+
echo "Downloading pytorch_model.bin failed."
161+
exit 1
162+
fi
163+
cd ${GPT_DIR}
164+
165+
# Convert weights from HF Tranformers to FT format
166+
python3 convert_checkpoint.py --model_dir gpt2 --dtype float16 --tp_size ${NUM_GPUS} --output_dir "./c-model/gpt2/${NUM_GPUS}-gpu/"
167+
cd ${BASE_DIR}
168+
}
169+
170+
function build_gpt2_tensorrt_engine {
171+
# Build TensorRT engines
172+
cd ${GPT_DIR}
173+
trtllm-build --checkpoint_dir "./c-model/gpt2/${NUM_GPUS}-gpu/" \
174+
--gpt_attention_plugin float16 \
175+
--remove_input_padding enable \
176+
--paged_kv_cache enable \
177+
--gemm_plugin float16 \
178+
--workers "${NUM_GPUS}" \
179+
--output_dir "${ENGINES_DIR}"
180+
181+
cd ${BASE_DIR}
182+
}
183+
141184
clone_tensorrt_llm_backend_repo
142185
build_gpt2_base_model
143186
build_gpt2_tensorrt_engine

qa/L0_perf_tensorrt_llm/test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -35,7 +35,7 @@ MODEL_NAME="gpt2_tensorrt_llm"
3535
NAME="tensorrt_llm_benchmarking_test"
3636
MODEL_REPOSITORY="$(pwd)/triton_model_repo"
3737
TENSORRTLLM_BACKEND_DIR="/workspace/tensorrtllm_backend"
38-
GPT_DIR="$TENSORRTLLM_BACKEND_DIR/tensorrt_llm/examples/gpt"
38+
GPT_DIR="$TENSORRTLLM_BACKEND_DIR/tensorrt_llm/examples/models/core/gpt"
3939
TOKENIZER_DIR="$GPT_DIR/gpt2"
4040
ENGINES_DIR="${BASE_DIR}/engines/inflight_batcher_llm/${NUM_GPUS}-gpu"
4141
TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"}

0 commit comments

Comments
 (0)