Skip to content

Commit 1ae7f5c

Browse files
committed
Add tests for negative samples for Mask R-CNN and Keypoint R-CNN (#2069)
* Add tests for negative samples for Mask R-CNN and Keypoint R-CNN * Fix lint
1 parent e61538c commit 1ae7f5c

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

test/test_models_detection_negative_samples.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,27 @@
1111

1212
class Tester(unittest.TestCase):
1313

14-
def test_targets_to_anchors(self):
14+
def _make_empty_sample(self, add_masks=False, add_keypoints=False):
15+
images = [torch.rand((3, 100, 100), dtype=torch.float32)]
1516
boxes = torch.zeros((0, 4), dtype=torch.float32)
1617
negative_target = {"boxes": boxes,
17-
"labels": torch.zeros((1, 1), dtype=torch.int64),
18+
"labels": torch.zeros(0, dtype=torch.int64),
1819
"image_id": 4,
1920
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
2021
"iscrowd": torch.zeros((0,), dtype=torch.int64)}
2122

22-
anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)]
23+
if add_masks:
24+
negative_target["masks"] = torch.zeros(0, 100, 100, dtype=torch.uint8)
25+
26+
if add_keypoints:
27+
negative_target["keypoints"] = torch.zeros(17, 0, 3, dtype=torch.float32)
28+
2329
targets = [negative_target]
30+
return images, targets
31+
32+
def test_targets_to_anchors(self):
33+
_, targets = self._make_empty_sample()
34+
anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)]
2435

2536
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
2637
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
@@ -85,25 +96,37 @@ def test_assign_targets_to_proposals(self):
8596
self.assertEqual(labels[0].shape, torch.Size([proposals[0].shape[0]]))
8697
self.assertEqual(labels[0].dtype, torch.int64)
8798

88-
def test_forward_negative_sample(self):
89-
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
90-
in_features = model.roi_heads.box_predictor.cls_score.in_features
91-
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
99+
def test_forward_negative_sample_frcnn(self):
100+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
101+
num_classes=2, min_size=100, max_size=100)
92102

93-
images = [torch.rand((3, 100, 100), dtype=torch.float32)]
94-
boxes = torch.zeros((0, 4), dtype=torch.float32)
95-
negative_target = {"boxes": boxes,
96-
"labels": torch.zeros((1, 1), dtype=torch.int64),
97-
"image_id": 4,
98-
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
99-
"iscrowd": torch.zeros((0,), dtype=torch.int64)}
103+
images, targets = self._make_empty_sample()
104+
loss_dict = model(images, targets)
100105

101-
targets = [negative_target]
106+
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
107+
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
108+
109+
def test_forward_negative_sample_mrcnn(self):
110+
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
111+
num_classes=2, min_size=100, max_size=100)
112+
113+
images, targets = self._make_empty_sample(add_masks=True)
114+
loss_dict = model(images, targets)
115+
116+
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
117+
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
118+
self.assertEqual(loss_dict["loss_mask"], torch.tensor(0.))
119+
120+
def test_forward_negative_sample_krcnn(self):
121+
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
122+
num_classes=2, min_size=100, max_size=100)
102123

124+
images, targets = self._make_empty_sample(add_keypoints=True)
103125
loss_dict = model(images, targets)
104126

105127
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
106128
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
129+
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.))
107130

108131

109132
if __name__ == '__main__':

torchvision/ops/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class can go away.
113113
)
114114

115115
output_shape = _output_size(2, input, size, scale_factor)
116-
output_shape = input.shape[:-2] + output_shape
116+
output_shape = list(input.shape[:-2]) + output_shape
117117
return _new_empty_tensor(input, output_shape)
118118

119119

0 commit comments

Comments
 (0)