|
| 1 | +import os |
| 2 | +import sys |
| 3 | +from omegaconf import DictConfig, OmegaConf |
| 4 | +import logging |
| 5 | +import torch |
| 6 | +from torch_geometric.data import Batch |
| 7 | + |
| 8 | +from torch_points3d.applications.modelfactory import ModelFactory |
| 9 | +import torch_points3d.modules.SparseConv3d as sp3d |
| 10 | +from torch_points3d.modules.SparseConv3d.modules import * |
| 11 | + |
| 12 | +# from torch_points3d.core.base_conv.message_passing import * |
| 13 | +# from torch_points3d.core.base_conv.partial_dense import * |
| 14 | +from torch_points3d.models.base_architectures.unet import UnwrappedUnetBasedModel |
| 15 | +from torch_points3d.core.common_modules.base_modules import MLP |
| 16 | + |
| 17 | +from .utils import extract_output_nc |
| 18 | + |
| 19 | + |
| 20 | +CUR_FILE = os.path.realpath(__file__) |
| 21 | +DIR_PATH = os.path.dirname(os.path.realpath(__file__)) |
| 22 | +PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/sparseconv3d") |
| 23 | + |
| 24 | +log = logging.getLogger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +def SparseConv3d( |
| 28 | + architecture: str = None, |
| 29 | + input_nc: int = None, |
| 30 | + num_layers: int = None, |
| 31 | + config: DictConfig = None, |
| 32 | + backend: str = "minkowski", |
| 33 | + *args, |
| 34 | + **kwargs |
| 35 | +): |
| 36 | + """Create a Sparse Conv backbone model based on architecture proposed in |
| 37 | + https://arxiv.org/abs/1904.08755 |
| 38 | +
|
| 39 | + Two backends are available at the moment: |
| 40 | + - https://github.com/mit-han-lab/torchsparse |
| 41 | + - https://github.com/NVIDIA/MinkowskiEngine |
| 42 | +
|
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + architecture : str, optional |
| 46 | + Architecture of the model, choose from unet, encoder and decoder |
| 47 | + input_nc : int, optional |
| 48 | + Number of channels for the input |
| 49 | + output_nc : int, optional |
| 50 | + If specified, then we add a fully connected head at the end of the network to provide the requested dimension |
| 51 | + num_layers : int, optional |
| 52 | + Depth of the network |
| 53 | + config : DictConfig, optional |
| 54 | + Custom config, overrides the num_layers and architecture parameters |
| 55 | + block: |
| 56 | + Type of resnet block, ResBlock by default but can be any of the blocks in modules/SparseConv3d/modules.py |
| 57 | + backend: |
| 58 | + torchsparse or minkowski |
| 59 | + """ |
| 60 | + if "SPARSE_BACKEND" in os.environ and sp3d.nn.backend_valid(os.environ["SPARSE_BACKEND"]): |
| 61 | + sp3d.nn.set_backend(os.environ["SPARSE_BACKEND"]) |
| 62 | + else: |
| 63 | + sp3d.nn.set_backend(backend) |
| 64 | + |
| 65 | + factory = SparseConv3dFactory( |
| 66 | + architecture=architecture, num_layers=num_layers, input_nc=input_nc, config=config, **kwargs |
| 67 | + ) |
| 68 | + return factory.build() |
| 69 | + |
| 70 | + |
| 71 | +class SparseConv3dFactory(ModelFactory): |
| 72 | + def _build_unet(self): |
| 73 | + if self._config: |
| 74 | + model_config = self._config |
| 75 | + else: |
| 76 | + path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers)) |
| 77 | + model_config = OmegaConf.load(path_to_model) |
| 78 | + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) |
| 79 | + modules_lib = sys.modules[__name__] |
| 80 | + return SparseConv3dUnet(model_config, None, modules_lib, **self.kwargs) |
| 81 | + |
| 82 | + def _build_encoder(self): |
| 83 | + if self._config: |
| 84 | + model_config = self._config |
| 85 | + else: |
| 86 | + path_to_model = os.path.join( |
| 87 | + PATH_TO_CONFIG, |
| 88 | + "encoder_{}.yaml".format(self.num_layers), |
| 89 | + ) |
| 90 | + model_config = OmegaConf.load(path_to_model) |
| 91 | + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) |
| 92 | + modules_lib = sys.modules[__name__] |
| 93 | + return SparseConv3dEncoder(model_config, None, modules_lib, **self.kwargs) |
| 94 | + |
| 95 | + |
| 96 | +class BaseSparseConv3d(UnwrappedUnetBasedModel): |
| 97 | + CONV_TYPE = "sparse" |
| 98 | + |
| 99 | + def __init__(self, model_config, model_type, modules, *args, **kwargs): |
| 100 | + super().__init__(model_config, model_type, modules) |
| 101 | + self.weight_initialization() |
| 102 | + default_output_nc = kwargs.get("default_output_nc", None) |
| 103 | + if not default_output_nc: |
| 104 | + default_output_nc = extract_output_nc(model_config) |
| 105 | + |
| 106 | + self._output_nc = default_output_nc |
| 107 | + self._has_mlp_head = False |
| 108 | + if "output_nc" in kwargs: |
| 109 | + self._has_mlp_head = True |
| 110 | + self._output_nc = kwargs["output_nc"] |
| 111 | + self.mlp = MLP([default_output_nc, self.output_nc], activation=torch.nn.ReLU(), bias=False) |
| 112 | + |
| 113 | + @property |
| 114 | + def has_mlp_head(self): |
| 115 | + return self._has_mlp_head |
| 116 | + |
| 117 | + @property |
| 118 | + def output_nc(self): |
| 119 | + return self._output_nc |
| 120 | + |
| 121 | + def weight_initialization(self): |
| 122 | + for m in self.modules(): |
| 123 | + if isinstance(m, sp3d.nn.Conv3d) or isinstance(m, sp3d.nn.Conv3dTranspose): |
| 124 | + torch.nn.init.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") |
| 125 | + |
| 126 | + if isinstance(m, sp3d.nn.BatchNorm): |
| 127 | + torch.nn.init.constant_(m.bn.weight, 1) |
| 128 | + torch.nn.init.constant_(m.bn.bias, 0) |
| 129 | + |
| 130 | + def _set_input(self, data): |
| 131 | + """Unpack input data from the dataloader and perform necessary pre-processing steps. |
| 132 | +
|
| 133 | + Parameters |
| 134 | + ----------- |
| 135 | + data: |
| 136 | + a dictionary that contains the data itself and its metadata information. |
| 137 | + """ |
| 138 | + self.input = sp3d.nn.SparseTensor(data.x, data.coords, data.batch) |
| 139 | + if data.pos is not None: |
| 140 | + self.xyz = data.pos |
| 141 | + else: |
| 142 | + self.xyz = data.coords |
| 143 | + |
| 144 | + |
| 145 | +class SparseConv3dEncoder(BaseSparseConv3d): |
| 146 | + def forward(self, data, *args, **kwargs): |
| 147 | + """ |
| 148 | + Parameters: |
| 149 | + ----------- |
| 150 | + data |
| 151 | + A SparseTensor that contains the data itself and its metadata information. Should contain |
| 152 | + F -- Features [N, C] |
| 153 | + coords -- Coords [N, 4] |
| 154 | +
|
| 155 | + Returns |
| 156 | + -------- |
| 157 | + data: |
| 158 | + - x [1, output_nc] |
| 159 | +
|
| 160 | + """ |
| 161 | + self._set_input(data) |
| 162 | + data = self.input |
| 163 | + for i in range(len(self.down_modules)): |
| 164 | + data = self.down_modules[i](data) |
| 165 | + |
| 166 | + out = Batch(x=data.F, batch=data.C[:, 0].long().to(data.F.device)) |
| 167 | + if not isinstance(self.inner_modules[0], Identity): |
| 168 | + out = self.inner_modules[0](out) |
| 169 | + |
| 170 | + if self.has_mlp_head: |
| 171 | + out.x = self.mlp(out.x) |
| 172 | + return out |
| 173 | + |
| 174 | + |
| 175 | +class SparseConv3dUnet(BaseSparseConv3d): |
| 176 | + def forward(self, data, *args, **kwargs): |
| 177 | + """Run forward pass. |
| 178 | + Input --- D1 -- D2 -- D3 -- U1 -- U2 -- output |
| 179 | + | |_________| | |
| 180 | + |______________________| |
| 181 | +
|
| 182 | + Parameters |
| 183 | + ----------- |
| 184 | + data |
| 185 | + A SparseTensor that contains the data itself and its metadata information. Should contain |
| 186 | + F -- Features [N, C] |
| 187 | + coords -- Coords [N, 4] |
| 188 | +
|
| 189 | + Returns |
| 190 | + -------- |
| 191 | + data: |
| 192 | + - pos [N, 3] (coords or real pos if xyz is in data) |
| 193 | + - x [N, output_nc] |
| 194 | + - batch [N] |
| 195 | + """ |
| 196 | + self._set_input(data) |
| 197 | + data = self.input |
| 198 | + stack_down = [] |
| 199 | + for i in range(len(self.down_modules) - 1): |
| 200 | + data = self.down_modules[i](data) |
| 201 | + stack_down.append(data) |
| 202 | + |
| 203 | + data = self.down_modules[-1](data) |
| 204 | + stack_down.append(None) |
| 205 | + # TODO : Manage the inner module |
| 206 | + for i in range(len(self.up_modules)): |
| 207 | + data = self.up_modules[i](data, stack_down.pop()) |
| 208 | + |
| 209 | + out = Batch(x=data.F, pos=self.xyz) |
| 210 | + if self.has_mlp_head: |
| 211 | + out.x = self.mlp(out.x) |
| 212 | + return out |
0 commit comments