Skip to content

Commit 2c590a0

Browse files
arashwantensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 424463592
1 parent ca8c44d commit 2c590a0

File tree

5 files changed

+33
-9
lines changed

5 files changed

+33
-9
lines changed

official/vision/beta/projects/panoptic_maskrcnn/configs/panoptic_maskrcnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from official.vision.beta.configs import common
2626
from official.vision.beta.configs import maskrcnn
2727
from official.vision.beta.configs import semantic_segmentation
28+
from official.vision.beta.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn as deepmac_maskrcnn
2829

2930

3031
SEGMENTATION_MODEL = semantic_segmentation.SemanticSegmentationModel
@@ -89,7 +90,7 @@ class PanopticSegmentationGenerator(hyperparams.Config):
8990

9091

9192
@dataclasses.dataclass
92-
class PanopticMaskRCNN(maskrcnn.MaskRCNN):
93+
class PanopticMaskRCNN(deepmac_maskrcnn.DeepMaskHeadRCNN):
9394
"""Panoptic Mask R-CNN model config."""
9495
segmentation_model: semantic_segmentation.SemanticSegmentationModel = (
9596
SEGMENTATION_MODEL(num_classes=2))

official/vision/beta/projects/panoptic_maskrcnn/modeling/factory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import tensorflow as tf
1818

1919
from official.vision.beta.modeling import backbones
20-
from official.vision.beta.modeling import factory as models_factory
2120
from official.vision.beta.modeling.decoders import factory as decoder_factory
2221
from official.vision.beta.modeling.heads import segmentation_heads
22+
from official.vision.beta.projects.deepmac_maskrcnn.tasks import deep_mask_head_rcnn
2323
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
2424
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
2525
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
@@ -47,7 +47,7 @@ def build_panoptic_maskrcnn(
4747
segmentation_config = model_config.segmentation_model
4848

4949
# Builds the maskrcnn model.
50-
maskrcnn_model = models_factory.build_maskrcnn(
50+
maskrcnn_model = deep_mask_head_rcnn.build_maskrcnn(
5151
input_specs=input_specs,
5252
model_config=model_config,
5353
l2_regularizer=l2_regularizer)
@@ -117,6 +117,7 @@ def build_panoptic_maskrcnn(
117117

118118
# Combines maskrcnn, and segmentation models to build panoptic segmentation
119119
# model.
120+
120121
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
121122
backbone=maskrcnn_model.backbone,
122123
decoder=maskrcnn_model.decoder,

official/vision/beta/projects/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818

1919
import tensorflow as tf
2020

21-
from official.vision.beta.modeling import maskrcnn_model
21+
from official.vision.beta.projects.deepmac_maskrcnn.modeling import maskrcnn_model
2222

2323

24-
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
24+
class PanopticMaskRCNNModel(maskrcnn_model.DeepMaskRCNNModel):
2525
"""The Panoptic Segmentation model."""
2626

2727
def __init__(
@@ -49,7 +49,8 @@ def __init__(
4949
max_level: Optional[int] = None,
5050
num_scales: Optional[int] = None,
5151
aspect_ratios: Optional[List[float]] = None,
52-
anchor_size: Optional[float] = None, # pytype: disable=annotation-type-mismatch # typed-keras
52+
anchor_size: Optional[float] = None,
53+
use_gt_boxes_for_masks: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
5354
**kwargs):
5455
"""Initializes the Panoptic Mask R-CNN model.
5556
@@ -94,6 +95,7 @@ def __init__(
9495
aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
9596
anchor_size: A number representing the scale of size of the base anchor to
9697
the feature stride 2^level.
98+
use_gt_boxes_for_masks: `bool`, whether to use only gt boxes for masks.
9799
**kwargs: keyword arguments to be passed.
98100
"""
99101
super(PanopticMaskRCNNModel, self).__init__(
@@ -115,6 +117,7 @@ def __init__(
115117
num_scales=num_scales,
116118
aspect_ratios=aspect_ratios,
117119
anchor_size=anchor_size,
120+
use_gt_boxes_for_masks=use_gt_boxes_for_masks,
118121
**kwargs)
119122

120123
self._config_dict.update({

official/vision/beta/projects/panoptic_maskrcnn/serving/panoptic_segmentation.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def serve(self, images: tf.Tensor):
9797
anchor_boxes=anchor_boxes,
9898
training=False)
9999

100+
detections.pop('rpn_boxes')
101+
detections.pop('rpn_scores')
102+
detections.pop('cls_outputs')
103+
detections.pop('box_outputs')
104+
detections.pop('backbone_features')
105+
detections.pop('decoder_features')
106+
107+
# Normalize detection boxes to [0, 1]. Here we first map them to the
108+
# original image size, then normalize them to [0, 1].
109+
detections['detection_boxes'] = (
110+
detections['detection_boxes'] /
111+
tf.tile(image_info[:, 2:3, :], [1, 1, 2]) /
112+
tf.tile(image_info[:, 0:1, :], [1, 1, 2]))
113+
100114
if model_params.detection_generator.apply_nms:
101115
final_outputs = {
102116
'detection_boxes': detections['detection_boxes'],
@@ -109,10 +123,15 @@ def serve(self, images: tf.Tensor):
109123
'decoded_boxes': detections['decoded_boxes'],
110124
'decoded_box_scores': detections['decoded_box_scores']
111125
}
112-
126+
masks = detections['segmentation_outputs']
127+
masks = tf.image.resize(masks, self._input_image_size, method='bilinear')
128+
classes = tf.math.argmax(masks, axis=-1)
129+
scores = tf.nn.softmax(masks, axis=-1)
113130
final_outputs.update({
114131
'detection_masks': detections['detection_masks'],
115-
'segmentation_outputs': detections['segmentation_outputs'],
132+
'masks': masks,
133+
'scores': scores,
134+
'classes': classes,
116135
'image_info': image_info
117136
})
118137
if model_params.generate_panoptic_masks:

official/vision/beta/projects/panoptic_maskrcnn/tasks/panoptic_maskrcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def build_model(self) -> tf.keras.Model:
6161
def initialize(self, model: tf.keras.Model) -> None:
6262
"""Loading pretrained checkpoint."""
6363

64-
if not self.task_config.init_checkpoint_modules:
64+
if not self.task_config.init_checkpoint:
6565
return
6666

6767
def _get_checkpoint_path(checkpoint_dir_or_file):

0 commit comments

Comments
 (0)