|
| 1 | +import torch |
| 2 | + |
| 3 | +from SuperGluePretrainedNetwork.models.superglue import ( |
| 4 | + SuperGlue, |
| 5 | + arange_like, |
| 6 | + log_optimal_transport, |
| 7 | + normalize_keypoints, |
| 8 | +) |
| 9 | + |
| 10 | + |
| 11 | +class SuperGlueWrapper(SuperGlue): |
| 12 | + default_config = { |
| 13 | + "descriptor_dim": 256, |
| 14 | + "weights": "outdor", |
| 15 | + "keypoint_encoder": [32, 64, 128, 256], |
| 16 | + "GNN_layers": ["self", "cross"] * 9, |
| 17 | + "sinkhorn_iterations": 100, |
| 18 | + "match_threshold": 0.2, |
| 19 | + } |
| 20 | + |
| 21 | + def __init__(self, config): |
| 22 | + SuperGlue.__init__(self, config) |
| 23 | + |
| 24 | + def forward( |
| 25 | + self, |
| 26 | + image0_shape, |
| 27 | + scores0, |
| 28 | + keypoints0, |
| 29 | + descriptors0, |
| 30 | + image1_shape, |
| 31 | + scores1, |
| 32 | + keypoints1, |
| 33 | + descriptors1, |
| 34 | + ): |
| 35 | + data = { |
| 36 | + "image0_shape": image0_shape, |
| 37 | + "scores0": scores0, |
| 38 | + "keypoints0": keypoints0, |
| 39 | + "descriptors0": descriptors0, |
| 40 | + "image1_shape": image1_shape, |
| 41 | + "scores1": scores1, |
| 42 | + "keypoints1": keypoints1, |
| 43 | + "descriptors1": descriptors1, |
| 44 | + } |
| 45 | + |
| 46 | + """Run SuperGlue on a pair of keypoints and descriptors""" |
| 47 | + desc0, desc1 = data["descriptors0"], data["descriptors1"] |
| 48 | + kpts0, kpts1 = data["keypoints0"], data["keypoints1"] |
| 49 | + |
| 50 | + # Keypoint normalization. |
| 51 | + kpts0 = normalize_keypoints(kpts0, data["image0_shape"]) |
| 52 | + kpts1 = normalize_keypoints(kpts1, data["image1_shape"]) |
| 53 | + |
| 54 | + # Keypoint MLP encoder. |
| 55 | + desc0 = desc0 + self.kenc(kpts0, data["scores0"]) |
| 56 | + desc1 = desc1 + self.kenc(kpts1, data["scores1"]) |
| 57 | + |
| 58 | + # Multi-layer Transformer network. |
| 59 | + desc0, desc1 = self.gnn(desc0, desc1) |
| 60 | + |
| 61 | + # Final MLP projection. |
| 62 | + mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) |
| 63 | + |
| 64 | + # Compute matching descriptor distance. |
| 65 | + scores = torch.einsum("bdn,bdm->bnm", mdesc0, mdesc1) |
| 66 | + scores = scores / self.config["descriptor_dim"] ** 0.5 |
| 67 | + |
| 68 | + # Run the optimal transport. |
| 69 | + scores = log_optimal_transport( |
| 70 | + scores, self.bin_score, iters=self.config["sinkhorn_iterations"] |
| 71 | + ) |
| 72 | + |
| 73 | + # Get the matches with score above "match_threshold". |
| 74 | + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) |
| 75 | + indices0, indices1 = max0.indices, max1.indices |
| 76 | + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) |
| 77 | + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) |
| 78 | + zero = scores.new_tensor(0) |
| 79 | + mscores0 = torch.where(mutual0, max0.values.exp(), zero) |
| 80 | + mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) |
| 81 | + valid0 = mutual0 & (mscores0 > self.config["match_threshold"]) |
| 82 | + valid1 = mutual1 & valid0.gather(1, indices1) |
| 83 | + indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) |
| 84 | + indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) |
| 85 | + |
| 86 | + return { |
| 87 | + "matches0": indices0, # use -1 for invalid match |
| 88 | + "matches1": indices1, # use -1 for invalid match |
| 89 | + "matching_scores0": mscores0, |
| 90 | + "matching_scores1": mscores1, |
| 91 | + } |
0 commit comments