diff --git a/.gitignore b/.gitignore index 68bc17f..ce39e3c 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +*.onnx +*.engine diff --git a/README.md b/README.md index 92c93f9..f571a8d 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,35 @@ xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True output = xfeat.detectAndCompute(torch.randn(1,3,480,640), top_k = 4096)[0] ``` +### TensorRT - Export +Its advisible to use a [NGC container](https://catalog.ngc.nvidia.com/containers). For example, for the NVIDIA Jetson platform refer to [L4T ML](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-ml/tags). The additional dependencies you need: +- [TensorRT](https://github.com/NVIDIA/TensorRT) - usually availble inside the docker container +- `onnx` +- `onnxruntime` +- [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt) + +``` +python3 export.py --help + +usage: export.py [-h] [--weights WEIGHTS] [--imgsz IMGSZ IMGSZ] [--fp16_mode FP16_MODE] [--use_dynamic_axis USE_DYNAMIC_AXIS] + [--onnx_opset ONNX_OPSET] [--workspace WORKSPACE] + +Create ONNX and TensorRT export for XFeat. + +optional arguments: + -h, --help show this help message and exit + --weights WEIGHTS Path to the weights pt file to process + --imgsz IMGSZ IMGSZ Input image size + --fp16_mode FP16_MODE + --use_dynamic_axis USE_DYNAMIC_AXIS + --onnx_opset ONNX_OPSET + --workspace WORKSPACE +``` +### TensorRT - Demo +``` +python3 realtime_demo.py --method XFeat --use_engine True +``` + ### Training XFeat training code will be released soon. Please stay tuned. diff --git a/export.py b/export.py new file mode 100644 index 0000000..78d9f25 --- /dev/null +++ b/export.py @@ -0,0 +1,101 @@ +import os +import torch +import tqdm +import argparse +import subprocess +import numpy as np +from modules.model import * + +def build_onnx_engine(weights: str, + onnx_weight: str, + imgsz: tuple = (480,640), + use_dynamic_axis: bool = True, + onnx_opset: int = 17) -> None: + if onnx_weight is None: + raise Exception("Onnx file path cannot be None.") + dev = 'cpu' + net = XFeatModel().to(dev).eval() + net.load_state_dict(torch.load(weights, map_location=dev)) + #Random input + x = torch.randn(1,3,*imgsz).to(dev) + if onnx_weight is None: + raise Exception("Onnx file path cannot be None.") + if use_dynamic_axis: + dynamic_axis = { + "image": {0: "batch"}, + } + else: + dynamic_axis = {} + # net = TempModule(net) + torch.onnx.export( + net, + x, + onnx_weight, + input_names=XFeatModel.get_xfeat_input_names(), + output_names=XFeatModel.get_xfeat_output_names(), + dynamic_axes=dynamic_axis, + opset_version=onnx_opset, + ) + +def build_tensorrt_engine(weights: str, + imgsz: tuple = (480,640), + fp16_mode: bool = True, + use_dynamic_axis: bool = True, + onnx_opset: int = 17, + workspace: int = 4096) -> None: + + if weights.endswith(".pt"): + # Replace ".pt" with ".onnx" + onnx_weight = weights[:-3] + ".onnx" + else: + raise Exception("File path does not end with '.pt'.") + + build_onnx_engine(weights, onnx_weight, imgsz, use_dynamic_axis, onnx_opset) + + if not os.path.exists(onnx_weight): + raise Exception("ONNX export does not exist") + + if onnx_weight.endswith(".onnx"): + # Replace ".pt" with ".onnx" + engine_weight = onnx_weight[:-5] + ".engine" + else: + raise Exception("File path does not end with '.onnx'.") + + args = ["/usr/src/tensorrt/bin/trtexec"] + args.append(f"--onnx={onnx_weight}") + args.append(f"--saveEngine={engine_weight}") + args.append(f"--workspace={workspace}") + + if fp16_mode: + args += ["--fp16"] + + args += [f"--shapes=image:1x3x{imgsz[0]}x{imgsz[1]}"] + + subprocess.call(args) + print(f"Finished TensorRT engine export to {engine_weight}.") + +def main(): + parser = argparse.ArgumentParser(description='Create ONNX and TensorRT export for XFeat.') + parser.add_argument('--weights', type=str, default=f'{os.path.abspath(os.path.dirname(__file__))}/weights/xfeat.pt', help='Path to the weights pt file to process') + parser.add_argument('--imgsz', nargs=2, type=int, default=[480,640], help='Input image size') + parser.add_argument("--fp16_mode", type=bool, default=True) + parser.add_argument("--use_dynamic_axis", type=bool, default=True) + parser.add_argument("--onnx_opset", type=int, default=17) + parser.add_argument("--workspace", type=int, default=4096) + args = parser.parse_args() + weights = args.weights + imgsz = args.imgsz + fp16_mode = args.fp16_mode + onnx_opset = args.onnx_opset + use_dynamic_axis = args.use_dynamic_axis + workspace = args.workspace + build_tensorrt_engine(weights, imgsz, fp16_mode, use_dynamic_axis, onnx_opset, workspace) + +if __name__ == '__main__': + main() + + + + + + diff --git a/modules/model.py b/modules/model.py index 57539fd..45eb49d 100644 --- a/modules/model.py +++ b/modules/model.py @@ -9,6 +9,8 @@ import torch.nn.functional as F import time +from typing import List + class BasicLayer(nn.Module): """ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU @@ -152,3 +154,12 @@ def forward(self, x): keypoints = self.keypoint_head(self._unfold2d(x, ws=8)) #Keypoint map logits return feats, keypoints, heatmap + + @staticmethod + def get_xfeat_input_names() -> List[str]: + return ["image"] + + @staticmethod + def get_xfeat_output_names() -> List[str]: + return ["feats", "keypoints", "heatmap"] + diff --git a/modules/xfeat.py b/modules/xfeat.py index f60a63d..eb200c5 100644 --- a/modules/xfeat.py +++ b/modules/xfeat.py @@ -20,18 +20,32 @@ class XFeat(nn.Module): It supports inference for both sparse and semi-dense feature extraction & matching. """ - def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096): + def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', + top_k = 4096, + use_engine=False, + use_fp16=False): super().__init__() self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - self.net = XFeatModel().to(self.dev).eval() self.top_k = top_k - - if weights is not None: - if isinstance(weights, str): - print('loading weights from: ' + weights) - self.net.load_state_dict(torch.load(weights, map_location=self.dev)) + self.use_engine = use_engine + self.use_fp16 = use_fp16 + if self.use_engine: + if os.path.exists(os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine'): + weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine' else: - self.net.load_state_dict(weights) + raise Exception('Engine file does not exist.') + self.net = XFeat.load_xfeat_engine(weights) + self.dev = 'cuda' # force cuda for TensorRT + if self.use_fp16: + self.net.half() + else: + self.net = XFeatModel().to(self.dev).eval() + if weights is not None: + if isinstance(weights, str): + print('loading weights from: ' + weights) + self.net.load_state_dict(torch.load(weights, map_location=self.dev)) + else: + self.net.load_state_dict(weights) self.interpolator = InterpolateSparse2d('bicubic') @@ -55,12 +69,12 @@ def detectAndCompute(self, x, top_k = None): B, _, _H1, _W1 = x.shape M1, K1, H1 = self.net(x) + M1 = F.normalize(M1, dim=1) #Convert logits to heatmap and extract kpts K1h = self.get_kpts_heatmap(K1) mkpts = self.NMS(K1h, threshold=0.05, kernel_size=5) - #Compute reliability scores _nearest = InterpolateSparse2d('nearest') _bilinear = InterpolateSparse2d('bilinear') @@ -174,6 +188,8 @@ def preprocess_tensor(self, x): if isinstance(x, np.ndarray) and x.shape == 3: x = torch.tensor(x).permute(2,0,1)[None] x = x.to(self.dev).float() + if self.use_fp16: + x.half() H, W = x.shape[-2:] _H, _W = (H//32) * 32, (W//32) * 32 @@ -272,7 +288,6 @@ def match(self, feats1, feats2, min_cossim = 0.82): cossim = feats1 @ feats2.t() cossim_t = feats2 @ feats1.t() - _, match12 = cossim.max(dim=1) _, match21 = cossim_t.max(dim=1) @@ -344,3 +359,26 @@ def parse_input(self, x): x = torch.tensor(x).permute(0,3,1,2)/255 return x + + @staticmethod + def load_xfeat_engine(engine_path: str): + if not engine_path.endswith(".engine"): + raise Exception('Invalid Engine file.') + + import tensorrt as trt + from torch2trt import TRTModule + trt.init_libnvinfer_plugins(None,'') + + with trt.Logger() as logger, trt.Runtime(logger) as runtime: + with open(engine_path, 'rb') as f: + engine_bytes = f.read() + engine = runtime.deserialize_cuda_engine(engine_bytes) + + xfeat_trt = TRTModule( + engine, + input_names=XFeatModel.get_xfeat_input_names(), + output_names=XFeatModel.get_xfeat_output_names(), + ) + + return xfeat_trt + diff --git a/realtime_demo.py b/realtime_demo.py index 42f7ec5..9ee7052 100644 --- a/realtime_demo.py +++ b/realtime_demo.py @@ -22,6 +22,8 @@ def argparser(): parser.add_argument('--max_kpts', type=int, default=3_000, help='Maximum number of keypoints.') parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat'], default='XFeat', help='Local feature detection method to use.') parser.add_argument('--cam', type=int, default=0, help='Webcam device number.') + parser.add_argument('--use_engine', type=bool, default=False, help='Use generated TensorRT engine file.') + parser.add_argument('--use_fp16', type=bool, default=False, help='Use generated TensorRT engine file with fp16 precision.') return parser.parse_args() @@ -59,13 +61,13 @@ def __init__(self, descriptor, matcher): self.descriptor = descriptor self.matcher = matcher -def init_method(method, max_kpts): +def init_method(method, max_kpts, use_engine=False, use_fp16=False): if method == "ORB": return Method(descriptor=cv2.ORB_create(max_kpts, fastThreshold=10), matcher=cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)) elif method == "SIFT": return Method(descriptor=cv2.SIFT_create(max_kpts, contrastThreshold=-1, edgeThreshold=1000), matcher=cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)) elif method == "XFeat": - return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts)), matcher=XFeat()) + return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts, use_engine=use_engine, use_fp16=use_fp16)), matcher=XFeat(use_engine=use_engine, use_fp16=use_fp16)) else: raise RuntimeError("Invalid Method.") @@ -97,7 +99,10 @@ def __init__(self, args): self.max_cnt = 30 #avg FPS over this number of frames #Set local feature method here -- we expect cv2 or Kornia convention - self.method = init_method(args.method, max_kpts=args.max_kpts) + self.method = init_method(args.method, + max_kpts=args.max_kpts, + use_engine=args.use_engine, + use_fp16=args.use_fp16) # Setting up font for captions self.font = cv2.FONT_HERSHEY_SIMPLEX @@ -129,9 +134,9 @@ def setup_camera(self): def draw_quad(self, frame, point_list): if len(self.corners) > 1: for i in range(len(self.corners) - 1): - cv2.line(frame, point_list[i], point_list[i + 1], self.line_color, self.line_thickness, lineType = self.line_type) + cv2.line(frame, tuple(point_list[i]), tuple(point_list[i + 1]), self.line_color, self.line_thickness, lineType = self.line_type) if len(self.corners) == 4: # Close the quadrilateral if 4 corners are defined - cv2.line(frame, point_list[3], point_list[0], self.line_color, self.line_thickness, lineType = self.line_type) + cv2.line(frame, tuple(point_list[3]), tuple(point_list[0]), self.line_color, self.line_thickness, lineType = self.line_type) def mouse_callback(self, event, x, y, flags, param): if event == cv2.EVENT_LBUTTONDOWN: