1818import os
1919import sys
2020
21- import logging
22- import multiprocessing
23- import time
24-
25- from functools import partial
26-
2721import numpy as np
28- import tensorflow as tf
2922
30- from statistics import mean
23+ import tensorflow as tf
3124
3225# Allow import of top level python files
3326import inspect
@@ -54,34 +47,33 @@ def __init__(self):
5447 super (CommandLineAPI , self ).__init__ ()
5548
5649 self ._parser .add_argument (
57- ' --sequence_length' ,
50+ " --sequence_length" ,
5851 type = int ,
5952 default = 128 ,
60- help = 'Directory containing the input saved model.'
53+ help = "Input data sequence length."
6154 )
6255
6356 self ._parser .add_argument (
64- ' --vocab_size' ,
57+ " --vocab_size" ,
6558 type = int ,
6659 required = True ,
6760 choices = self .ALLOWED_VOCAB_SIZES ,
68- help = 'Size of the vocabulory used for '
69- 'training. Refer to huggingface '
70- 'documentation.'
61+ help = "Size of the vocabulory used for training. Refer to "
62+ "huggingface documentation."
7163 )
7264
73- self ._parser .add_argument (
74- ' --validate_output' ,
75- action = ' store_true' ,
76- help = ' Validates that the model returns the correct '
77- 'value. This only works with batch_size =32.'
78- )
65+ # self._parser.add_argument(
66+ # " --validate_output" ,
67+ # action=" store_true" ,
68+ # help=" Validates that the model returns the correct value. This "
69+ # " only works with batch_size =32."
70+ # )
7971
8072 def _validate_args (self , args ):
8173 super (CommandLineAPI , self )._validate_args (args )
8274
83- if args .validate_output and args .batch_size != 32 :
84- raise ValueError ("Output validation only supports batch size 32." )
75+ # if args.validate_output and args.batch_size != 32:
76+ # raise ValueError("Output validation only supports batch size 32.")
8577
8678 # TODO: Remove when proper dataloading is implemented
8779 if args .num_iterations is None :
@@ -90,145 +82,94 @@ def _validate_args(self, args):
9082 "--num_iterations=None"
9183 )
9284
93- # TODO: Remove when proper dataloading is implemented
9485 def _post_process_args (self , args ):
86+ args = super (CommandLineAPI , self )._post_process_args (args )
87+
9588 return args
9689
9790
91+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
92+ # %%%%%%%%%%%%%%%%% IMPLEMENT MODEL-SPECIFIC FUNCTIONS HERE %%%%%%%%%%%%%%%%%% #
93+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
94+
95+
9896class BenchmarkRunner (BaseBenchmarkRunner ):
9997
100- ACCURACY_METRIC_NAME = "mAP"
98+ def get_dataset_batches (self ):
99+ """Returns a list of batches of input samples.
101100
102- def before_benchmark (self , ** kwargs ):
103- pass
101+ Each batch should be in the form [x, y], where
102+ x is a numpy array of the input samples for the batch, and
103+ y is a numpy array of the expected model outputs for the batch
104104
105- def compute_accuracy_metric (self , predictions , expected , ** kwargs ):
106- pass
105+ Returns:
106+ - dataset: a TF Dataset object
107+ - bypass_data_to_eval: any object type that will be passed unmodified to
108+ `evaluate_result()`. If not necessary: `None`
107109
108- def process_model_output ( self , outputs , ** kwargs ):
109- pass
110+ Note: script arguments can be accessed using `self._args.attr`
111+ """
110112
113+ if not self ._args .use_synthetic_data :
114+ raise NotImplementedError ()
111115
112- # def validate_model_artifacts(infer_func, model_dir, use_tftrt, precision):
113- # numpy_asset_dir = os.path.join(model_dir, "numpy_assets")
114- #
115- # input_data = np.load(os.path.join(numpy_asset_dir, 'input_data.npy'))
116- # input_data = tf.constant(input_data, dtype=tf.int32)
117- #
118- # output = infer_func(input_ids=input_data)
119- #
120- # if use_tftrt:
121- # if precision == "fp16":
122- # rtol=1e-2
123- # atol=2e-1
124- # else:
125- # rtol=1e-2
126- # atol=5e-2
127- # else:
128- # rtol=1e-5
129- # atol=1e-8
130- #
131- # for key in output.keys():
132- # target = np.load(os.path.join(numpy_asset_dir, '%s.npy' % key))
133- # np.testing.assert_allclose(
134- # target, output[key].numpy(), rtol=rtol, atol=atol
135- # )
136- # print("\n*****************************************************************")
137- # print("Model was validated with success ...")
138- # print("*****************************************************************\n")
116+ tf .random .set_seed (10 )
117+
118+ input_data = tf .random .uniform (
119+ shape = (1 , self ._args .sequence_length ),
120+ maxval = self ._args .vocab_size ,
121+ dtype = tf .int32
122+ )
123+
124+ dataset = tf .data .Dataset .from_tensor_slices (input_data )
125+ dataset = dataset .repeat ()
126+ dataset = dataset .batch (self ._args .batch_size )
127+ dataset = dataset .take (count = 1 ) # loop over 1 batch
128+ dataset = dataset .cache ()
129+ dataset = dataset .repeat ()
130+ dataset = dataset .prefetch (tf .data .experimental .AUTOTUNE )
131+
132+ return dataset , None
133+
134+ def preprocess_model_inputs (self , data_batch ):
135+ """This function prepare the `data_batch` generated from the dataset.
136+ Returns:
137+ x: input of the model
138+ y: data to be used for model evaluation
139139
140+ Note: script arguments can be accessed using `self._args.attr`
141+ """
140142
141- def get_dataset (batch_size , seq_len , vocab_size , use_synthetic_data ):
143+ x = data_batch
144+ return x , None
142145
143- if not use_synthetic_data :
144- raise NotImplementedError ()
146+ def postprocess_model_outputs (self , predictions , expected ):
147+ """Post process if needed the predictions and expected tensors. At the
148+ minimum, this function transforms all TF Tensors into a numpy arrays.
149+ Most models will not need to modify this function.
145150
146- tf .random .set_seed (10 )
147- input_data = tf .random .uniform (
148- shape = (1 , seq_len ), maxval = vocab_size , dtype = tf .int32
149- )
151+ Note: script arguments can be accessed using `self._args.attr`
152+ """
150153
151- dataset = tf .data .Dataset .from_tensor_slices (input_data )
152- dataset = dataset .repeat ()
153- dataset = dataset .batch (batch_size )
154- dataset = dataset .take (count = 1 ) # loop over 1 batch
155- dataset = dataset .cache ()
156- dataset = dataset .repeat ()
157- dataset = dataset .prefetch (tf .data .experimental .AUTOTUNE )
154+ return predictions .numpy (), expected .numpy ()
158155
159- return dataset
156+ def evaluate_model (self , predictions , expected , bypass_data_to_eval ):
157+ """Evaluate result predictions for entire dataset.
158+
159+ This computes overall accuracy, mAP, etc. Returns the
160+ metric value and a metric_units string naming the metric.
161+
162+ Note: script arguments can be accessed using `args.attr`
163+ """
164+
165+ return None , "Top-1 Accuracy %"
160166
161167
162168if __name__ == '__main__' :
169+
163170 cmdline_api = CommandLineAPI ()
164171 args = cmdline_api .parse_args ()
165172
166- def _input_fn (build_steps , model_phase ):
167-
168- dataset = get_dataset (
169- batch_size = args .batch_size ,
170- seq_len = args .sequence_length ,
171- vocab_size = args .vocab_size ,
172- use_synthetic_data = args .use_synthetic_data
173- )
173+ runner = BenchmarkRunner (args )
174174
175- for i , (input_batch ) in enumerate (dataset ):
176- if i >= build_steps :
177- break
178-
179- print ("* [%s] - step %04d/%04d" % (model_phase , i + 1 , build_steps ))
180- yield input_batch ,
181-
182- calibration_input_fn = partial (
183- _input_fn ,
184- build_steps = args .num_calib_batches // args .batch_size ,
185- model_phase = "Calibration"
186- )
187- optimize_offline_input_fn = partial (
188- _input_fn , build_steps = 1 , model_phase = "Building"
189- )
190-
191- runner = BenchmarkRunner (
192- input_saved_model_dir = args .input_saved_model_dir ,
193- output_saved_model_dir = args .output_saved_model_dir ,
194- allow_build_at_runtime = args .allow_build_at_runtime ,
195- calibration_input_fn = calibration_input_fn ,
196- debug = args .debug ,
197- gpu_mem_cap = args .gpu_mem_cap ,
198- input_signature_key = args .input_signature_key ,
199- max_workspace_size_bytes = args .max_workspace_size ,
200- minimum_segment_size = args .minimum_segment_size ,
201- num_calib_batches = args .num_calib_batches ,
202- optimize_offline = args .optimize_offline ,
203- optimize_offline_input_fn = optimize_offline_input_fn ,
204- output_tensor_names = args .output_tensor_names ,
205- precision_mode = args .precision ,
206- use_dynamic_shape = args .use_dynamic_shape ,
207- use_tftrt = args .use_tftrt
208- )
209-
210- # if args.validate_output:
211- # # artifacts only generated for BS == 32
212- # validate_model_artifacts(
213- # graph_func,
214- # args.input_saved_model_dir,
215- # args.use_tftrt,
216- # args.precision.lower()
217- # )
218-
219- get_benchmark_input_fn = partial (
220- get_dataset , seq_len = args .sequence_length , vocab_size = args .vocab_size
221- )
222-
223- runner .execute_benchmark (
224- batch_size = args .batch_size ,
225- display_every = args .display_every ,
226- get_benchmark_input_fn = get_benchmark_input_fn ,
227- num_iterations = args .num_iterations ,
228- num_warmup_iterations = args .num_warmup_iterations ,
229- skip_accuracy_testing = (
230- args .use_synthetic_data or args .skip_accuracy_testing
231- ),
232- use_synthetic_data = args .use_synthetic_data ,
233- use_xla = args .use_xla ,
234- )
175+ runner .execute_benchmark ()
0 commit comments