Skip to content

Commit 72e7e91

Browse files
No public description
PiperOrigin-RevId: 569290058
1 parent 8c30dbe commit 72e7e91

File tree

3 files changed

+48
-11
lines changed

3 files changed

+48
-11
lines changed

official/vision/configs/maskrcnn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ class Parser(hyperparams.Config):
4646
rpn_batch_size_per_im: int = 256
4747
rpn_fg_fraction: float = 0.5
4848
mask_crop_size: int = 112
49+
pad: bool = True # Only support `pad = True`.
50+
51+
def __post_init__(self, *args, **kwargs):
52+
"""Validates the configuration."""
53+
if not self.pad:
54+
raise ValueError('`maskrcnn.Parser` only supports `pad = True`.')
55+
super().__post_init__(*args, **kwargs)
4956

5057

5158
@dataclasses.dataclass

official/vision/configs/retinanet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class Parser(hyperparams.Config):
5959
max_num_instances: int = 100
6060
# Can choose AutoAugment and RandAugment.
6161
aug_type: Optional[common.Augmentation] = None
62+
pad: bool = True
6263

6364
# Keep for backward compatibility. Not used.
6465
aug_policy: Optional[str] = None

official/vision/serving/detection.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
"""Detection input and model functions for serving/inference."""
1616

17+
import math
1718
from typing import Mapping, Tuple
1819

1920
from absl import logging
2021
import tensorflow as tf
2122

23+
from official.core import config_definitions as cfg
2224
from official.vision import configs
2325
from official.vision.modeling import factory
2426
from official.vision.ops import anchor
@@ -30,6 +32,34 @@
3032
class DetectionModule(export_base.ExportModule):
3133
"""Detection Module."""
3234

35+
def __init__(
36+
self,
37+
params: cfg.ExperimentConfig,
38+
*,
39+
input_image_size: list[int],
40+
**kwargs,
41+
):
42+
"""Initializes a detection module for export.
43+
44+
Args:
45+
params: Experiment params.
46+
input_image_size: List or Tuple of size of the input image. For 2D image,
47+
it is [height, width].
48+
**kwargs: All other kwargs are passed to `export_base.ExportModule`; see
49+
the documentation on `export_base.ExportModule` for valid arguments.
50+
"""
51+
if params.task.train_data.parser.pad:
52+
self._padded_size = preprocess_ops.compute_padded_size(
53+
input_image_size, 2**params.task.model.max_level
54+
)
55+
else:
56+
self._padded_size = input_image_size
57+
super().__init__(
58+
params=params,
59+
input_image_size=input_image_size,
60+
**kwargs,
61+
)
62+
3363
def _build_model(self):
3464

3565
nms_versions_supporting_dynamic_batch_size = {'batched', 'v2', 'v3'}
@@ -40,8 +70,8 @@ def _build_model(self):
4070
'does not support with dynamic batch size.', nms_version)
4171
self.params.task.model.detection_generator.nms_version = 'batched'
4272

43-
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
44-
self._input_image_size + [3])
73+
input_specs = tf.keras.layers.InputSpec(shape=[
74+
self._batch_size, *self._padded_size, 3])
4575

4676
if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
4777
model = factory.build_maskrcnn(
@@ -64,23 +94,21 @@ def _build_anchor_boxes(self):
6494
num_scales=model_params.anchor.num_scales,
6595
aspect_ratios=model_params.anchor.aspect_ratios,
6696
anchor_size=model_params.anchor.anchor_size)
67-
return input_anchor(
68-
image_size=(self._input_image_size[0], self._input_image_size[1]))
97+
return input_anchor(image_size=self._padded_size)
6998

7099
def _build_inputs(self, image):
71100
"""Builds detection model inputs for serving."""
72-
model_params = self.params.task.model
73101
# Normalizes image with mean and std pixel values.
74102
image = preprocess_ops.normalize_image(
75103
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
76104

77105
image, image_info = preprocess_ops.resize_and_crop_image(
78106
image,
79107
self._input_image_size,
80-
padded_size=preprocess_ops.compute_padded_size(
81-
self._input_image_size, 2**model_params.max_level),
108+
padded_size=self._padded_size,
82109
aug_scale_min=1.0,
83-
aug_scale_max=1.0)
110+
aug_scale_max=1.0,
111+
)
84112
anchor_boxes = self._build_anchor_boxes()
85113

86114
return image, anchor_boxes, image_info
@@ -128,7 +156,7 @@ def preprocess(
128156
images = tf.cast(images, dtype=tf.float32)
129157

130158
# Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
131-
images_spec = tf.TensorSpec(shape=self._input_image_size + [3],
159+
images_spec = tf.TensorSpec(shape=self._padded_size + [3],
132160
dtype=tf.float32)
133161

134162
num_anchors = model_params.anchor.num_scales * len(
@@ -137,8 +165,9 @@ def preprocess(
137165
for level in range(model_params.min_level, model_params.max_level + 1):
138166
anchor_level_spec = tf.TensorSpec(
139167
shape=[
140-
self._input_image_size[0] // 2**level,
141-
self._input_image_size[1] // 2**level, num_anchors
168+
math.ceil(self._padded_size[0] / 2**level),
169+
math.ceil(self._padded_size[1] / 2**level),
170+
num_anchors,
142171
],
143172
dtype=tf.float32)
144173
anchor_shapes.append((str(level), anchor_level_spec))

0 commit comments

Comments
 (0)