Skip to content

Commit 61a8423

Browse files
committed
update to add more model seperation
1 parent 7e06969 commit 61a8423

File tree

21 files changed

+215
-194
lines changed

21 files changed

+215
-194
lines changed

conf/dataset/default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# cfg:
33
# torch data-loader specific arguments
44
cfg:
5+
num_classes:
6+
feature_dimension:
57
batch_size: ${training.batch_size}
68
num_workers: ${training.num_workers}
79
dataroot: data

conf/dataset/segmentation/s3dis/s3dis1x1.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ defaults:
33
- segmentation/default
44
_target_: torch_points3d.datasets.s3dis1x1.s3dis_data_module
55
cfg:
6+
num_classes: 13
7+
feature_dimension: 6 # todo: able to calculate this dynamically
68
fold: 5

conf/model/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ defaults:
55
# By default we turn off recursive instantiation, allowing the user to instantiate themselves at the appropriate times.
66
_recursive_: false
77

8-
_target_: torch_points3d.models.base_model.PointCloudBaseModel
8+
_target_: torch_points3d.tasks.base_model.PointCloudBaseModule

conf/model/segmentation/default.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
defaults:
33
- /model/default
44

5-
_target_: torch_points3d.models.segmentation.base_model.SegmentationBaseModel
5+
model:
6+
_recursive_: false
7+
_target_: torch_points3d.models.segmentation.base_model.SegmentationBaseModel
8+
num_classes: ${dataset.cfg.num_classes}
69

7-
backbone:
8-
architecture: unet
10+
backbone:
11+
input_nc: ${dataset.cfg.feature_dimension}
12+
architecture: unet

conf/model/segmentation/sparseconv3d/Res16UNet34.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
defaults:
33
- /model/segmentation/ResUNet32
44

5-
backbone:
6-
down_conv:
7-
N: [ 0, 2, 3, 4, 6 ]
8-
up_conv:
9-
N: [ 1, 1, 1, 1, 1 ]
5+
model:
6+
backbone:
7+
down_conv:
8+
N: [ 0, 2, 3, 4, 6 ]
9+
up_conv:
10+
N: [ 1, 1, 1, 1, 1 ]

conf/model/segmentation/sparseconv3d/ResUNet32.yaml

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,40 @@
22
defaults:
33
- /model/segmentation/default
44

5-
backbone:
6-
_target_: torch_points3d.applications.sparseconv3d.SparseConv3d
7-
backend: torchsparse
5+
model:
6+
backbone:
7+
_target_: torch_points3d.applications.sparseconv3d.SparseConv3d
8+
backend: torchsparse
89

9-
config:
10-
define_constants:
11-
in_feat: 32
12-
block: ResBlock # Can be any of the blocks in modules/MinkowskiEngine/api_modules.py
13-
down_conv:
14-
module_name: ResNetDown
15-
block: block
16-
N: [ 0, 1, 2, 2, 3 ]
17-
down_conv_nn:
18-
[
19-
[ FEAT, in_feat ],
20-
[ in_feat, in_feat ],
21-
[ in_feat, 2*in_feat ],
22-
[ 2*in_feat, 4*in_feat ],
23-
[ 4*in_feat, 8*in_feat ],
24-
]
25-
kernel_size: 3
26-
stride: [ 1, 2, 2, 2, 2 ]
27-
up_conv:
28-
block: block
29-
module_name: ResNetUp
30-
N: [ 1, 1, 1, 1, 0 ]
31-
up_conv_nn:
32-
[
33-
[ 8*in_feat, 4*in_feat ],
34-
[ 4*in_feat + 4*in_feat, 4*in_feat ],
35-
[ 4*in_feat + 2*in_feat, 3*in_feat ],
36-
[ 3*in_feat + in_feat, 3*in_feat ],
37-
[ 3*in_feat + in_feat, 3*in_feat ],
38-
]
39-
kernel_size: 3
40-
stride: [ 2, 2, 2, 2, 1 ]
10+
config:
11+
define_constants:
12+
in_feat: 32
13+
block: ResBlock # Can be any of the blocks in modules/MinkowskiEngine/api_modules.py
14+
down_conv:
15+
module_name: ResNetDown
16+
block: block
17+
N: [ 0, 1, 2, 2, 3 ]
18+
down_conv_nn:
19+
[
20+
[ FEAT, in_feat ],
21+
[ in_feat, in_feat ],
22+
[ in_feat, 2*in_feat ],
23+
[ 2*in_feat, 4*in_feat ],
24+
[ 4*in_feat, 8*in_feat ],
25+
]
26+
kernel_size: 3
27+
stride: [ 1, 2, 2, 2, 2 ]
28+
up_conv:
29+
block: block
30+
module_name: ResNetUp
31+
N: [ 1, 1, 1, 1, 0 ]
32+
up_conv_nn:
33+
[
34+
[ 8*in_feat, 4*in_feat ],
35+
[ 4*in_feat + 4*in_feat, 4*in_feat ],
36+
[ 4*in_feat + 2*in_feat, 3*in_feat ],
37+
[ 3*in_feat + in_feat, 3*in_feat ],
38+
[ 3*in_feat + in_feat, 3*in_feat ],
39+
]
40+
kernel_size: 3
41+
stride: [ 2, 2, 2, 2, 1 ]

torch_points3d/models/base_architectures/unet.py renamed to torch_points3d/applications/base_architectures/unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
BatchNorm1d as BN,
1414
Dropout,
1515
)
16-
from torch_points3d.models.base_architectures.base_model import BaseModel
16+
from torch_points3d.applications.base_architectures.base_model import BaseModel
1717
from omegaconf.listconfig import ListConfig
1818
from omegaconf.dictconfig import DictConfig
1919
import logging

0 commit comments

Comments
 (0)