Skip to content

Commit 1182cbb

Browse files
authored
Merge pull request #6 from torch-points3d/model
Model
2 parents 7fa0cb7 + d89235a commit 1182cbb

File tree

24 files changed

+1480
-5
lines changed

24 files changed

+1480
-5
lines changed

test/test_model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
import sys
3+
import os
4+
import torch
5+
from omegaconf import OmegaConf
6+
7+
from torch_geometric.data import Batch
8+
9+
DIR = os.path.dirname(os.path.realpath(__file__))
10+
ROOT = os.path.join(DIR, "..")
11+
sys.path.insert(0, ROOT)
12+
13+
from torch_points3d.models.segmentation.sparseconv3d import APIModel
14+
15+
16+
class TestAPIModel(unittest.TestCase):
17+
def test_forward(self):
18+
option_dataset = OmegaConf.create({"feature_dimension": 1, "num_classes": 10})
19+
20+
option = OmegaConf.load(os.path.join(ROOT, "conf", "models", "segmentation", "sparseconv3d.yaml"))
21+
name_model = list(option.keys())[0]
22+
model = APIModel(option[name_model], option_dataset)
23+
24+
pos = torch.randn(1000, 3)
25+
coords = torch.round(pos * 10000)
26+
x = torch.ones(1000, 1)
27+
batch = torch.zeros(1000).long()
28+
y = torch.randint(0, 10, (1000,))
29+
data = Batch(pos=pos, x=x, batch=batch, y=y, coords=coords)
30+
model.set_input(data)
31+
model.forward()
32+
33+
34+
if __name__ == "__main__":
35+
unittest.main()

torch_points3d/applications/__init__.py

Whitespace-only changes.
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from enum import Enum
2+
from omegaconf import DictConfig
3+
import logging
4+
5+
from torch_points3d.utils.model_building_utils.model_definition_resolver import resolve
6+
7+
log = logging.getLogger(__name__)
8+
9+
10+
class ModelArchitectures(Enum):
11+
UNET = "unet"
12+
ENCODER = "encoder"
13+
DECODER = "decoder"
14+
15+
16+
class ModelFactory:
17+
MODEL_ARCHITECTURES = [e.value for e in ModelArchitectures]
18+
19+
@staticmethod
20+
def raise_enum_error(arg_name, arg_value, options):
21+
raise Exception("The provided argument {} with value {} isn't within {}".format(arg_name, arg_value, options))
22+
23+
def __init__(
24+
self,
25+
architecture: str = None,
26+
input_nc: int = None,
27+
num_layers: int = None,
28+
config: DictConfig = None,
29+
**kwargs
30+
):
31+
if not architecture:
32+
raise ValueError()
33+
self._architecture = architecture.lower()
34+
assert self._architecture in self.MODEL_ARCHITECTURES, ModelFactory.raise_enum_error(
35+
"model_architecture", self._architecture, self.MODEL_ARCHITECTURES
36+
)
37+
38+
self._input_nc = input_nc
39+
self._num_layers = num_layers
40+
self._config = config
41+
self._kwargs = kwargs
42+
43+
if self._config:
44+
log.info("The config will be used to build the model")
45+
46+
@property
47+
def modules_lib(self):
48+
raise NotImplementedError
49+
50+
@property
51+
def kwargs(self):
52+
return self._kwargs
53+
54+
@property
55+
def num_layers(self):
56+
return self._num_layers
57+
58+
@property
59+
def num_features(self):
60+
return self._input_nc
61+
62+
def _build_unet(self):
63+
raise NotImplementedError
64+
65+
def _build_encoder(self):
66+
raise NotImplementedError
67+
68+
def _build_decoder(self):
69+
raise NotImplementedError
70+
71+
def build(self):
72+
if self._architecture == ModelArchitectures.UNET.value:
73+
return self._build_unet()
74+
elif self._architecture == ModelArchitectures.ENCODER.value:
75+
return self._build_encoder()
76+
elif self._architecture == ModelArchitectures.DECODER.value:
77+
return self._build_decoder()
78+
else:
79+
raise NotImplementedError
80+
81+
@staticmethod
82+
def resolve_model(model_config, num_features, kwargs):
83+
"""Parses the model config and evaluates any expression that may contain constants
84+
Overrides any argument in the `define_constants` with keywords wrgument to the constructor
85+
"""
86+
# placeholders to subsitute
87+
constants = {
88+
"FEAT": max(num_features, 0),
89+
}
90+
91+
# user defined contants to subsitute
92+
if "define_constants" in model_config.keys():
93+
constants.update(dict(model_config.define_constants))
94+
define_constants = model_config.define_constants
95+
for key in define_constants.keys():
96+
value = kwargs.get(key)
97+
if value:
98+
constants[key] = value
99+
resolve(model_config, constants)
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import os
2+
import sys
3+
from omegaconf import DictConfig, OmegaConf
4+
import logging
5+
import torch
6+
from torch_geometric.data import Batch
7+
8+
from torch_points3d.applications.modelfactory import ModelFactory
9+
import torch_points3d.modules.SparseConv3d as sp3d
10+
from torch_points3d.modules.SparseConv3d.modules import *
11+
12+
# from torch_points3d.core.base_conv.message_passing import *
13+
# from torch_points3d.core.base_conv.partial_dense import *
14+
from torch_points3d.models.base_architectures.unet import UnwrappedUnetBasedModel
15+
from torch_points3d.core.common_modules.base_modules import MLP
16+
17+
from .utils import extract_output_nc
18+
19+
20+
CUR_FILE = os.path.realpath(__file__)
21+
DIR_PATH = os.path.dirname(os.path.realpath(__file__))
22+
PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/sparseconv3d")
23+
24+
log = logging.getLogger(__name__)
25+
26+
27+
def SparseConv3d(
28+
architecture: str = None,
29+
input_nc: int = None,
30+
num_layers: int = None,
31+
config: DictConfig = None,
32+
backend: str = "minkowski",
33+
*args,
34+
**kwargs
35+
):
36+
"""Create a Sparse Conv backbone model based on architecture proposed in
37+
https://arxiv.org/abs/1904.08755
38+
39+
Two backends are available at the moment:
40+
- https://github.com/mit-han-lab/torchsparse
41+
- https://github.com/NVIDIA/MinkowskiEngine
42+
43+
Parameters
44+
----------
45+
architecture : str, optional
46+
Architecture of the model, choose from unet, encoder and decoder
47+
input_nc : int, optional
48+
Number of channels for the input
49+
output_nc : int, optional
50+
If specified, then we add a fully connected head at the end of the network to provide the requested dimension
51+
num_layers : int, optional
52+
Depth of the network
53+
config : DictConfig, optional
54+
Custom config, overrides the num_layers and architecture parameters
55+
block:
56+
Type of resnet block, ResBlock by default but can be any of the blocks in modules/SparseConv3d/modules.py
57+
backend:
58+
torchsparse or minkowski
59+
"""
60+
if "SPARSE_BACKEND" in os.environ and sp3d.nn.backend_valid(os.environ["SPARSE_BACKEND"]):
61+
sp3d.nn.set_backend(os.environ["SPARSE_BACKEND"])
62+
else:
63+
sp3d.nn.set_backend(backend)
64+
65+
factory = SparseConv3dFactory(
66+
architecture=architecture, num_layers=num_layers, input_nc=input_nc, config=config, **kwargs
67+
)
68+
return factory.build()
69+
70+
71+
class SparseConv3dFactory(ModelFactory):
72+
def _build_unet(self):
73+
if self._config:
74+
model_config = self._config
75+
else:
76+
path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers))
77+
model_config = OmegaConf.load(path_to_model)
78+
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
79+
modules_lib = sys.modules[__name__]
80+
return SparseConv3dUnet(model_config, None, modules_lib, **self.kwargs)
81+
82+
def _build_encoder(self):
83+
if self._config:
84+
model_config = self._config
85+
else:
86+
path_to_model = os.path.join(
87+
PATH_TO_CONFIG,
88+
"encoder_{}.yaml".format(self.num_layers),
89+
)
90+
model_config = OmegaConf.load(path_to_model)
91+
ModelFactory.resolve_model(model_config, self.num_features, self._kwargs)
92+
modules_lib = sys.modules[__name__]
93+
return SparseConv3dEncoder(model_config, None, modules_lib, **self.kwargs)
94+
95+
96+
class BaseSparseConv3d(UnwrappedUnetBasedModel):
97+
CONV_TYPE = "sparse"
98+
99+
def __init__(self, model_config, model_type, modules, *args, **kwargs):
100+
super().__init__(model_config, model_type, modules)
101+
self.weight_initialization()
102+
default_output_nc = kwargs.get("default_output_nc", None)
103+
if not default_output_nc:
104+
default_output_nc = extract_output_nc(model_config)
105+
106+
self._output_nc = default_output_nc
107+
self._has_mlp_head = False
108+
if "output_nc" in kwargs:
109+
self._has_mlp_head = True
110+
self._output_nc = kwargs["output_nc"]
111+
self.mlp = MLP([default_output_nc, self.output_nc], activation=torch.nn.ReLU(), bias=False)
112+
113+
@property
114+
def has_mlp_head(self):
115+
return self._has_mlp_head
116+
117+
@property
118+
def output_nc(self):
119+
return self._output_nc
120+
121+
def weight_initialization(self):
122+
for m in self.modules():
123+
if isinstance(m, sp3d.nn.Conv3d) or isinstance(m, sp3d.nn.Conv3dTranspose):
124+
torch.nn.init.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")
125+
126+
if isinstance(m, sp3d.nn.BatchNorm):
127+
torch.nn.init.constant_(m.bn.weight, 1)
128+
torch.nn.init.constant_(m.bn.bias, 0)
129+
130+
def _set_input(self, data):
131+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
132+
133+
Parameters
134+
-----------
135+
data:
136+
a dictionary that contains the data itself and its metadata information.
137+
"""
138+
self.input = sp3d.nn.SparseTensor(data.x, data.coords, data.batch)
139+
if data.pos is not None:
140+
self.xyz = data.pos
141+
else:
142+
self.xyz = data.coords
143+
144+
145+
class SparseConv3dEncoder(BaseSparseConv3d):
146+
def forward(self, data, *args, **kwargs):
147+
"""
148+
Parameters:
149+
-----------
150+
data
151+
A SparseTensor that contains the data itself and its metadata information. Should contain
152+
F -- Features [N, C]
153+
coords -- Coords [N, 4]
154+
155+
Returns
156+
--------
157+
data:
158+
- x [1, output_nc]
159+
160+
"""
161+
self._set_input(data)
162+
data = self.input
163+
for i in range(len(self.down_modules)):
164+
data = self.down_modules[i](data)
165+
166+
out = Batch(x=data.F, batch=data.C[:, 0].long().to(data.F.device))
167+
if not isinstance(self.inner_modules[0], Identity):
168+
out = self.inner_modules[0](out)
169+
170+
if self.has_mlp_head:
171+
out.x = self.mlp(out.x)
172+
return out
173+
174+
175+
class SparseConv3dUnet(BaseSparseConv3d):
176+
def forward(self, data, *args, **kwargs):
177+
"""Run forward pass.
178+
Input --- D1 -- D2 -- D3 -- U1 -- U2 -- output
179+
| |_________| |
180+
|______________________|
181+
182+
Parameters
183+
-----------
184+
data
185+
A SparseTensor that contains the data itself and its metadata information. Should contain
186+
F -- Features [N, C]
187+
coords -- Coords [N, 4]
188+
189+
Returns
190+
--------
191+
data:
192+
- pos [N, 3] (coords or real pos if xyz is in data)
193+
- x [N, output_nc]
194+
- batch [N]
195+
"""
196+
self._set_input(data)
197+
data = self.input
198+
stack_down = []
199+
for i in range(len(self.down_modules) - 1):
200+
data = self.down_modules[i](data)
201+
stack_down.append(data)
202+
203+
data = self.down_modules[-1](data)
204+
stack_down.append(None)
205+
# TODO : Manage the inner module
206+
for i in range(len(self.up_modules)):
207+
data = self.up_modules[i](data, stack_down.pop())
208+
209+
out = Batch(x=data.F, pos=self.xyz)
210+
if self.has_mlp_head:
211+
out.x = self.mlp(out.x)
212+
return out
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
def extract_output_nc(model_config):
2+
"""Extracts the number of channels at the output of the network form the model config"""
3+
if model_config.get("up_conv") is not None:
4+
output_nc = model_config.up_conv.up_conv_nn[-1][-1]
5+
elif model_config.get("innermost") is not None:
6+
output_nc = model_config.innermost.nn[-1]
7+
else:
8+
raise ValueError("Input model_config does not match expected pattern")
9+
return output_nc
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .base_modules import *
2+
from .spatial_transform import *

0 commit comments

Comments
 (0)