Skip to content

Commit ca90fc1

Browse files
committed
Create Python API and test cases for Cubic Feature Sampling.
1 parent 1af4a55 commit ca90fc1

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
import os
3+
import sys
4+
import torch
5+
import unittest
6+
7+
from torch.autograd import gradcheck
8+
9+
from . import run_if_cuda
10+
11+
12+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
13+
sys.path.insert(0, ROOT)
14+
15+
from torch_points_kernels import CubicFeatureSamplingFunction, cubic_feature_sampling
16+
17+
18+
class TestCubicFeatureSampling(unittest.TestCase):
19+
@run_if_cuda
20+
def test_neighborhood_size_1(self):
21+
ptcloud = torch.rand(2, 64, 3) * 2 - 1
22+
cubic_features = torch.rand(2, 4, 8, 8, 8)
23+
ptcloud.requires_grad = True
24+
cubic_features.requires_grad = True
25+
self.assertTrue(
26+
gradcheck(
27+
CubicFeatureSamplingFunction.apply,
28+
[ptcloud.double().cuda(), cubic_features.double().cuda()],
29+
)
30+
)
31+
32+
@run_if_cuda
33+
def test_neighborhood_size_2(self):
34+
ptcloud = torch.rand(2, 32, 3) * 2 - 1
35+
cubic_features = torch.rand(2, 2, 8, 8, 8)
36+
ptcloud.requires_grad = True
37+
cubic_features.requires_grad = True
38+
self.assertTrue(
39+
gradcheck(
40+
CubicFeatureSamplingFunction.apply,
41+
[ptcloud.double().cuda(), cubic_features.double().cuda(), 2],
42+
)
43+
)
44+
45+
@run_if_cuda
46+
def test_neighborhood_size_3(self):
47+
ptcloud = torch.rand(1, 32, 3) * 2 - 1
48+
cubic_features = torch.rand(1, 2, 16, 16, 16)
49+
ptcloud.requires_grad = True
50+
cubic_features.requires_grad = True
51+
self.assertTrue(
52+
gradcheck(
53+
CubicFeatureSamplingFunction.apply,
54+
[ptcloud.double().cuda(), cubic_features.double().cuda(), 3],
55+
)
56+
)
57+
58+
59+
if __name__ == "__main__":
60+
unittest.main()

torch_points_kernels/torchpoints.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def forward(ctx, xyz1, xyz2):
229229
@staticmethod
230230
def backward(ctx, grad_dist1, grad_dist2):
231231
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
232-
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
232+
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(
233+
xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2
234+
)
233235
return grad_xyz1, grad_xyz2
234236

235237

@@ -260,3 +262,56 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
260262

261263
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
262264
return torch.mean(dist1) + torch.mean(dist2)
265+
266+
267+
class CubicFeatureSamplingFunction(torch.autograd.Function):
268+
@staticmethod
269+
def forward(ctx, ptcloud, cubic_features, neighborhood_size=1):
270+
scale = cubic_features.size(2)
271+
point_features, grid_pt_indexes = tpcuda.cubic_feature_sampling(
272+
scale, neighborhood_size, ptcloud, cubic_features
273+
)
274+
ctx.save_for_backward(
275+
torch.Tensor([scale]), torch.Tensor([neighborhood_size]), grid_pt_indexes
276+
)
277+
return point_features
278+
279+
@staticmethod
280+
def backward(ctx, grad_point_features):
281+
scale, neighborhood_size, grid_pt_indexes = ctx.saved_tensors
282+
scale = int(scale.item())
283+
neighborhood_size = int(neighborhood_size.item())
284+
grad_point_features = grad_point_features.contiguous()
285+
grad_ptcloud, grad_cubic_features = tpcuda.cubic_feature_sampling_grad(
286+
scale, neighborhood_size, grad_point_features, grid_pt_indexes
287+
)
288+
return grad_ptcloud, grad_cubic_features, None
289+
290+
291+
def cubic_feature_sampling(ptcloud, cubic_features, neighborhood_size=1):
292+
r"""
293+
Sample the features of points from 3D feature maps that the point lies in.
294+
Please refer to https://arxiv.org/pdf/2006.03761 for more information
295+
296+
Parameters
297+
----------
298+
ptcloud : torch.Tensor (dtype=torch.float32)
299+
(B, n_pts, 3) point clouds containing n_pts points
300+
cubic_features : torch.Tensor (dtype=torch.float32)
301+
(B, c, m, m, m) 3D feature maps of sizes m x m x m and c channels
302+
neighborhood_size : int
303+
The neighborhood cubes to sample.
304+
neighborhood_size = 1 means to sample the cube that point lies in.
305+
neighborhood_size = 2 means to sample surrouding cubes (step = 1) of
306+
the cube that point lies in.
307+
308+
Returns
309+
-------
310+
dist: torch.Tensor
311+
(B, n_pts, n_vertices, c), where n_vertices = (neighborhood_size * 2)^3
312+
"""
313+
h_scale = cubic_features.size(2) / 2
314+
ptcloud = ptcloud * h_scale + h_scale
315+
return CubicFeatureSamplingFunction.apply(
316+
ptcloud, cubic_features, neighborhood_size
317+
)

0 commit comments

Comments
 (0)