Skip to content

Commit c736968

Browse files
Xharktensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 381320453
1 parent 83ad7bd commit c736968

File tree

4 files changed

+52
-7
lines changed

4 files changed

+52
-7
lines changed

official/benchmark/base_benchmark.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,16 @@
3232
from official.modeling import hyperparams
3333

3434

35-
def _get_benchmark_params(benchmark_models):
35+
def _get_benchmark_params(benchmark_models, eval_tflite=False):
3636
"""Formats benchmark params into a list."""
3737
parameterized_benchmark_params = []
3838
for _, benchmarks in benchmark_models.items():
3939
for name, params in benchmarks.items():
40-
for execution_mode in ['performance', 'accuracy']:
40+
if eval_tflite:
41+
execution_modes = ['performance', 'tflite_accuracy']
42+
else:
43+
execution_modes = ['performance', 'accuracy']
44+
for execution_mode in execution_modes:
4145
benchmark_name = '{}.{}'.format(name, execution_mode)
4246
benchmark_params = (
4347
benchmark_name, # First arg is used by ParameterizedBenchmark.
@@ -66,7 +70,8 @@ class BaseBenchmark( # pylint: disable=undefined-variable
6670

6771
_benchmark_parameters = _get_benchmark_params(
6872
benchmark_definitions.VISION_BENCHMARKS) + _get_benchmark_params(
69-
benchmark_definitions.NLP_BENCHMARKS)
73+
benchmark_definitions.NLP_BENCHMARKS) + _get_benchmark_params(
74+
benchmark_definitions.QAT_BENCHMARKS, True)
7075

7176
def __init__(self,
7277
output_dir=None,
@@ -144,7 +149,7 @@ def benchmark(self,
144149
execution_mode, params, self._get_model_dir(benchmark_name))
145150

146151
metrics = []
147-
if execution_mode == 'accuracy':
152+
if execution_mode in ['accuracy', 'tflite_accuracy']:
148153
for metric_bound in metric_bounds:
149154
metric = {
150155
'name': metric_bound['name'],

official/benchmark/benchmark_definitions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,6 @@
5151

5252
NLP_BENCHMARKS = {
5353
}
54+
55+
QAT_BENCHMARKS = {
56+
}

official/benchmark/benchmark_lib.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from absl import logging
2222
import orbit
2323
import tensorflow as tf
24+
from official.benchmark import tflite_utils
2425
from official.common import distribute_utils
2526
from official.core import config_definitions
2627
from official.core import task_factory
@@ -37,15 +38,18 @@ def run_benchmark(
3738
"""Runs benchmark for a specific experiment.
3839
3940
Args:
40-
execution_mode: A 'str', specifying the mode. Can be 'accuracy', or
41-
'performance'.
41+
execution_mode: A 'str', specifying the mode. Can be 'accuracy',
42+
'performance', or 'tflite_accuracy'.
4243
params: ExperimentConfig instance.
4344
model_dir: A 'str', a path to store model checkpoints and summaries.
4445
distribution_strategy: A tf.distribute.Strategy to use. If specified,
4546
it will be used instead of inferring the strategy from params.
4647
4748
Returns:
4849
benchmark_data: returns benchmark data in dict format.
50+
51+
Raises:
52+
NotImplementedError: If try to use unsupported setup.
4953
"""
5054

5155
# For GPU runs, allow option to set thread mode
@@ -77,7 +81,7 @@ def run_benchmark(
7781
trainer.initialize()
7882

7983
steps_per_loop = params.trainer.steps_per_loop if (
80-
execution_mode == 'accuracy') else 100
84+
execution_mode in ['accuracy', 'tflite_accuracy']) else 100
8185
controller = orbit.Controller(
8286
strategy=strategy,
8387
trainer=trainer,
@@ -105,6 +109,10 @@ def run_benchmark(
105109
benchmark_data = {'metrics': eval_logs}
106110
elif execution_mode == 'performance':
107111
benchmark_data = {}
112+
elif execution_mode == 'tflite_accuracy':
113+
eval_logs = tflite_utils.train_and_evaluate(
114+
params, task, trainer, controller)
115+
benchmark_data = {'metrics': eval_logs}
108116
else:
109117
raise NotImplementedError(
110118
'The benchmark execution mode is not implemented: %s' %

official/benchmark/tflite_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""TFLite utils."""
16+
import orbit
17+
from official.core import base_task
18+
from official.core import base_trainer
19+
from official.core import config_definitions
20+
21+
22+
def train_and_evaluate(
23+
params: config_definitions.ExperimentConfig,
24+
task: base_task.Task,
25+
trainer: base_trainer.Trainer,
26+
controller: orbit.Controller):
27+
"""Train and evaluate on TFLite."""
28+
raise NotImplementedError('train_and_evaluate on tflite_utils is not '
29+
'implemented yet.')

0 commit comments

Comments
 (0)