|
14 | 14 |
|
15 | 15 | """Contains definitions of generators to generate the final detections."""
|
16 | 16 | import contextlib
|
17 |
| -from typing import Any, Dict, List, Optional, Mapping, Sequence |
| 17 | +from typing import Any, Dict, List, Optional, Mapping, Sequence, Tuple |
18 | 18 | # Import libraries
|
19 | 19 | import tensorflow as tf
|
20 | 20 |
|
| 21 | +from official.projects.edgetpu.vision.modeling import custom_layers |
21 | 22 | from official.vision.ops import box_ops
|
22 | 23 | from official.vision.ops import nms
|
23 | 24 | from official.vision.ops import preprocess_ops
|
@@ -372,6 +373,93 @@ def _generate_detections_v2(boxes: tf.Tensor,
|
372 | 373 | return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
|
373 | 374 |
|
374 | 375 |
|
| 376 | +def _generate_detections_v3( |
| 377 | + boxes: tf.Tensor, |
| 378 | + scores: tf.Tensor, |
| 379 | + pre_nms_score_threshold: float = 0.05, |
| 380 | + nms_iou_threshold: float = 0.5, |
| 381 | + max_num_detections: int = 100 |
| 382 | +) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: |
| 383 | + """Generates the detections given the model outputs using NMS for EdgeTPU. |
| 384 | +
|
| 385 | + Args: |
| 386 | + boxes: A `tf.Tensor` with shape `[batch_size, num_classes, N, 4]` or |
| 387 | + `[batch_size, 1, N, 4]`, which box predictions on all feature levels. The |
| 388 | + N is the number of total anchors on all levels. |
| 389 | + scores: A `tf.Tensor` with shape `[batch_size, num_classes, N]`, which |
| 390 | + stacks class probability on all feature levels. The N is the number of |
| 391 | + total anchors on all levels. The num_classes is the number of classes |
| 392 | + predicted by the model. Note that the class_outputs here is the raw score. |
| 393 | + pre_nms_score_threshold: A `float` representing the threshold for deciding |
| 394 | + when to remove boxes based on score. |
| 395 | + nms_iou_threshold: A `float` representing the threshold for deciding whether |
| 396 | + boxes overlap too much with respect to IOU. |
| 397 | + max_num_detections: A `scalar` representing maximum number of boxes retained |
| 398 | + over all classes. |
| 399 | +
|
| 400 | + Returns: |
| 401 | + nms_boxes: A `float` tf.Tensor of shape [batch_size, max_num_detections, 4] |
| 402 | + representing top detected boxes in [y1, x1, y2, x2]. |
| 403 | + nms_scores: A `float` tf.Tensor of shape [batch_size, max_num_detections] |
| 404 | + representing sorted confidence scores for detected boxes. The values are |
| 405 | + between [0, 1]. |
| 406 | + nms_classes: An `int` tf.Tensor of shape [batch_size, max_num_detections] |
| 407 | + representing classes for detected boxes. |
| 408 | + valid_detections: An `int` tf.Tensor of shape [batch_size] only the top |
| 409 | + `valid_detections` boxes are valid detections. |
| 410 | +
|
| 411 | + Raises: |
| 412 | + ValueError if inputs shapes are not valid. |
| 413 | + """ |
| 414 | + with tf.name_scope('generate_detections'): |
| 415 | + batch_size, num_box_classes, box_locations, sides = ( |
| 416 | + boxes.get_shape().as_list()) |
| 417 | + if batch_size is None: |
| 418 | + batch_size = tf.shape(boxes)[0] |
| 419 | + _, num_classes, locations = scores.get_shape().as_list() |
| 420 | + if num_box_classes != 1 and num_box_classes != num_classes: |
| 421 | + raise ValueError('Boxes should have either 1 class or same as scores.') |
| 422 | + if locations != box_locations: |
| 423 | + raise ValueError('Number of locations is different.') |
| 424 | + if sides != 4: |
| 425 | + raise ValueError('Number of sides is incorrect.') |
| 426 | + # Selects pre_nms_score_threshold scores before NMS. |
| 427 | + boxes, scores = box_ops.filter_boxes_by_scores( |
| 428 | + boxes, scores, min_score_threshold=pre_nms_score_threshold) |
| 429 | + |
| 430 | + # EdgeTPU-friendly class-wise NMS, -1 for invalid. |
| 431 | + indices = custom_layers.non_max_suppression_padded( |
| 432 | + boxes, |
| 433 | + scores, |
| 434 | + max_num_detections, |
| 435 | + iou_threshold=nms_iou_threshold) |
| 436 | + # Gather NMS-ed boxes and scores. |
| 437 | + safe_indices = tf.nn.relu(indices) # 0 for invalid |
| 438 | + invalid_detections = safe_indices - indices # 1 for invalid, 0 for valid |
| 439 | + valid_detections = 1.0 - invalid_detections # 0 for invalid, 1 for valid |
| 440 | + safe_indices = tf.cast(safe_indices, tf.int32) |
| 441 | + boxes = tf.expand_dims(valid_detections, -1) * tf.gather( |
| 442 | + boxes, safe_indices, axis=2, batch_dims=2) |
| 443 | + scores = valid_detections * tf.gather( |
| 444 | + scores, safe_indices, axis=2, batch_dims=2) |
| 445 | + # Compliment with class numbers. |
| 446 | + classes = tf.range(num_classes, dtype=tf.float32) |
| 447 | + classes = tf.reshape(classes, [1, num_classes, 1]) |
| 448 | + classes = tf.tile(classes, [batch_size, 1, max_num_detections]) |
| 449 | + # Flatten classes, locations. Class = -1 for invalid detection |
| 450 | + scores = tf.reshape(scores, [batch_size, num_classes * max_num_detections]) |
| 451 | + boxes = tf.reshape(boxes, [batch_size, num_classes * max_num_detections, 4]) |
| 452 | + classes = tf.reshape(valid_detections * classes - invalid_detections, |
| 453 | + [batch_size, num_classes * max_num_detections]) |
| 454 | + # Filter top-k across boxes of all classes |
| 455 | + scores, indices = tf.nn.top_k(scores, k=max_num_detections, sorted=True) |
| 456 | + boxes = tf.gather(boxes, indices, batch_dims=1, axis=1) |
| 457 | + classes = tf.gather(classes, indices, batch_dims=1, axis=1) |
| 458 | + invalid_detections = tf.nn.relu(classes) - classes |
| 459 | + valid_detections = tf.reduce_sum(1. - invalid_detections, axis=1) |
| 460 | + return boxes, scores, classes, valid_detections |
| 461 | + |
| 462 | + |
375 | 463 | def _generate_detections_batched(boxes: tf.Tensor, scores: tf.Tensor,
|
376 | 464 | pre_nms_score_threshold: float,
|
377 | 465 | nms_iou_threshold: float,
|
@@ -997,6 +1085,19 @@ def __call__(self,
|
997 | 1085 | max_num_detections=self._config_dict['max_num_detections']))
|
998 | 1086 | # Set `nmsed_attributes` to None for v2.
|
999 | 1087 | nmsed_attributes = {}
|
| 1088 | + elif self._config_dict['nms_version'] == 'v3': |
| 1089 | + # TODO(tohaspiridonov): add compatible version of |
| 1090 | + # `_decode_multilevel_outputs` in cl/485381750 |
| 1091 | + (nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections) = ( |
| 1092 | + _generate_detections_v3( |
| 1093 | + tf.transpose(boxes, [0, 2, 1, 3]), |
| 1094 | + tf.transpose(scores, [0, 2, 1]), |
| 1095 | + pre_nms_score_threshold=self |
| 1096 | + ._config_dict['pre_nms_score_threshold'], |
| 1097 | + nms_iou_threshold=self._config_dict['nms_iou_threshold'], |
| 1098 | + max_num_detections=self._config_dict['max_num_detections'])) |
| 1099 | + # Set `nmsed_attributes` to None for v2. |
| 1100 | + nmsed_attributes = {} |
1000 | 1101 | else:
|
1001 | 1102 | raise ValueError('NMS version {} not supported.'.format(
|
1002 | 1103 | self._config_dict['nms_version']))
|
|
0 commit comments