Skip to content

Commit e3dbeaf

Browse files
No public description
PiperOrigin-RevId: 604711233
1 parent 602905a commit e3dbeaf

File tree

2 files changed

+270
-0
lines changed

2 files changed

+270
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""This script is tailored for processing outputs from two Mask R-CNN models.
16+
17+
It is designed to handle object detection and segmentation tasks, combines
18+
outputs from two Mask R-CNN models. This involves aggregating detected objects
19+
and their respective masks and bounding boxes. Identifies and removes duplicate
20+
detections in the merged result, ensuring each detected object is unique.
21+
Extracts and compiles features of the detected objects, which may include
22+
aspects like size, area, color, or other model-specific attributes.
23+
"""
24+
25+
import sys
26+
import numpy as np
27+
28+
sys.path.append(
29+
'models/official/projects/waste_identification_ml/model_inference/'
30+
)
31+
from official.projects.waste_identification_ml.model_inference import postprocessing # pylint: disable=g-import-not-at-top,g-bad-import-order
32+
33+
HEIGHT, WIDTH = 512, 1024
34+
35+
36+
def merge_predictions(
37+
results: list[dict[str, np.ndarray]],
38+
score: float,
39+
category_indices: list[list[str]],
40+
category_index: dict[int, dict[str, str]],
41+
max_detection: int,
42+
) -> dict[str, np.ndarray]:
43+
"""Merges and refines prediction results.
44+
45+
This function takes the prediction results from two models, reframes masks to
46+
the original image size, and aligns similar masks from both model outputs. It
47+
then merges these masks into a single result based on the given threshold
48+
criteria. The criteria include a minimum score threshold, an area threshold,
49+
and category alignment using provided indices and dictionary.
50+
51+
Args:
52+
results: Outputs from 2 Mask RCNN models.
53+
score: The minimum score threshold for filtering out the detections.
54+
category_indices: Class labels of 2 models.
55+
category_index: A dictionary mapping class IDs to class labels.
56+
max_detection: Maximum number of detections from both models.
57+
58+
Returns:
59+
Merged and filtered detection results.
60+
"""
61+
# This threshold will be used to eliminate all the detected objects whose
62+
# area is greater than the 'area_threshold'.
63+
area_threshold = 0.3 * HEIGHT * WIDTH
64+
65+
# Reframe the masks from the output of the model to its original size.
66+
results_reframed = [
67+
postprocessing.reframing_masks(detection, HEIGHT, WIDTH)
68+
for detection in results
69+
]
70+
71+
# Align similar masks from both the model outputs and merge all the
72+
# properties into a single mask. Function will only compare first
73+
# 'max_detection' objects. All the objects which have less than
74+
# 'score' probability will be eliminated. All objects whose area is
75+
# more than 'area_threshold' will be eliminated. 'category_dict' and
76+
# 'category_index' are used to find the label from the combinations of
77+
# labels from both individual models. The output should include masks
78+
# appearing in either of the models if they qualify the criteria.
79+
final_result = postprocessing.find_similar_masks(
80+
results_reframed[0],
81+
results_reframed[1],
82+
max_detection,
83+
score,
84+
category_indices,
85+
category_index,
86+
area_threshold,
87+
)
88+
return final_result
89+
90+
91+
def _transform_bounding_boxes(
92+
results: dict[str, np.ndarray]
93+
) -> list[list[int]]:
94+
"""Transforms normalized bounding box coordinates to their original format.
95+
96+
This function takes a dictionary containing normalized bounding box
97+
coordinates and transforms these coordinates to their original scale based on
98+
the provided image height and width.
99+
100+
Args:
101+
results: A dictionary containing detection results. Expected to have a key
102+
'detection_boxes' with a numpy array of normalized coordinates.
103+
104+
Returns:
105+
A list of transformed bounding boxes, each represented as [ymin, xmin, ymax,
106+
xmax] in the original image scale.
107+
"""
108+
transformed_boxes = []
109+
for bb in results['detection_boxes'][0]:
110+
ymin = int(bb[0] * HEIGHT)
111+
xmin = int(bb[1] * WIDTH)
112+
ymax = int(bb[2] * HEIGHT)
113+
xmax = int(bb[3] * WIDTH)
114+
transformed_boxes.append([ymin, xmin, ymax, xmax])
115+
return transformed_boxes
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from unittest import mock
17+
import numpy as np
18+
from official.projects.waste_identification_ml.docker_solution.prediction_pipeline import prediction_postprocessing
19+
20+
21+
class PostprocessingTest(unittest.TestCase):
22+
23+
def setUp(self):
24+
super().setUp()
25+
self.results1 = {
26+
'detection_boxes': [np.array([[0, 0, 100, 100], [100, 100, 200, 200]])],
27+
'detection_masks': [
28+
np.zeros((1, 512, 1024), dtype=np.uint8),
29+
np.ones((1, 512, 1024), dtype=np.uint8),
30+
],
31+
'detection_scores': [[0.9, 0.8]],
32+
'detection_classes': [1, 2],
33+
'detection_classes_names': ['class_1', 'class_2'],
34+
}
35+
36+
self.results2 = {
37+
'detection_boxes': [
38+
np.array([[50, 50, 150, 150], [150, 150, 250, 250]])
39+
],
40+
'detection_masks': [
41+
np.full((1, 512, 1024), 0.5, dtype=np.uint8),
42+
np.full((1, 512, 1024), 0.5, dtype=np.uint8),
43+
],
44+
'detection_scores': [[0.9, 0.8]],
45+
'detection_classes': [2, 1],
46+
'detection_classes_names': ['class_2', 'class_1'],
47+
}
48+
49+
self.category_indices = [[1, 2], [2, 1]]
50+
51+
self.category_index = {
52+
1: {'id': 1, 'name': 'class_1'},
53+
2: {'id': 2, 'name': 'class_2'},
54+
}
55+
self.height = 512
56+
self.width = 1024
57+
58+
def test_merge_predictions(self):
59+
results = prediction_postprocessing.merge_predictions(
60+
[self.results1, self.results2],
61+
0.8,
62+
self.category_indices,
63+
self.category_index,
64+
4,
65+
)
66+
67+
self.assertEqual(results['num_detections'], 4)
68+
self.assertEqual(results['detection_scores'].shape, (4,))
69+
self.assertEqual(results['detection_boxes'].shape, (4, 4))
70+
self.assertEqual(results['detection_classes'].shape, (4,))
71+
self.assertEqual(
72+
results['detection_classes_names'],
73+
['class_1', 'class_2', 'class_1', 'class_2'],
74+
)
75+
self.assertEqual(results['detection_masks_reframed'].shape, (4, 512, 1024))
76+
77+
@mock.patch('postprocessing.find_similar_masks')
78+
def test_merge_predictions_calls_find_similar_masks(
79+
self, mock_find_similar_masks
80+
):
81+
prediction_postprocessing.merge_predictions(
82+
[self.results1, self.results2],
83+
0.8,
84+
self.category_indices,
85+
self.category_index,
86+
4,
87+
)
88+
89+
mock_find_similar_masks.assert_called_once_with(
90+
self.results1,
91+
self.results2,
92+
4,
93+
0.8,
94+
self.category_indices,
95+
self.category_index,
96+
0.3 * 512 * 1024,
97+
)
98+
99+
def test_merge_predictions_with_empty_results(self):
100+
results = prediction_postprocessing.merge_predictions(
101+
[{}, {}],
102+
0.8,
103+
self.category_indices,
104+
self.category_index,
105+
4,
106+
)
107+
108+
self.assertEqual(results['num_detections'], 0)
109+
self.assertEqual(results['detection_scores'].shape, (0,))
110+
self.assertEqual(results['detection_boxes'].shape, (0, 4))
111+
self.assertEqual(results['detection_classes'].shape, (0,))
112+
self.assertEqual(results['detection_classes_names'], [])
113+
self.assertEqual(results['detection_masks_reframed'].shape, (0, 512, 1024))
114+
115+
def test_merge_predictions_with_invalid_category_indices(self):
116+
category_indices = [[1, 3], [2, 4]]
117+
118+
with self.assertRaises(ValueError):
119+
prediction_postprocessing.merge_predictions(
120+
[self.results1, self.results2],
121+
0.8,
122+
category_indices,
123+
self.category_index,
124+
4,
125+
)
126+
127+
def test_transform_bounding_boxes(self):
128+
results = {
129+
'detection_boxes': np.array([[
130+
[0.1, 0.2, 0.4, 0.5], # Normalized coordinates
131+
[0.3, 0.3, 0.6, 0.7],
132+
]])
133+
}
134+
135+
# Expected output for the adjusted height and width
136+
expected_transformed_boxes = [
137+
[
138+
int(0.1 * self.height),
139+
int(0.2 * self.width),
140+
int(0.4 * self.height),
141+
int(0.5 * self.width),
142+
],
143+
[
144+
int(0.3 * self.height),
145+
int(0.3 * self.width),
146+
int(0.6 * self.height),
147+
int(0.7 * self.width),
148+
],
149+
]
150+
151+
transformed_boxes = prediction_postprocessing._transform_bounding_boxes(
152+
results
153+
)
154+
155+
self.assertEqual(transformed_boxes, expected_transformed_boxes)

0 commit comments

Comments
 (0)