Skip to content

Commit bff9162

Browse files
committed
Move Cubic Feature Sampling to a new file.
1 parent ca90fc1 commit bff9162

File tree

5 files changed

+95
-63
lines changed

5 files changed

+95
-63
lines changed

cuda/src/cubic_feature_sampling_gpu.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <cstdlib>
44
#include <torch/extension.h>
55

6+
#include "cuda_utils.h"
7+
68
#define CUDA_NUM_THREADS 512
79

810
// Computer the number of threads needed in GPU

test/test_cubic_feature_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
1313
sys.path.insert(0, ROOT)
1414

15-
from torch_points_kernels import CubicFeatureSamplingFunction, cubic_feature_sampling
15+
from torch_points_kernels.cubic_feature_sampling import CubicFeatureSamplingFunction, cubic_feature_sampling
1616

1717

1818
class TestCubicFeatureSampling(unittest.TestCase):

torch_points_kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .knn import knn
33
from .cluster import region_grow
44
from .metrics import instance_iou
5+
from .cubic_feature_sampling import cubic_feature_sampling
56

67
__all__ = [
78
"ball_query",
@@ -13,4 +14,5 @@
1314
"region_grow",
1415
"instance_iou",
1516
"chamfer_dist",
17+
"cubic_feature_sampling",
1618
]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
3+
if torch.cuda.is_available():
4+
import torch_points_kernels.points_cuda as tpcuda
5+
6+
7+
class CubicFeatureSamplingFunction(torch.autograd.Function):
8+
@staticmethod
9+
def forward(ctx, ptcloud, cubic_features, neighborhood_size=1):
10+
scale = cubic_features.size(2)
11+
if not torch.cuda.is_available():
12+
raise NotImplementedError(
13+
"CPU version is not available for Cubic Feature Sampling"
14+
)
15+
16+
point_features, grid_pt_indexes = tpcuda.cubic_feature_sampling(
17+
scale, neighborhood_size, ptcloud, cubic_features
18+
)
19+
ctx.save_for_backward(
20+
torch.Tensor([scale]), torch.Tensor([neighborhood_size]), grid_pt_indexes
21+
)
22+
return point_features
23+
24+
@staticmethod
25+
def backward(ctx, grad_point_features):
26+
scale, neighborhood_size, grid_pt_indexes = ctx.saved_tensors
27+
scale = int(scale.item())
28+
neighborhood_size = int(neighborhood_size.item())
29+
grad_point_features = grad_point_features.contiguous()
30+
grad_ptcloud, grad_cubic_features = tpcuda.cubic_feature_sampling_grad(
31+
scale, neighborhood_size, grad_point_features, grid_pt_indexes
32+
)
33+
return grad_ptcloud, grad_cubic_features, None
34+
35+
36+
def cubic_feature_sampling(ptcloud, cubic_features, neighborhood_size=1):
37+
r"""
38+
Sample the features of points from 3D feature maps that the point lies in.
39+
Please refer to https://arxiv.org/pdf/2006.03761 for more information
40+
41+
Parameters
42+
----------
43+
ptcloud : torch.Tensor (dtype=torch.float32)
44+
(B, n_pts, 3) point clouds containing n_pts points
45+
cubic_features : torch.Tensor (dtype=torch.float32)
46+
(B, c, m, m, m) 3D feature maps of sizes m x m x m and c channels
47+
neighborhood_size : int
48+
The neighborhood cubes to sample.
49+
neighborhood_size = 1 means to sample the cube that point lies in.
50+
neighborhood_size = 2 means to sample surrouding cubes (step = 1) of
51+
the cube that point lies in.
52+
53+
Returns
54+
-------
55+
dist: torch.Tensor
56+
(B, n_pts, n_vertices, c), where n_vertices = (neighborhood_size * 2)^3
57+
"""
58+
h_scale = cubic_features.size(2) / 2
59+
ptcloud = ptcloud * h_scale + h_scale
60+
return CubicFeatureSamplingFunction.apply(
61+
ptcloud, cubic_features, neighborhood_size
62+
)

torch_points_kernels/torchpoints.py

Lines changed: 28 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def furthest_point_sample(xyz, npoint):
3030
(B, npoint) tensor containing the set
3131
"""
3232
if npoint > xyz.shape[1]:
33-
raise ValueError("caanot sample %i points from an input set of %i points" % (npoint, xyz.shape[1]))
33+
raise ValueError(
34+
"caanot sample %i points from an input set of %i points"
35+
% (npoint, xyz.shape[1])
36+
)
3437
if xyz.is_cuda:
3538
return tpcuda.furthest_point_sampling(xyz, npoint)
3639
else:
@@ -99,9 +102,13 @@ def backward(ctx, grad_out):
99102
idx, weight, m = ctx.three_interpolate_for_backward
100103

101104
if grad_out.is_cuda:
102-
grad_features = tpcuda.three_interpolate_grad(grad_out.contiguous(), idx, weight, m)
105+
grad_features = tpcuda.three_interpolate_grad(
106+
grad_out.contiguous(), idx, weight, m
107+
)
103108
else:
104-
grad_features = tpcpu.knn_interpolate_grad(grad_out.contiguous(), idx, weight, m)
109+
grad_features = tpcpu.knn_interpolate_grad(
110+
grad_out.contiguous(), idx, weight, m
111+
)
105112

106113
return grad_features, None, None
107114

@@ -143,17 +150,23 @@ def grouping_operation(features, idx):
143150
all_idx = idx.reshape(idx.shape[0], -1)
144151
all_idx = all_idx.unsqueeze(1).repeat(1, features.shape[1], 1)
145152
grouped_features = features.gather(2, all_idx)
146-
return grouped_features.reshape(idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2])
153+
return grouped_features.reshape(
154+
idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2]
155+
)
147156

148157

149-
def ball_query_dense(radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False):
158+
def ball_query_dense(
159+
radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False
160+
):
150161
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
151162
if new_xyz.is_cuda:
152163
if sort:
153164
raise NotImplementedError("CUDA version does not sort the neighbors")
154165
ind, dist = tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
155166
else:
156-
ind, dist = tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0, sorted=sort)
167+
ind, dist = tpcpu.dense_ball_query(
168+
new_xyz, xyz, radius, nsample, mode=0, sorted=sort
169+
)
157170
return ind, dist
158171

159172

@@ -162,9 +175,13 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False
162175
if x.is_cuda:
163176
if sort:
164177
raise NotImplementedError("CUDA version does not sort the neighbors")
165-
ind, dist = tpcuda.ball_query_partial_dense(x, y, batch_x, batch_y, radius, nsample)
178+
ind, dist = tpcuda.ball_query_partial_dense(
179+
x, y, batch_x, batch_y, radius, nsample
180+
)
166181
else:
167-
ind, dist = tpcpu.batch_ball_query(x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort)
182+
ind, dist = tpcpu.batch_ball_query(
183+
x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort
184+
)
168185
return ind, dist
169186

170187

@@ -207,7 +224,9 @@ def ball_query(
207224
assert x.size(0) == batch_x.size(0)
208225
assert y.size(0) == batch_y.size(0)
209226
assert x.dim() == 2
210-
return ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=sort)
227+
return ball_query_partial_dense(
228+
radius, nsample, x, y, batch_x, batch_y, sort=sort
229+
)
211230

212231
elif mode.lower() == "dense":
213232
if (batch_x is not None) or (batch_y is not None):
@@ -262,56 +281,3 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
262281

263282
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
264283
return torch.mean(dist1) + torch.mean(dist2)
265-
266-
267-
class CubicFeatureSamplingFunction(torch.autograd.Function):
268-
@staticmethod
269-
def forward(ctx, ptcloud, cubic_features, neighborhood_size=1):
270-
scale = cubic_features.size(2)
271-
point_features, grid_pt_indexes = tpcuda.cubic_feature_sampling(
272-
scale, neighborhood_size, ptcloud, cubic_features
273-
)
274-
ctx.save_for_backward(
275-
torch.Tensor([scale]), torch.Tensor([neighborhood_size]), grid_pt_indexes
276-
)
277-
return point_features
278-
279-
@staticmethod
280-
def backward(ctx, grad_point_features):
281-
scale, neighborhood_size, grid_pt_indexes = ctx.saved_tensors
282-
scale = int(scale.item())
283-
neighborhood_size = int(neighborhood_size.item())
284-
grad_point_features = grad_point_features.contiguous()
285-
grad_ptcloud, grad_cubic_features = tpcuda.cubic_feature_sampling_grad(
286-
scale, neighborhood_size, grad_point_features, grid_pt_indexes
287-
)
288-
return grad_ptcloud, grad_cubic_features, None
289-
290-
291-
def cubic_feature_sampling(ptcloud, cubic_features, neighborhood_size=1):
292-
r"""
293-
Sample the features of points from 3D feature maps that the point lies in.
294-
Please refer to https://arxiv.org/pdf/2006.03761 for more information
295-
296-
Parameters
297-
----------
298-
ptcloud : torch.Tensor (dtype=torch.float32)
299-
(B, n_pts, 3) point clouds containing n_pts points
300-
cubic_features : torch.Tensor (dtype=torch.float32)
301-
(B, c, m, m, m) 3D feature maps of sizes m x m x m and c channels
302-
neighborhood_size : int
303-
The neighborhood cubes to sample.
304-
neighborhood_size = 1 means to sample the cube that point lies in.
305-
neighborhood_size = 2 means to sample surrouding cubes (step = 1) of
306-
the cube that point lies in.
307-
308-
Returns
309-
-------
310-
dist: torch.Tensor
311-
(B, n_pts, n_vertices, c), where n_vertices = (neighborhood_size * 2)^3
312-
"""
313-
h_scale = cubic_features.size(2) / 2
314-
ptcloud = ptcloud * h_scale + h_scale
315-
return CubicFeatureSamplingFunction.apply(
316-
ptcloud, cubic_features, neighborhood_size
317-
)

0 commit comments

Comments
 (0)