|
| 1 | +#!/usr/bin/env python |
| 2 | +import copy |
| 3 | +import os |
| 4 | +import sys |
| 5 | +from typing import Any, Dict |
| 6 | + |
| 7 | +import torch |
| 8 | +from einops.einops import rearrange |
| 9 | + |
| 10 | +_CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) |
| 11 | +sys.path.append(os.path.join(_CURRENT_DIR, "LoFTR")) |
| 12 | + |
| 13 | +from src.loftr import LoFTR, default_cfg |
| 14 | + |
| 15 | +DEFAULT_CFG = copy.deepcopy(default_cfg) |
| 16 | +DEFAULT_CFG["coarse"]["temp_bug_fix"] = True |
| 17 | + |
| 18 | + |
| 19 | +class LoFTRWrapper(LoFTR): |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + config: Dict[str, Any] = DEFAULT_CFG, |
| 23 | + ): |
| 24 | + LoFTR.__init__(self, config) |
| 25 | + |
| 26 | + def forward( |
| 27 | + self, |
| 28 | + image0: torch.Tensor, |
| 29 | + image1: torch.Tensor, |
| 30 | + ) -> Dict[str, torch.Tensor]: |
| 31 | + data = { |
| 32 | + "image0": image0, |
| 33 | + "image1": image1, |
| 34 | + } |
| 35 | + del image0, image1 |
| 36 | + |
| 37 | + data.update( |
| 38 | + { |
| 39 | + "bs": data["image0"].size(0), |
| 40 | + "hw0_i": data["image0"].shape[2:], |
| 41 | + "hw1_i": data["image1"].shape[2:], |
| 42 | + } |
| 43 | + ) |
| 44 | + |
| 45 | + if data["hw0_i"] == data["hw1_i"]: # faster & better BN convergence |
| 46 | + feats_c, feats_f = self.backbone( |
| 47 | + torch.cat([data["image0"], data["image1"]], dim=0) |
| 48 | + ) |
| 49 | + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split( |
| 50 | + data["bs"] |
| 51 | + ), feats_f.split(data["bs"]) |
| 52 | + else: # handle different input shapes |
| 53 | + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone( |
| 54 | + data["image0"] |
| 55 | + ), self.backbone(data["image1"]) |
| 56 | + |
| 57 | + data.update( |
| 58 | + { |
| 59 | + "hw0_c": feat_c0.shape[2:], |
| 60 | + "hw1_c": feat_c1.shape[2:], |
| 61 | + "hw0_f": feat_f0.shape[2:], |
| 62 | + "hw1_f": feat_f1.shape[2:], |
| 63 | + } |
| 64 | + ) |
| 65 | + |
| 66 | + # 2. coarse-level loftr module |
| 67 | + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] |
| 68 | + feat_c0 = rearrange(self.pos_encoding(feat_c0), "n c h w -> n (h w) c") |
| 69 | + feat_c1 = rearrange(self.pos_encoding(feat_c1), "n c h w -> n (h w) c") |
| 70 | + |
| 71 | + mask_c0 = mask_c1 = None # mask is useful in training |
| 72 | + if "mask0" in data: |
| 73 | + mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2) |
| 74 | + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) |
| 75 | + |
| 76 | + # 3. match coarse-level |
| 77 | + self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) |
| 78 | + |
| 79 | + # 4. fine-level refinement |
| 80 | + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess( |
| 81 | + feat_f0, feat_f1, feat_c0, feat_c1, data |
| 82 | + ) |
| 83 | + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted |
| 84 | + feat_f0_unfold, feat_f1_unfold = self.loftr_fine( |
| 85 | + feat_f0_unfold, feat_f1_unfold |
| 86 | + ) |
| 87 | + |
| 88 | + # 5. match fine-level |
| 89 | + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) |
| 90 | + |
| 91 | + rename_keys: Dict[str, str] = { |
| 92 | + "mkpts0_f": "keypoints0", |
| 93 | + "mkpts1_f": "keypoints1", |
| 94 | + "mconf": "confidence", |
| 95 | + } |
| 96 | + out: Dict[str, torch.Tensor] = {} |
| 97 | + for k, v in rename_keys.items(): |
| 98 | + _d = data[k] |
| 99 | + if isinstance(_d, torch.Tensor): |
| 100 | + out[v] = _d |
| 101 | + else: |
| 102 | + raise TypeError( |
| 103 | + f"Expected torch.Tensor for item `{k}`. Gotcha {type(_d)}" |
| 104 | + ) |
| 105 | + del data |
| 106 | + |
| 107 | + return out |
0 commit comments