Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 48b6c03

Browse files
pavanimajetyDEKHTIARJonathan
authored andcommitted
[Benchmark-Py] Adding TF-HUB ALBERT
1 parent 05425d2 commit 48b6c03

File tree

9 files changed

+1490
-0
lines changed

9 files changed

+1490
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#!/bin/#!/usr/bin/env bash
2+
3+
nvidia-smi
4+
5+
set -x
6+
7+
BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
8+
9+
#install packages for the rest of the script
10+
pip install tensorflow_text tensorflow_hub scipy==1.4.1
11+
12+
13+
# Runtime Parameters
14+
MODEL_NAME=""
15+
DATASET_NAME=""
16+
17+
# Default Argument Values
18+
BATCH_SIZE=32
19+
SEQ_LEN=128
20+
VOCAB_SIZE=33000
21+
22+
NUM_ITERATIONS=1000
23+
OUTPUT_TENSOR_NAMES="albert_encoder,albert_encoder_1,albert_encoder_2,albert_encoder_3,albert_encoder_4,albert_encoder_5,albert_encoder_6,albert_encoder_7,albert_encoder_8,albert_encoder_9,albert_encoder_10,albert_encoder_11,albert_encoder_12,albert_encoder_13"
24+
25+
BYPASS_ARGUMENTS=""
26+
27+
# Loop through arguments and process them
28+
for arg in "$@"
29+
do
30+
case $arg in
31+
--model_name=*)
32+
MODEL_NAME="${arg#*=}"
33+
shift # Remove --model_name from processing
34+
;;
35+
--dataset_name=*)
36+
DATASET_NAME="${arg#*=}"
37+
shift # Remove --dataset_name= from processing
38+
;;
39+
--batch_size=*)
40+
BATCH_SIZE="${arg#*=}"
41+
shift # Remove --batch_size= from processing
42+
;;
43+
--sequence_length=*)
44+
SEQ_LEN="${arg#*=}"
45+
shift # Remove --sequence_length= from processing
46+
;;
47+
--num_iterations=*)
48+
NUM_ITERATIONS="${arg#*=}"
49+
shift # Remove --num_iterations= from processing
50+
;;
51+
--vocab_size=*)
52+
VOCAB_SIZE="${arg#*=}"
53+
shift # Remove --vocab_size= from processing
54+
;;
55+
--output_tensors_name=*)
56+
OUTPUT_TENSOR_NAMES="${arg#*=}"
57+
shift # Remove --output_tensors_name= from processing
58+
;;
59+
######### IGNORE ARGUMENTS BELOW
60+
--data_dir=*)
61+
shift # Remove --data_dir= from processing
62+
;;
63+
--input_saved_model_dir=*)
64+
shift # Remove --input_saved_model_dir= from processing
65+
;;
66+
--tokenizer_dir=*)
67+
shift # Remove --tokenizer_model_dir= from processing
68+
;;
69+
--total_max_samples=*)
70+
shift # Remove --total_max_samples= from processing
71+
;;
72+
*)
73+
BYPASS_ARGUMENTS=" ${BYPASS_ARGUMENTS} ${arg}"
74+
;;
75+
esac
76+
done
77+
78+
echo -e "\n********************************************************************"
79+
echo "[*] MODEL_NAME: ${MODEL_NAME}"
80+
echo "[*] DATASET_NAME: ${DATASET_NAME}"
81+
echo ""
82+
echo "[*] DATA_DIR: ${DATA_DIR}"
83+
echo "[*] BATCH_SIZE: ${BATCH_SIZE}"
84+
echo ""
85+
# Custom T5 Task Flags
86+
echo "[*] SEQ_LEN: ${SEQ_LEN}"
87+
echo "[*] OUTPUT_TENSOR_NAMES: ${OUTPUT_TENSOR_NAMES}"
88+
echo ""
89+
echo "[*] BYPASS_ARGUMENTS: $(echo \"${BYPASS_ARGUMENTS}\" | tr -s ' ')"
90+
91+
echo -e "********************************************************************\n"
92+
93+
DATA_DIR="/workspace/tftrt/benchmarking-python/tf_hub/albert/data"
94+
MODEL_DIR="/models/tf_hub/albert/${MODEL_NAME}/"
95+
TOKENIZER_DIR="/models/tf_hub/albert/tokenizer"
96+
97+
if [[ ! -d ${DATA_DIR} ]]; then
98+
echo "ERROR: \`--data_dir=/path/to/directory\` does not exist. [Received: \`${DATA_DIR}\`]"
99+
exit 1
100+
fi
101+
102+
if [[ ! -d ${MODEL_DIR} ]]; then
103+
echo "ERROR: \`--input_saved_model_dir=/path/to/directory\` does not exist. [Received: \`${MODEL_DIR}\`]"
104+
exit 1
105+
fi
106+
107+
if [[ ! -d ${TOKENIZER_DIR} ]]; then
108+
echo "ERROR: \`--tokenizer_dir=/path/to/directory\` does not exist. [Received: \`${TOKENIZER_DIR}\`]"
109+
exit 1
110+
fi
111+
112+
# Dataset Directory
113+
114+
python ${BASE_DIR}/infer.py \
115+
--data_dir=${DATA_DIR} \
116+
--calib_data_dir=${DATA_DIR} \
117+
--input_saved_model_dir=${MODEL_DIR} \
118+
--tokenizer_dir=${TOKENIZER_DIR}\
119+
--output_tensors_name=${OUTPUT_TENSOR_NAMES} \
120+
`# The following is set because we will be running synthetic benchmarks` \
121+
--total_max_samples=1 \
122+
--use_synthetic_data \
123+
--num_iterations=${NUM_ITERATIONS} \
124+
${@}

0 commit comments

Comments
 (0)