14
14
15
15
"""Detection input and model functions for serving/inference."""
16
16
17
+ import math
17
18
from typing import Mapping , Tuple
18
19
19
20
from absl import logging
20
21
import tensorflow as tf
21
22
23
+ from official .core import config_definitions as cfg
22
24
from official .vision import configs
23
25
from official .vision .modeling import factory
24
26
from official .vision .ops import anchor
30
32
class DetectionModule (export_base .ExportModule ):
31
33
"""Detection Module."""
32
34
35
+ def __init__ (
36
+ self ,
37
+ params : cfg .ExperimentConfig ,
38
+ * ,
39
+ input_image_size : list [int ],
40
+ ** kwargs ,
41
+ ):
42
+ """Initializes a detection module for export.
43
+
44
+ Args:
45
+ params: Experiment params.
46
+ input_image_size: List or Tuple of size of the input image. For 2D image,
47
+ it is [height, width].
48
+ **kwargs: All other kwargs are passed to `export_base.ExportModule`; see
49
+ the documentation on `export_base.ExportModule` for valid arguments.
50
+ """
51
+ if params .task .train_data .parser .pad :
52
+ self ._padded_size = preprocess_ops .compute_padded_size (
53
+ input_image_size , 2 ** params .task .model .max_level
54
+ )
55
+ else :
56
+ self ._padded_size = input_image_size
57
+ super ().__init__ (
58
+ params = params ,
59
+ input_image_size = input_image_size ,
60
+ ** kwargs ,
61
+ )
62
+
33
63
def _build_model (self ):
34
64
35
65
nms_versions_supporting_dynamic_batch_size = {'batched' , 'v2' , 'v3' }
@@ -40,8 +70,8 @@ def _build_model(self):
40
70
'does not support with dynamic batch size.' , nms_version )
41
71
self .params .task .model .detection_generator .nms_version = 'batched'
42
72
43
- input_specs = tf .keras .layers .InputSpec (shape = [self . _batch_size ] +
44
- self ._input_image_size + [ 3 ])
73
+ input_specs = tf .keras .layers .InputSpec (shape = [
74
+ self . _batch_size , * self ._padded_size , 3 ])
45
75
46
76
if isinstance (self .params .task .model , configs .maskrcnn .MaskRCNN ):
47
77
model = factory .build_maskrcnn (
@@ -64,23 +94,21 @@ def _build_anchor_boxes(self):
64
94
num_scales = model_params .anchor .num_scales ,
65
95
aspect_ratios = model_params .anchor .aspect_ratios ,
66
96
anchor_size = model_params .anchor .anchor_size )
67
- return input_anchor (
68
- image_size = (self ._input_image_size [0 ], self ._input_image_size [1 ]))
97
+ return input_anchor (image_size = self ._padded_size )
69
98
70
99
def _build_inputs (self , image ):
71
100
"""Builds detection model inputs for serving."""
72
- model_params = self .params .task .model
73
101
# Normalizes image with mean and std pixel values.
74
102
image = preprocess_ops .normalize_image (
75
103
image , offset = preprocess_ops .MEAN_RGB , scale = preprocess_ops .STDDEV_RGB )
76
104
77
105
image , image_info = preprocess_ops .resize_and_crop_image (
78
106
image ,
79
107
self ._input_image_size ,
80
- padded_size = preprocess_ops .compute_padded_size (
81
- self ._input_image_size , 2 ** model_params .max_level ),
108
+ padded_size = self ._padded_size ,
82
109
aug_scale_min = 1.0 ,
83
- aug_scale_max = 1.0 )
110
+ aug_scale_max = 1.0 ,
111
+ )
84
112
anchor_boxes = self ._build_anchor_boxes ()
85
113
86
114
return image , anchor_boxes , image_info
@@ -128,7 +156,7 @@ def preprocess(
128
156
images = tf .cast (images , dtype = tf .float32 )
129
157
130
158
# Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
131
- images_spec = tf .TensorSpec (shape = self ._input_image_size + [3 ],
159
+ images_spec = tf .TensorSpec (shape = self ._padded_size + [3 ],
132
160
dtype = tf .float32 )
133
161
134
162
num_anchors = model_params .anchor .num_scales * len (
@@ -137,8 +165,9 @@ def preprocess(
137
165
for level in range (model_params .min_level , model_params .max_level + 1 ):
138
166
anchor_level_spec = tf .TensorSpec (
139
167
shape = [
140
- self ._input_image_size [0 ] // 2 ** level ,
141
- self ._input_image_size [1 ] // 2 ** level , num_anchors
168
+ math .ceil (self ._padded_size [0 ] / 2 ** level ),
169
+ math .ceil (self ._padded_size [1 ] / 2 ** level ),
170
+ num_anchors ,
142
171
],
143
172
dtype = tf .float32 )
144
173
anchor_shapes .append ((str (level ), anchor_level_spec ))
0 commit comments