Skip to content

Commit cbf6d45

Browse files
committed
move scheduler/optim to be contained within the model, not in global scope
separate "backbones" from "lightning models" fix sp3d sparsetensor device add basic data transforms (need to structure better than original repo) add training methods
1 parent 88bcc70 commit cbf6d45

File tree

22 files changed

+352
-238
lines changed

22 files changed

+352
-238
lines changed

conf/config.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
defaults: # loads default configs
22
- dataset: ???
3-
- optimizer: sgd
4-
- scheduler: default
53
- model: ???
64
- training: default
75
- trainer: default

conf/model/default.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
# @package model
2+
defaults:
3+
- /optimizer: sgd
4+
- /scheduler:
15
# By default we turn off recursive instantiation, allowing the user to instantiate themselves at the appropriate times.
26
_recursive_: false
37

4-
_target_: torch_points3d.models.base_model.PointCloudBaseModel
5-
optimizer: ${optimizer}
6-
scheduler: ${scheduler}
8+
_target_: torch_points3d.models.base_model.PointCloudBaseModel
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# @package model
2+
defaults:
3+
- /model/default
4+
5+
_target_: torch_points3d.models.segmentation.base_model.SegmentationBaseModel
6+
7+
backbone:
8+
architecture: unet

conf/model/segmentation/sparseconv3d.yaml

Lines changed: 0 additions & 74 deletions
This file was deleted.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# @package model
2+
defaults:
3+
- /model/segmentation/ResUNet32
4+
5+
backbone:
6+
down_conv:
7+
N: [ 0, 2, 3, 4, 6 ]
8+
up_conv:
9+
N: [ 1, 1, 1, 1, 1 ]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# @package model
2+
defaults:
3+
- /model/segmentation/default
4+
5+
backbone:
6+
_target_: torch_points3d.applications.sparseconv3d.SparseConv3d
7+
backend: torchsparse
8+
9+
config:
10+
define_constants:
11+
in_feat: 32
12+
block: ResBlock # Can be any of the blocks in modules/MinkowskiEngine/api_modules.py
13+
down_conv:
14+
module_name: ResNetDown
15+
block: block
16+
N: [ 0, 1, 2, 2, 3 ]
17+
down_conv_nn:
18+
[
19+
[ FEAT, in_feat ],
20+
[ in_feat, in_feat ],
21+
[ in_feat, 2*in_feat ],
22+
[ 2*in_feat, 4*in_feat ],
23+
[ 4*in_feat, 8*in_feat ],
24+
]
25+
kernel_size: 3
26+
stride: [ 1, 2, 2, 2, 2 ]
27+
up_conv:
28+
block: block
29+
module_name: ResNetUp
30+
N: [ 1, 1, 1, 1, 0 ]
31+
up_conv_nn:
32+
[
33+
[ 8*in_feat, 4*in_feat ],
34+
[ 4*in_feat + 4*in_feat, 4*in_feat ],
35+
[ 4*in_feat + 2*in_feat, 3*in_feat ],
36+
[ 3*in_feat + in_feat, 3*in_feat ],
37+
[ 3*in_feat + in_feat, 3*in_feat ],
38+
]
39+
kernel_size: 3
40+
stride: [ 2, 2, 2, 2, 1 ]

conf/scheduler/default.yaml

Whitespace-only changes.

conf/trainer/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ gradient_clip_val: 0.0
66
process_position: 0
77
num_nodes: 1
88
num_processes: 1
9-
gpus: null
9+
gpus: 1
1010
auto_select_gpus: False
1111
tpu_cores: null
1212
log_gpu_memory: null

torch_points3d/applications/sparseconv3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _set_input(self, data):
135135
data:
136136
a dictionary that contains the data itself and its metadata information.
137137
"""
138-
self.input = sp3d.nn.SparseTensor(data.x, data.coords, data.batch)
138+
self.input = sp3d.nn.SparseTensor(data.x, data.coords, data.batch, self.device)
139139
if data.pos is not None:
140140
self.xyz = data.pos
141141
else:
@@ -163,7 +163,7 @@ def forward(self, data, *args, **kwargs):
163163
for i in range(len(self.down_modules)):
164164
data = self.down_modules[i](data)
165165

166-
out = Batch(x=data.F, batch=data.C[:, 0].long().to(data.F.device))
166+
out = Batch(x=data.F, batch=data.C[:, 0].long())
167167
if not isinstance(self.inner_modules[0], Identity):
168168
out = self.inner_modules[0](out)
169169

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import torch
2+
import re
3+
import logging
4+
import torch.nn.functional as F
5+
from torch_scatter import scatter_mean, scatter_add
6+
from torch_geometric.nn.pool.consecutive import consecutive_cluster
7+
from torch_geometric.nn import voxel_grid
8+
from torch_cluster import grid_cluster
9+
10+
log = logging.getLogger(__name__)
11+
12+
class AddOnes(object):
13+
"""
14+
Add ones tensor to data
15+
"""
16+
17+
def __call__(self, data):
18+
num_nodes = data.pos.shape[0]
19+
data.ones = torch.ones((num_nodes, 1)).float()
20+
return data
21+
22+
def __repr__(self):
23+
return "{}()".format(self.__class__.__name__)
24+
25+
26+
# Label will be the majority label in each voxel
27+
_INTEGER_LABEL_KEYS = ["y", "instance_labels"]
28+
29+
def shuffle_data(data):
30+
num_points = data.pos.shape[0]
31+
shuffle_idx = torch.randperm(num_points)
32+
for key in set(data.keys):
33+
item = data[key]
34+
if torch.is_tensor(item) and num_points == item.shape[0]:
35+
data[key] = item[shuffle_idx]
36+
return data
37+
38+
39+
def group_data(data, cluster=None, unique_pos_indices=None, mode="last", skip_keys=[]):
40+
""" Group data based on indices in cluster.
41+
The option ``mode`` controls how data gets agregated within each cluster.
42+
43+
Parameters
44+
----------
45+
data : Data
46+
[description]
47+
cluster : torch.Tensor
48+
Tensor of the same size as the number of points in data. Each element is the cluster index of that point.
49+
unique_pos_indices : torch.tensor
50+
Tensor containing one index per cluster, this index will be used to select features and labels
51+
mode : str
52+
Option to select how the features and labels for each voxel is computed. Can be ``last`` or ``mean``.
53+
``last`` selects the last point falling in a voxel as the representent, ``mean`` takes the average.
54+
skip_keys: list
55+
Keys of attributes to skip in the grouping
56+
"""
57+
58+
assert mode in ["mean", "last"]
59+
if mode == "mean" and cluster is None:
60+
raise ValueError("In mean mode the cluster argument needs to be specified")
61+
if mode == "last" and unique_pos_indices is None:
62+
raise ValueError("In last mode the unique_pos_indices argument needs to be specified")
63+
64+
num_nodes = data.num_nodes
65+
for key, item in data:
66+
if bool(re.search("edge", key)):
67+
raise ValueError("Edges not supported. Wrong data type.")
68+
if key in skip_keys:
69+
continue
70+
71+
if torch.is_tensor(item) and item.size(0) == num_nodes:
72+
if mode == "last" or key == "batch": #or key == SaveOriginalPosId.KEY:
73+
data[key] = item[unique_pos_indices]
74+
elif mode == "mean":
75+
is_item_bool = item.dtype == torch.bool
76+
if is_item_bool:
77+
item = item.int()
78+
if key in _INTEGER_LABEL_KEYS:
79+
item_min = item.min()
80+
item = F.one_hot(item - item_min)
81+
item = scatter_add(item, cluster, dim=0)
82+
data[key] = item.argmax(dim=-1) + item_min
83+
else:
84+
data[key] = scatter_mean(item, cluster, dim=0)
85+
if is_item_bool:
86+
data[key] = data[key].bool()
87+
return data
88+
89+
# todo: replace these with minkowski/torchsparse impl?
90+
class GridSampling3D:
91+
""" Clusters points into voxels with size :attr:`size`.
92+
Parameters
93+
----------
94+
size: float
95+
Size of a voxel (in each dimension).
96+
quantize_coords: bool
97+
If True, it will convert the points into their associated sparse coordinates within the grid and store
98+
the value into a new `coords` attribute
99+
mode: string:
100+
The mode can be either `last` or `mean`.
101+
If mode is `mean`, all the points and their features within a cell will be averaged
102+
If mode is `last`, one random points per cell will be selected with its associated features
103+
"""
104+
105+
def __init__(self, size, quantize_coords=False, mode="mean", verbose=False):
106+
self._grid_size = size
107+
self._quantize_coords = quantize_coords
108+
self._mode = mode
109+
if verbose:
110+
log.warning(
111+
"If you need to keep track of the position of your points, use SaveOriginalPosId transform before using GridSampling3D"
112+
)
113+
114+
if self._mode == "last":
115+
log.warning(
116+
"The tensors within data will be shuffled each time this transform is applied. Be careful that if an attribute doesn't have the size of num_points, it won't be shuffled"
117+
)
118+
119+
def _process(self, data):
120+
if self._mode == "last":
121+
data = shuffle_data(data)
122+
123+
coords = torch.round((data.pos) / self._grid_size)
124+
if "batch" not in data:
125+
cluster = grid_cluster(coords, torch.tensor([1, 1, 1]))
126+
else:
127+
cluster = voxel_grid(coords, data.batch, 1)
128+
cluster, unique_pos_indices = consecutive_cluster(cluster)
129+
130+
data = group_data(data, cluster, unique_pos_indices, mode=self._mode)
131+
if self._quantize_coords:
132+
data.coords = coords[unique_pos_indices].int()
133+
134+
data.grid_size = torch.tensor([self._grid_size])
135+
return data
136+
137+
def __call__(self, data):
138+
if isinstance(data, list):
139+
data = [self._process(d) for d in data]
140+
else:
141+
data = self._process(data)
142+
return data
143+
144+
def __repr__(self):
145+
return "{}(grid_size={}, quantize_coords={}, mode={})".format(
146+
self.__class__.__name__, self._grid_size, self._quantize_coords, self._mode
147+
)

0 commit comments

Comments
 (0)