|
| 1 | +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +r"""Evaluation executable for detection data. |
| 16 | +
|
| 17 | +This executable evaluates precomputed detections produced by a detection |
| 18 | +model and writes the evaluation results into csv file metrics.csv, stored |
| 19 | +in the directory, specified by --eval_dir. |
| 20 | +
|
| 21 | +The evaluation metrics set is supplied in object_detection.protos.EvalConfig |
| 22 | +in metrics_set field. |
| 23 | +Currently two set of metrics are supported: |
| 24 | +- pascal_voc_metrics: standard PASCAL VOC 2007 metric |
| 25 | +- open_images_metrics: Open Image V2 metric |
| 26 | +All other field of object_detection.protos.EvalConfig are ignored. |
| 27 | +
|
| 28 | +Example usage: |
| 29 | + ./compute_metrics \ |
| 30 | + --eval_dir=path/to/eval_dir \ |
| 31 | + --eval_config_path=path/to/evaluation/configuration/file \ |
| 32 | + --input_config_path=path/to/input/configuration/file |
| 33 | +""" |
| 34 | +import csv |
| 35 | +import os |
| 36 | +import re |
| 37 | +import tensorflow as tf |
| 38 | + |
| 39 | +from object_detection import evaluator |
| 40 | +from object_detection.core import standard_fields |
| 41 | +from object_detection.metrics import tf_example_parser |
| 42 | +from object_detection.utils import config_util |
| 43 | +from object_detection.utils import label_map_util |
| 44 | + |
| 45 | +flags = tf.app.flags |
| 46 | +tf.logging.set_verbosity(tf.logging.INFO) |
| 47 | + |
| 48 | +flags.DEFINE_string('eval_dir', None, 'Directory to write eval summaries to.') |
| 49 | +flags.DEFINE_string('eval_config_path', None, |
| 50 | + 'Path to an eval_pb2.EvalConfig config file.') |
| 51 | +flags.DEFINE_string('input_config_path', None, |
| 52 | + 'Path to an eval_pb2.InputConfig config file.') |
| 53 | + |
| 54 | +FLAGS = flags.FLAGS |
| 55 | + |
| 56 | + |
| 57 | +def _generate_sharded_filenames(filename): |
| 58 | + m = re.search(r'@(\d{1,})', filename) |
| 59 | + if m: |
| 60 | + num_shards = int(m.group(1)) |
| 61 | + return [ |
| 62 | + re.sub(r'@(\d{1,})', '-%.5d-of-%.5d' % (i, num_shards), filename) |
| 63 | + for i in range(num_shards) |
| 64 | + ] |
| 65 | + else: |
| 66 | + return [filename] |
| 67 | + |
| 68 | + |
| 69 | +def _generate_filenames(filenames): |
| 70 | + result = [] |
| 71 | + for filename in filenames: |
| 72 | + result += _generate_sharded_filenames(filename) |
| 73 | + return result |
| 74 | + |
| 75 | + |
| 76 | +def read_data_and_evaluate(input_config, eval_config): |
| 77 | + """Reads pre-computed object detections and groundtruth from tf_record. |
| 78 | +
|
| 79 | + Args: |
| 80 | + input_config: input config proto of type |
| 81 | + object_detection.protos.InputReader. |
| 82 | + eval_config: evaluation config proto of type |
| 83 | + object_detection.protos.EvalConfig. |
| 84 | +
|
| 85 | + Returns: |
| 86 | + Evaluated detections metrics. |
| 87 | +
|
| 88 | + Raises: |
| 89 | + ValueError: if input_reader type is not supported or metric type is unknown. |
| 90 | + """ |
| 91 | + if input_config.WhichOneof('input_reader') == 'tf_record_input_reader': |
| 92 | + input_paths = input_config.tf_record_input_reader.input_path |
| 93 | + |
| 94 | + label_map = label_map_util.load_labelmap(input_config.label_map_path) |
| 95 | + max_num_classes = max([item.id for item in label_map.item]) |
| 96 | + categories = label_map_util.convert_label_map_to_categories( |
| 97 | + label_map, max_num_classes) |
| 98 | + |
| 99 | + object_detection_evaluators = evaluator.get_evaluators( |
| 100 | + eval_config, categories) |
| 101 | + # Support a single evaluator |
| 102 | + object_detection_evaluator = object_detection_evaluators[0] |
| 103 | + |
| 104 | + skipped_images = 0 |
| 105 | + processed_images = 0 |
| 106 | + for input_path in _generate_filenames(input_paths): |
| 107 | + tf.logging.info('Processing file: {0}'.format(input_path)) |
| 108 | + |
| 109 | + record_iterator = tf.python_io.tf_record_iterator(path=input_path) |
| 110 | + data_parser = tf_example_parser.TfExampleDetectionAndGTParser() |
| 111 | + |
| 112 | + for string_record in record_iterator: |
| 113 | + tf.logging.log_every_n(tf.logging.INFO, 'Processed %d images...', 1000, |
| 114 | + processed_images) |
| 115 | + processed_images += 1 |
| 116 | + |
| 117 | + example = tf.train.Example() |
| 118 | + example.ParseFromString(string_record) |
| 119 | + decoded_dict = data_parser.parse(example) |
| 120 | + |
| 121 | + if decoded_dict: |
| 122 | + object_detection_evaluator.add_single_ground_truth_image_info( |
| 123 | + decoded_dict[standard_fields.DetectionResultFields.key], |
| 124 | + decoded_dict) |
| 125 | + object_detection_evaluator.add_single_detected_image_info( |
| 126 | + decoded_dict[standard_fields.DetectionResultFields.key], |
| 127 | + decoded_dict) |
| 128 | + else: |
| 129 | + skipped_images += 1 |
| 130 | + tf.logging.info('Skipped images: {0}'.format(skipped_images)) |
| 131 | + |
| 132 | + return object_detection_evaluator.evaluate() |
| 133 | + |
| 134 | + raise ValueError('Unsupported input_reader_config.') |
| 135 | + |
| 136 | + |
| 137 | +def write_metrics(metrics, output_dir): |
| 138 | + """Write metrics to the output directory. |
| 139 | +
|
| 140 | + Args: |
| 141 | + metrics: A dictionary containing metric names and values. |
| 142 | + output_dir: Directory to write metrics to. |
| 143 | + """ |
| 144 | + tf.logging.info('Writing metrics.') |
| 145 | + |
| 146 | + with open(os.path.join(output_dir, 'metrics.csv'), 'w') as csvfile: |
| 147 | + metrics_writer = csv.writer(csvfile, delimiter=',') |
| 148 | + for metric_name, metric_value in metrics.items(): |
| 149 | + metrics_writer.writerow([metric_name, str(metric_value)]) |
| 150 | + |
| 151 | + |
| 152 | +def main(argv): |
| 153 | + del argv |
| 154 | + required_flags = ['input_config_path', 'eval_config_path', 'eval_dir'] |
| 155 | + for flag_name in required_flags: |
| 156 | + if not getattr(FLAGS, flag_name): |
| 157 | + raise ValueError('Flag --{} is required'.format(flag_name)) |
| 158 | + |
| 159 | + configs = config_util.get_configs_from_multiple_files( |
| 160 | + eval_input_config_path=FLAGS.input_config_path, |
| 161 | + eval_config_path=FLAGS.eval_config_path) |
| 162 | + |
| 163 | + eval_config = configs['eval_config'] |
| 164 | + input_config = configs['eval_input_config'] |
| 165 | + |
| 166 | + metrics = read_data_and_evaluate(input_config, eval_config) |
| 167 | + |
| 168 | + # Save metrics |
| 169 | + write_metrics(metrics, FLAGS.eval_dir) |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == '__main__': |
| 173 | + tf.app.run(main) |
0 commit comments