Skip to content

Commit 89cf774

Browse files
committed
convert superglue to onnx format
1 parent f536f54 commit 89cf774

File tree

5 files changed

+182
-0
lines changed

5 files changed

+182
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[submodule "scripts/superpoint/SuperPointPretrainedNetwork"]
22
path = scripts/superpoint/SuperPointPretrainedNetwork
33
url = https://github.com/magicleap/SuperPointPretrainedNetwork
4+
[submodule "scripts/superglue/SuperGluePretrainedNetwork"]
5+
path = scripts/superglue/SuperGluePretrainedNetwork
6+
url = https://github.com/magicleap/SuperGluePretrainedNetwork.git

scripts/superglue/README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# convert pre-trained superglue pytorch weights to onnx format
2+
3+
---
4+
5+
## dependencies
6+
7+
---
8+
9+
- python: 3x
10+
11+
-
12+
13+
```bash
14+
git submodule update --init --recursive
15+
16+
python3 -m pip install -r SuperGluePretrainedNetwork/requirements.txt
17+
```
18+
19+
## :running: how to run
20+
21+
---
22+
23+
24+
- export onnx weights
25+
26+
```
27+
python3 convert_to_onnx.py
28+
```
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/usr/bin/env python
2+
import os
3+
4+
import torch
5+
6+
from superglue_wrapper import SuperGlueWrapper as SuperGlue
7+
8+
9+
def main():
10+
config = {
11+
"descriptor_dim": 256,
12+
"weights": "indoor",
13+
"keypoint_encoder": [32, 64, 128, 256],
14+
"GNN_layers": ["self", "cross"] * 9,
15+
"sinkhorn_iterations": 100,
16+
"match_threshold": 0.2,
17+
}
18+
19+
model = SuperGlue(config=config)
20+
model.eval()
21+
22+
batch_size = 1
23+
height = 480
24+
width = 640
25+
num_keypoints = 382
26+
data = {}
27+
for i in range(2):
28+
data[f"image{i}_shape"] = torch.tensor([batch_size, 1, height, width])
29+
data[f"scores{i}"] = torch.randn(batch_size, num_keypoints)
30+
data[f"keypoints{i}"] = torch.randn(batch_size, num_keypoints, 2)
31+
data[f"descriptors{i}"] = torch.randn(batch_size, 256, num_keypoints)
32+
33+
torch.onnx.export(
34+
model,
35+
data,
36+
"super_glue.onnx",
37+
export_params=True,
38+
opset_version=12,
39+
do_constant_folding=True,
40+
input_names=list(data.keys()),
41+
output_names=["matches0", "matches1", "matching_scores0", "matching_scores1"],
42+
dynamic_axes={
43+
"keypoints0": {0: "batch_size", 1: "num_keypoints0"},
44+
"scores0": {0: "batch_size", 1: "num_keypoints0"},
45+
"descriptors0": {0: "batch_size", 2: "num_keypoints0"},
46+
"keypoints1": {0: "batch_size", 1: "num_keypoints1"},
47+
"scores1": {0: "batch_size", 1: "num_keypoints1"},
48+
"descriptors1": {0: "batch_size", 2: "num_keypoints1"},
49+
"matches0": {0: "batch_size", 1: "num_keypoints0"},
50+
"matches1": {0: "batch_size", 1: "num_keypoints1"},
51+
"matching_scores0": {0: "batch_size", 1: "num_keypoints0"},
52+
"matching_scores1": {0: "batch_size", 1: "num_keypoints1"},
53+
},
54+
)
55+
print(f"\nonnx model is saved to: {os.getcwd()}/super_glue.onnx")
56+
57+
58+
if __name__ == "__main__":
59+
main()
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
3+
from SuperGluePretrainedNetwork.models.superglue import (
4+
SuperGlue,
5+
arange_like,
6+
log_optimal_transport,
7+
normalize_keypoints,
8+
)
9+
10+
11+
class SuperGlueWrapper(SuperGlue):
12+
default_config = {
13+
"descriptor_dim": 256,
14+
"weights": "outdor",
15+
"keypoint_encoder": [32, 64, 128, 256],
16+
"GNN_layers": ["self", "cross"] * 9,
17+
"sinkhorn_iterations": 100,
18+
"match_threshold": 0.2,
19+
}
20+
21+
def __init__(self, config):
22+
SuperGlue.__init__(self, config)
23+
24+
def forward(
25+
self,
26+
image0_shape,
27+
scores0,
28+
keypoints0,
29+
descriptors0,
30+
image1_shape,
31+
scores1,
32+
keypoints1,
33+
descriptors1,
34+
):
35+
data = {
36+
"image0_shape": image0_shape,
37+
"scores0": scores0,
38+
"keypoints0": keypoints0,
39+
"descriptors0": descriptors0,
40+
"image1_shape": image1_shape,
41+
"scores1": scores1,
42+
"keypoints1": keypoints1,
43+
"descriptors1": descriptors1,
44+
}
45+
46+
"""Run SuperGlue on a pair of keypoints and descriptors"""
47+
desc0, desc1 = data["descriptors0"], data["descriptors1"]
48+
kpts0, kpts1 = data["keypoints0"], data["keypoints1"]
49+
50+
# Keypoint normalization.
51+
kpts0 = normalize_keypoints(kpts0, data["image0_shape"])
52+
kpts1 = normalize_keypoints(kpts1, data["image1_shape"])
53+
54+
# Keypoint MLP encoder.
55+
desc0 = desc0 + self.kenc(kpts0, data["scores0"])
56+
desc1 = desc1 + self.kenc(kpts1, data["scores1"])
57+
58+
# Multi-layer Transformer network.
59+
desc0, desc1 = self.gnn(desc0, desc1)
60+
61+
# Final MLP projection.
62+
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
63+
64+
# Compute matching descriptor distance.
65+
scores = torch.einsum("bdn,bdm->bnm", mdesc0, mdesc1)
66+
scores = scores / self.config["descriptor_dim"] ** 0.5
67+
68+
# Run the optimal transport.
69+
scores = log_optimal_transport(
70+
scores, self.bin_score, iters=self.config["sinkhorn_iterations"]
71+
)
72+
73+
# Get the matches with score above "match_threshold".
74+
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
75+
indices0, indices1 = max0.indices, max1.indices
76+
mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
77+
mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
78+
zero = scores.new_tensor(0)
79+
mscores0 = torch.where(mutual0, max0.values.exp(), zero)
80+
mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
81+
valid0 = mutual0 & (mscores0 > self.config["match_threshold"])
82+
valid1 = mutual1 & valid0.gather(1, indices1)
83+
indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
84+
indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
85+
86+
return {
87+
"matches0": indices0, # use -1 for invalid match
88+
"matches1": indices1, # use -1 for invalid match
89+
"matching_scores0": mscores0,
90+
"matching_scores1": mscores1,
91+
}

0 commit comments

Comments
 (0)