Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

*.onnx
*.engine
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,35 @@ xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True
output = xfeat.detectAndCompute(torch.randn(1,3,480,640), top_k = 4096)[0]
```

### TensorRT - Export
Its advisible to use a [NGC container](https://catalog.ngc.nvidia.com/containers). For example, for the NVIDIA Jetson platform refer to [L4T ML](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-ml/tags). The additional dependencies you need:
- [TensorRT](https://github.com/NVIDIA/TensorRT) - usually availble inside the docker container
- `onnx`
- `onnxruntime`
- [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt)

```
python3 export.py --help

usage: export.py [-h] [--weights WEIGHTS] [--imgsz IMGSZ IMGSZ] [--fp16_mode FP16_MODE] [--use_dynamic_axis USE_DYNAMIC_AXIS]
[--onnx_opset ONNX_OPSET] [--workspace WORKSPACE]

Create ONNX and TensorRT export for XFeat.

optional arguments:
-h, --help show this help message and exit
--weights WEIGHTS Path to the weights pt file to process
--imgsz IMGSZ IMGSZ Input image size
--fp16_mode FP16_MODE
--use_dynamic_axis USE_DYNAMIC_AXIS
--onnx_opset ONNX_OPSET
--workspace WORKSPACE
```
### TensorRT - Demo
```
python3 realtime_demo.py --method XFeat --use_engine True
```

### Training
XFeat training code will be released soon. Please stay tuned.

Expand Down
101 changes: 101 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import torch
import tqdm
import argparse
import subprocess
import numpy as np
from modules.model import *

def build_onnx_engine(weights: str,
onnx_weight: str,
imgsz: tuple = (480,640),
use_dynamic_axis: bool = True,
onnx_opset: int = 17) -> None:
if onnx_weight is None:
raise Exception("Onnx file path cannot be None.")
dev = 'cpu'
net = XFeatModel().to(dev).eval()
net.load_state_dict(torch.load(weights, map_location=dev))
#Random input
x = torch.randn(1,3,*imgsz).to(dev)
if onnx_weight is None:
raise Exception("Onnx file path cannot be None.")
if use_dynamic_axis:
dynamic_axis = {
"image": {0: "batch"},
}
else:
dynamic_axis = {}
# net = TempModule(net)
torch.onnx.export(
net,
x,
onnx_weight,
input_names=XFeatModel.get_xfeat_input_names(),
output_names=XFeatModel.get_xfeat_output_names(),
dynamic_axes=dynamic_axis,
opset_version=onnx_opset,
)

def build_tensorrt_engine(weights: str,
imgsz: tuple = (480,640),
fp16_mode: bool = True,
use_dynamic_axis: bool = True,
onnx_opset: int = 17,
workspace: int = 4096) -> None:

if weights.endswith(".pt"):
# Replace ".pt" with ".onnx"
onnx_weight = weights[:-3] + ".onnx"
else:
raise Exception("File path does not end with '.pt'.")

build_onnx_engine(weights, onnx_weight, imgsz, use_dynamic_axis, onnx_opset)

if not os.path.exists(onnx_weight):
raise Exception("ONNX export does not exist")

if onnx_weight.endswith(".onnx"):
# Replace ".pt" with ".onnx"
engine_weight = onnx_weight[:-5] + ".engine"
else:
raise Exception("File path does not end with '.onnx'.")

args = ["/usr/src/tensorrt/bin/trtexec"]
args.append(f"--onnx={onnx_weight}")
args.append(f"--saveEngine={engine_weight}")
args.append(f"--workspace={workspace}")

if fp16_mode:
args += ["--fp16"]

args += [f"--shapes=image:1x3x{imgsz[0]}x{imgsz[1]}"]

subprocess.call(args)
print(f"Finished TensorRT engine export to {engine_weight}.")

def main():
parser = argparse.ArgumentParser(description='Create ONNX and TensorRT export for XFeat.')
parser.add_argument('--weights', type=str, default=f'{os.path.abspath(os.path.dirname(__file__))}/weights/xfeat.pt', help='Path to the weights pt file to process')
parser.add_argument('--imgsz', nargs=2, type=int, default=[480,640], help='Input image size')
parser.add_argument("--fp16_mode", type=bool, default=True)
parser.add_argument("--use_dynamic_axis", type=bool, default=True)
parser.add_argument("--onnx_opset", type=int, default=17)
parser.add_argument("--workspace", type=int, default=4096)
args = parser.parse_args()
weights = args.weights
imgsz = args.imgsz
fp16_mode = args.fp16_mode
onnx_opset = args.onnx_opset
use_dynamic_axis = args.use_dynamic_axis
workspace = args.workspace
build_tensorrt_engine(weights, imgsz, fp16_mode, use_dynamic_axis, onnx_opset, workspace)

if __name__ == '__main__':
main()






11 changes: 11 additions & 0 deletions modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch.nn.functional as F
import time

from typing import List

class BasicLayer(nn.Module):
"""
Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
Expand Down Expand Up @@ -152,3 +154,12 @@ def forward(self, x):
keypoints = self.keypoint_head(self._unfold2d(x, ws=8)) #Keypoint map logits

return feats, keypoints, heatmap

@staticmethod
def get_xfeat_input_names() -> List[str]:
return ["image"]

@staticmethod
def get_xfeat_output_names() -> List[str]:
return ["feats", "keypoints", "heatmap"]

58 changes: 48 additions & 10 deletions modules/xfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,32 @@ class XFeat(nn.Module):
It supports inference for both sparse and semi-dense feature extraction & matching.
"""

def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096):
def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt',
top_k = 4096,
use_engine=False,
use_fp16=False):
super().__init__()
self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net = XFeatModel().to(self.dev).eval()
self.top_k = top_k

if weights is not None:
if isinstance(weights, str):
print('loading weights from: ' + weights)
self.net.load_state_dict(torch.load(weights, map_location=self.dev))
self.use_engine = use_engine
self.use_fp16 = use_fp16
if self.use_engine:
if os.path.exists(os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine'):
weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine'
else:
self.net.load_state_dict(weights)
raise Exception('Engine file does not exist.')
self.net = XFeat.load_xfeat_engine(weights)
self.dev = 'cuda' # force cuda for TensorRT
if self.use_fp16:
self.net.half()
else:
self.net = XFeatModel().to(self.dev).eval()
if weights is not None:
if isinstance(weights, str):
print('loading weights from: ' + weights)
self.net.load_state_dict(torch.load(weights, map_location=self.dev))
else:
self.net.load_state_dict(weights)

self.interpolator = InterpolateSparse2d('bicubic')

Expand All @@ -55,12 +69,12 @@ def detectAndCompute(self, x, top_k = None):
B, _, _H1, _W1 = x.shape

M1, K1, H1 = self.net(x)

M1 = F.normalize(M1, dim=1)

#Convert logits to heatmap and extract kpts
K1h = self.get_kpts_heatmap(K1)
mkpts = self.NMS(K1h, threshold=0.05, kernel_size=5)

#Compute reliability scores
_nearest = InterpolateSparse2d('nearest')
_bilinear = InterpolateSparse2d('bilinear')
Expand Down Expand Up @@ -174,6 +188,8 @@ def preprocess_tensor(self, x):
if isinstance(x, np.ndarray) and x.shape == 3:
x = torch.tensor(x).permute(2,0,1)[None]
x = x.to(self.dev).float()
if self.use_fp16:
x.half()

H, W = x.shape[-2:]
_H, _W = (H//32) * 32, (W//32) * 32
Expand Down Expand Up @@ -272,7 +288,6 @@ def match(self, feats1, feats2, min_cossim = 0.82):

cossim = feats1 @ feats2.t()
cossim_t = feats2 @ feats1.t()

_, match12 = cossim.max(dim=1)
_, match21 = cossim_t.max(dim=1)

Expand Down Expand Up @@ -344,3 +359,26 @@ def parse_input(self, x):
x = torch.tensor(x).permute(0,3,1,2)/255

return x

@staticmethod
def load_xfeat_engine(engine_path: str):
if not engine_path.endswith(".engine"):
raise Exception('Invalid Engine file.')

import tensorrt as trt
from torch2trt import TRTModule
trt.init_libnvinfer_plugins(None,'')

with trt.Logger() as logger, trt.Runtime(logger) as runtime:
with open(engine_path, 'rb') as f:
engine_bytes = f.read()
engine = runtime.deserialize_cuda_engine(engine_bytes)

xfeat_trt = TRTModule(
engine,
input_names=XFeatModel.get_xfeat_input_names(),
output_names=XFeatModel.get_xfeat_output_names(),
)

return xfeat_trt

15 changes: 10 additions & 5 deletions realtime_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def argparser():
parser.add_argument('--max_kpts', type=int, default=3_000, help='Maximum number of keypoints.')
parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat'], default='XFeat', help='Local feature detection method to use.')
parser.add_argument('--cam', type=int, default=0, help='Webcam device number.')
parser.add_argument('--use_engine', type=bool, default=False, help='Use generated TensorRT engine file.')
parser.add_argument('--use_fp16', type=bool, default=False, help='Use generated TensorRT engine file with fp16 precision.')
return parser.parse_args()


Expand Down Expand Up @@ -59,13 +61,13 @@ def __init__(self, descriptor, matcher):
self.descriptor = descriptor
self.matcher = matcher

def init_method(method, max_kpts):
def init_method(method, max_kpts, use_engine=False, use_fp16=False):
if method == "ORB":
return Method(descriptor=cv2.ORB_create(max_kpts, fastThreshold=10), matcher=cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True))
elif method == "SIFT":
return Method(descriptor=cv2.SIFT_create(max_kpts, contrastThreshold=-1, edgeThreshold=1000), matcher=cv2.BFMatcher(cv2.NORM_L2, crossCheck=True))
elif method == "XFeat":
return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts)), matcher=XFeat())
return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts, use_engine=use_engine, use_fp16=use_fp16)), matcher=XFeat(use_engine=use_engine, use_fp16=use_fp16))
else:
raise RuntimeError("Invalid Method.")

Expand Down Expand Up @@ -97,7 +99,10 @@ def __init__(self, args):
self.max_cnt = 30 #avg FPS over this number of frames

#Set local feature method here -- we expect cv2 or Kornia convention
self.method = init_method(args.method, max_kpts=args.max_kpts)
self.method = init_method(args.method,
max_kpts=args.max_kpts,
use_engine=args.use_engine,
use_fp16=args.use_fp16)

# Setting up font for captions
self.font = cv2.FONT_HERSHEY_SIMPLEX
Expand Down Expand Up @@ -129,9 +134,9 @@ def setup_camera(self):
def draw_quad(self, frame, point_list):
if len(self.corners) > 1:
for i in range(len(self.corners) - 1):
cv2.line(frame, point_list[i], point_list[i + 1], self.line_color, self.line_thickness, lineType = self.line_type)
cv2.line(frame, tuple(point_list[i]), tuple(point_list[i + 1]), self.line_color, self.line_thickness, lineType = self.line_type)
if len(self.corners) == 4: # Close the quadrilateral if 4 corners are defined
cv2.line(frame, point_list[3], point_list[0], self.line_color, self.line_thickness, lineType = self.line_type)
cv2.line(frame, tuple(point_list[3]), tuple(point_list[0]), self.line_color, self.line_thickness, lineType = self.line_type)

def mouse_callback(self, event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
Expand Down