Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit f9f566b

Browse files
BERT & BART refactored
1 parent 1d11290 commit f9f566b

File tree

2 files changed

+115
-148
lines changed

2 files changed

+115
-148
lines changed

tftrt/examples/transformers/scripts/base_script.sh

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ DATA_DIR="/tmp"
1414

1515
BYPASS_ARGUMENTS=""
1616
TF_AUTO_JIT_XLA_FLAG=""
17+
BATCH_SIZE=32
18+
SEQ_LEN=128
1719

1820
# Loop through arguments and process them
1921
for arg in "$@"
@@ -27,14 +29,18 @@ do
2729
NVIDIA_TF32_OVERRIDE="NVIDIA_TF32_OVERRIDE=0"
2830
shift # Remove --no_tf32 from processing
2931
;;
32+
--batch_size=*)
33+
BATCH_SIZE="${arg#*=}"
34+
shift # Remove --batch_size= from processing
35+
;;
3036
--data_dir=*)
3137
shift # Remove --data_dir= from processing
3238
;;
33-
--vocab_size=*)
34-
shift # Remove --vocab_size= from processing
39+
--total_max_samples=*)
40+
shift # Remove --total_max_samples= from processing
3541
;;
36-
--minimum_segment_size=*)
37-
shift # Remove --minimum_segment_size= from processing
42+
--output_tensors_name=*)
43+
shift # Remove --output_tensors_name= from processing
3844
;;
3945
--input_saved_model_dir=*)
4046
MODEL_DIR="${arg#*=}"
@@ -44,6 +50,13 @@ do
4450
TF_AUTO_JIT_XLA_FLAG="TF_XLA_FLAGS=--tf_xla_auto_jit=2"
4551
shift # Remove --use_xla_auto_jit from processing
4652
;;
53+
--vocab_size=*)
54+
shift # Remove --vocab_size= from processing
55+
;;
56+
--sequence_length=*)
57+
SEQ_LEN="${arg#*=}"
58+
shift # Remove --sequence_length= from processing
59+
;;
4760
*)
4861
BYPASS_ARGUMENTS=" ${BYPASS_ARGUMENTS} ${arg}"
4962
;;
@@ -54,6 +67,9 @@ done
5467

5568
MIN_SEGMENT_SIZE=5
5669
VOCAB_SIZE=-1
70+
MAX_WORKSPACE_SIZE=$((2 ** (32 + 1))) # + 1 necessary compared to python
71+
MAX_SAMPLES=1
72+
OUTPUT_TENSORS_NAME="prediction_logits,seq_relationship_logits"
5773

5874
case ${MODEL_NAME} in
5975
"bert_base_uncased" | "bert_large_uncased")
@@ -67,6 +83,7 @@ case ${MODEL_NAME} in
6783
"bart_base" | "bart_large")
6884
VOCAB_SIZE=50265
6985
MIN_SEGMENT_SIZE=90
86+
OUTPUT_TENSORS_NAME="encoder_last_hidden_state,logits"
7087
;;
7188
esac
7289

@@ -80,9 +97,12 @@ echo "[*] MODEL_DIR: ${MODEL_DIR}"
8097
echo ""
8198
echo "[*] NVIDIA_TF32_OVERRIDE: ${NVIDIA_TF32_OVERRIDE}"
8299
echo ""
83-
# Custom Transormers Task Flags
84-
echo "[*] MIN_SEGMENT_SIZE: ${MIN_SEGMENT_SIZE}"
100+
# Custom Transormer Task Flags
85101
echo "[*] VOCAB_SIZE: ${VOCAB_SIZE}"
102+
echo "[*] SEQ_LEN: ${SEQ_LEN}"
103+
echo "[*] MAX_WORKSPACE_SIZE: ${MAX_WORKSPACE_SIZE}"
104+
echo "[*] MAX_SAMPLES: ${MAX_SAMPLES}"
105+
echo "[*] OUTPUT_TENSORS_NAME: ${OUTPUT_TENSORS_NAME}"
86106
echo ""
87107
echo "[*] TF_AUTO_JIT_XLA_FLAG: ${TF_AUTO_JIT_XLA_FLAG}"
88108
echo "[*] BYPASS_ARGUMENTS: $(echo \"${BYPASS_ARGUMENTS}\" | tr -s ' ')"
@@ -132,10 +152,16 @@ cd ${BENCH_DIR}
132152
PREPEND_COMMAND="${TF_AUTO_JIT_XLA_FLAG} ${NVIDIA_TF32_OVERRIDE}"
133153

134154
COMMAND="${PREPEND_COMMAND} python transformers.py \
135-
--input_saved_model_dir ${INPUT_SAVED_MODEL_DIR} \
136155
--data_dir ${DATA_DIR} \
156+
--calib_data_dir ${DATA_DIR} \
157+
--input_saved_model_dir ${INPUT_SAVED_MODEL_DIR} \
158+
--output_saved_model_dir /tmp/$RANDOM \
159+
--batch_size ${BATCH_SIZE} \
137160
--vocab_size ${VOCAB_SIZE} \
138-
--minimum_segment_size ${MIN_SEGMENT_SIZE} \
161+
--sequence_length=${SEQ_LEN} \
162+
--max_workspace_size ${MAX_WORKSPACE_SIZE} \
163+
--total_max_samples=${MAX_SAMPLES} \
164+
--output_tensors_name=${OUTPUT_TENSORS_NAME} \
139165
${BYPASS_ARGUMENTS}"
140166

141167
COMMAND=$(echo ${COMMAND} | sed 's/ *$//g') # Trimming whitespaces

tftrt/examples/transformers/transformers.py

Lines changed: 81 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,9 @@
1818
import os
1919
import sys
2020

21-
import logging
22-
import multiprocessing
23-
import time
24-
25-
from functools import partial
26-
2721
import numpy as np
28-
import tensorflow as tf
2922

30-
from statistics import mean
23+
import tensorflow as tf
3124

3225
# Allow import of top level python files
3326
import inspect
@@ -54,34 +47,33 @@ def __init__(self):
5447
super(CommandLineAPI, self).__init__()
5548

5649
self._parser.add_argument(
57-
'--sequence_length',
50+
"--sequence_length",
5851
type=int,
5952
default=128,
60-
help='Directory containing the input saved model.'
53+
help="Input data sequence length."
6154
)
6255

6356
self._parser.add_argument(
64-
'--vocab_size',
57+
"--vocab_size",
6558
type=int,
6659
required=True,
6760
choices=self.ALLOWED_VOCAB_SIZES,
68-
help='Size of the vocabulory used for '
69-
'training. Refer to huggingface '
70-
'documentation.'
61+
help="Size of the vocabulory used for training. Refer to "
62+
"huggingface documentation."
7163
)
7264

73-
self._parser.add_argument(
74-
'--validate_output',
75-
action='store_true',
76-
help='Validates that the model returns the correct '
77-
'value. This only works with batch_size =32.'
78-
)
65+
# self._parser.add_argument(
66+
# "--validate_output",
67+
# action="store_true",
68+
# help="Validates that the model returns the correct value. This "
69+
# "only works with batch_size =32."
70+
# )
7971

8072
def _validate_args(self, args):
8173
super(CommandLineAPI, self)._validate_args(args)
8274

83-
if args.validate_output and args.batch_size != 32:
84-
raise ValueError("Output validation only supports batch size 32.")
75+
# if args.validate_output and args.batch_size != 32:
76+
# raise ValueError("Output validation only supports batch size 32.")
8577

8678
# TODO: Remove when proper dataloading is implemented
8779
if args.num_iterations is None:
@@ -90,145 +82,94 @@ def _validate_args(self, args):
9082
"--num_iterations=None"
9183
)
9284

93-
# TODO: Remove when proper dataloading is implemented
9485
def _post_process_args(self, args):
86+
args = super(CommandLineAPI, self)._post_process_args(args)
87+
9588
return args
9689

9790

91+
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
92+
# %%%%%%%%%%%%%%%%% IMPLEMENT MODEL-SPECIFIC FUNCTIONS HERE %%%%%%%%%%%%%%%%%% #
93+
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
94+
95+
9896
class BenchmarkRunner(BaseBenchmarkRunner):
9997

100-
ACCURACY_METRIC_NAME = "mAP"
98+
def get_dataset_batches(self):
99+
"""Returns a list of batches of input samples.
101100
102-
def before_benchmark(self, **kwargs):
103-
pass
101+
Each batch should be in the form [x, y], where
102+
x is a numpy array of the input samples for the batch, and
103+
y is a numpy array of the expected model outputs for the batch
104104
105-
def compute_accuracy_metric(self, predictions, expected, **kwargs):
106-
pass
105+
Returns:
106+
- dataset: a TF Dataset object
107+
- bypass_data_to_eval: any object type that will be passed unmodified to
108+
`evaluate_result()`. If not necessary: `None`
107109
108-
def process_model_output(self, outputs, **kwargs):
109-
pass
110+
Note: script arguments can be accessed using `self._args.attr`
111+
"""
110112

113+
if not self._args.use_synthetic_data:
114+
raise NotImplementedError()
111115

112-
# def validate_model_artifacts(infer_func, model_dir, use_tftrt, precision):
113-
# numpy_asset_dir = os.path.join(model_dir, "numpy_assets")
114-
#
115-
# input_data = np.load(os.path.join(numpy_asset_dir, 'input_data.npy'))
116-
# input_data = tf.constant(input_data, dtype=tf.int32)
117-
#
118-
# output = infer_func(input_ids=input_data)
119-
#
120-
# if use_tftrt:
121-
# if precision == "fp16":
122-
# rtol=1e-2
123-
# atol=2e-1
124-
# else:
125-
# rtol=1e-2
126-
# atol=5e-2
127-
# else:
128-
# rtol=1e-5
129-
# atol=1e-8
130-
#
131-
# for key in output.keys():
132-
# target = np.load(os.path.join(numpy_asset_dir, '%s.npy' % key))
133-
# np.testing.assert_allclose(
134-
# target, output[key].numpy(), rtol=rtol, atol=atol
135-
# )
136-
# print("\n*****************************************************************")
137-
# print("Model was validated with success ...")
138-
# print("*****************************************************************\n")
116+
tf.random.set_seed(10)
117+
118+
input_data = tf.random.uniform(
119+
shape=(1, self._args.sequence_length),
120+
maxval=self._args.vocab_size,
121+
dtype=tf.int32
122+
)
123+
124+
dataset = tf.data.Dataset.from_tensor_slices(input_data)
125+
dataset = dataset.repeat()
126+
dataset = dataset.batch(self._args.batch_size)
127+
dataset = dataset.take(count=1) # loop over 1 batch
128+
dataset = dataset.cache()
129+
dataset = dataset.repeat()
130+
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
131+
132+
return dataset, None
133+
134+
def preprocess_model_inputs(self, data_batch):
135+
"""This function prepare the `data_batch` generated from the dataset.
136+
Returns:
137+
x: input of the model
138+
y: data to be used for model evaluation
139139
140+
Note: script arguments can be accessed using `self._args.attr`
141+
"""
140142

141-
def get_dataset(batch_size, seq_len, vocab_size, use_synthetic_data):
143+
x = data_batch
144+
return x, None
142145

143-
if not use_synthetic_data:
144-
raise NotImplementedError()
146+
def postprocess_model_outputs(self, predictions, expected):
147+
"""Post process if needed the predictions and expected tensors. At the
148+
minimum, this function transforms all TF Tensors into a numpy arrays.
149+
Most models will not need to modify this function.
145150
146-
tf.random.set_seed(10)
147-
input_data = tf.random.uniform(
148-
shape=(1, seq_len), maxval=vocab_size, dtype=tf.int32
149-
)
151+
Note: script arguments can be accessed using `self._args.attr`
152+
"""
150153

151-
dataset = tf.data.Dataset.from_tensor_slices(input_data)
152-
dataset = dataset.repeat()
153-
dataset = dataset.batch(batch_size)
154-
dataset = dataset.take(count=1) # loop over 1 batch
155-
dataset = dataset.cache()
156-
dataset = dataset.repeat()
157-
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
154+
return predictions.numpy(), expected.numpy()
158155

159-
return dataset
156+
def evaluate_model(self, predictions, expected, bypass_data_to_eval):
157+
"""Evaluate result predictions for entire dataset.
158+
159+
This computes overall accuracy, mAP, etc. Returns the
160+
metric value and a metric_units string naming the metric.
161+
162+
Note: script arguments can be accessed using `args.attr`
163+
"""
164+
165+
return None, "Top-1 Accuracy %"
160166

161167

162168
if __name__ == '__main__':
169+
163170
cmdline_api = CommandLineAPI()
164171
args = cmdline_api.parse_args()
165172

166-
def _input_fn(build_steps, model_phase):
167-
168-
dataset = get_dataset(
169-
batch_size=args.batch_size,
170-
seq_len=args.sequence_length,
171-
vocab_size=args.vocab_size,
172-
use_synthetic_data=args.use_synthetic_data
173-
)
173+
runner = BenchmarkRunner(args)
174174

175-
for i, (input_batch) in enumerate(dataset):
176-
if i >= build_steps:
177-
break
178-
179-
print("* [%s] - step %04d/%04d" % (model_phase, i + 1, build_steps))
180-
yield input_batch,
181-
182-
calibration_input_fn = partial(
183-
_input_fn,
184-
build_steps=args.num_calib_batches // args.batch_size,
185-
model_phase="Calibration"
186-
)
187-
optimize_offline_input_fn = partial(
188-
_input_fn, build_steps=1, model_phase="Building"
189-
)
190-
191-
runner = BenchmarkRunner(
192-
input_saved_model_dir=args.input_saved_model_dir,
193-
output_saved_model_dir=args.output_saved_model_dir,
194-
allow_build_at_runtime=args.allow_build_at_runtime,
195-
calibration_input_fn=calibration_input_fn,
196-
debug=args.debug,
197-
gpu_mem_cap=args.gpu_mem_cap,
198-
input_signature_key=args.input_signature_key,
199-
max_workspace_size_bytes=args.max_workspace_size,
200-
minimum_segment_size=args.minimum_segment_size,
201-
num_calib_batches=args.num_calib_batches,
202-
optimize_offline=args.optimize_offline,
203-
optimize_offline_input_fn=optimize_offline_input_fn,
204-
output_tensor_names=args.output_tensor_names,
205-
precision_mode=args.precision,
206-
use_dynamic_shape=args.use_dynamic_shape,
207-
use_tftrt=args.use_tftrt
208-
)
209-
210-
# if args.validate_output:
211-
# # artifacts only generated for BS == 32
212-
# validate_model_artifacts(
213-
# graph_func,
214-
# args.input_saved_model_dir,
215-
# args.use_tftrt,
216-
# args.precision.lower()
217-
# )
218-
219-
get_benchmark_input_fn = partial(
220-
get_dataset, seq_len=args.sequence_length, vocab_size=args.vocab_size
221-
)
222-
223-
runner.execute_benchmark(
224-
batch_size=args.batch_size,
225-
display_every=args.display_every,
226-
get_benchmark_input_fn=get_benchmark_input_fn,
227-
num_iterations=args.num_iterations,
228-
num_warmup_iterations=args.num_warmup_iterations,
229-
skip_accuracy_testing=(
230-
args.use_synthetic_data or args.skip_accuracy_testing
231-
),
232-
use_synthetic_data=args.use_synthetic_data,
233-
use_xla=args.use_xla,
234-
)
175+
runner.execute_benchmark()

0 commit comments

Comments
 (0)