Skip to content

Commit e2c315b

Browse files
committed
Create the unit test for Gridding.
1 parent 8488978 commit e2c315b

File tree

6 files changed

+126
-39
lines changed

6 files changed

+126
-39
lines changed

cuda/src/gridding_gpu.cu

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@ __device__ int compute_index(int offset_x, int offset_y, int offset_z, int len_y
1919
return offset_x * len_y * len_z + offset_y * len_z + offset_z;
2020
}
2121

22-
__device__ float compute_weight(float x, float x0)
22+
template <typename scalar_t>
23+
__device__ scalar_t compute_weight(scalar_t x, scalar_t x0)
2324
{
2425
return 1 - abs(x - x0);
2526
}
2627

2728
template <typename scalar_t>
2829
__global__ void
29-
gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float min_z,
30-
int len_y, int len_z, const scalar_t* __restrict__ ptcloud,
30+
gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float min_z, int len_y,
31+
int len_z, const scalar_t* __restrict__ ptcloud,
3132
scalar_t* __restrict__ grid_weights, scalar_t* __restrict__ grid_pt_weights,
3233
int* __restrict__ grid_pt_indexes)
3334
{
@@ -72,51 +73,51 @@ gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float
7273
// Compute weights and corresponding positions, a loop for 8 points
7374
// LLL -> Lower X, Lower Y, Lower Z
7475
grid_pt_indexes[j * 8 + 0] = compute_index(lx_offset, ly_offset, lz_offset, len_y, len_z);
75-
grid_pt_weights[j * 24 + 0] = compute_weight(pt_x, lower_x);
76-
grid_pt_weights[j * 24 + 1] = compute_weight(pt_y, lower_y);
77-
grid_pt_weights[j * 24 + 2] = compute_weight(pt_z, lower_z);
76+
grid_pt_weights[j * 24 + 0] = compute_weight<scalar_t>(pt_x, lower_x);
77+
grid_pt_weights[j * 24 + 1] = compute_weight<scalar_t>(pt_y, lower_y);
78+
grid_pt_weights[j * 24 + 2] = compute_weight<scalar_t>(pt_z, lower_z);
7879

7980
// LLU -> Lower X, Lower Y, Upper Z
8081
grid_pt_indexes[j * 8 + 1] = compute_index(lx_offset, ly_offset, uz_offset, len_y, len_z);
81-
grid_pt_weights[j * 24 + 3] = compute_weight(pt_x, lower_x);
82-
grid_pt_weights[j * 24 + 4] = compute_weight(pt_y, lower_y);
83-
grid_pt_weights[j * 24 + 5] = compute_weight(pt_z, upper_z);
82+
grid_pt_weights[j * 24 + 3] = compute_weight<scalar_t>(pt_x, lower_x);
83+
grid_pt_weights[j * 24 + 4] = compute_weight<scalar_t>(pt_y, lower_y);
84+
grid_pt_weights[j * 24 + 5] = compute_weight<scalar_t>(pt_z, upper_z);
8485

8586
// LUL -> Lower X, Upper Y, Lower Z
8687
grid_pt_indexes[j * 8 + 2] = compute_index(lx_offset, uy_offset, lz_offset, len_y, len_z);
87-
grid_pt_weights[j * 24 + 6] = compute_weight(pt_x, lower_x);
88-
grid_pt_weights[j * 24 + 7] = compute_weight(pt_y, upper_y);
89-
grid_pt_weights[j * 24 + 8] = compute_weight(pt_z, lower_z);
88+
grid_pt_weights[j * 24 + 6] = compute_weight<scalar_t>(pt_x, lower_x);
89+
grid_pt_weights[j * 24 + 7] = compute_weight<scalar_t>(pt_y, upper_y);
90+
grid_pt_weights[j * 24 + 8] = compute_weight<scalar_t>(pt_z, lower_z);
9091

9192
// LUU -> Lower X, Upper Y, Upper Z
9293
grid_pt_indexes[j * 8 + 3] = compute_index(lx_offset, uy_offset, uz_offset, len_y, len_z);
93-
grid_pt_weights[j * 24 + 9] = compute_weight(pt_x, lower_x);
94-
grid_pt_weights[j * 24 + 10] = compute_weight(pt_y, upper_y);
95-
grid_pt_weights[j * 24 + 11] = compute_weight(pt_z, upper_z);
94+
grid_pt_weights[j * 24 + 9] = compute_weight<scalar_t>(pt_x, lower_x);
95+
grid_pt_weights[j * 24 + 10] = compute_weight<scalar_t>(pt_y, upper_y);
96+
grid_pt_weights[j * 24 + 11] = compute_weight<scalar_t>(pt_z, upper_z);
9697

9798
// ULL -> Upper X, Lower Y, Lower Z
9899
grid_pt_indexes[j * 8 + 4] = compute_index(ux_offset, ly_offset, lz_offset, len_y, len_z);
99-
grid_pt_weights[j * 24 + 12] = compute_weight(pt_x, upper_x);
100-
grid_pt_weights[j * 24 + 13] = compute_weight(pt_y, lower_y);
101-
grid_pt_weights[j * 24 + 14] = compute_weight(pt_z, lower_z);
100+
grid_pt_weights[j * 24 + 12] = compute_weight<scalar_t>(pt_x, upper_x);
101+
grid_pt_weights[j * 24 + 13] = compute_weight<scalar_t>(pt_y, lower_y);
102+
grid_pt_weights[j * 24 + 14] = compute_weight<scalar_t>(pt_z, lower_z);
102103

103104
// ULU -> Upper X, Lower Y, Upper Z
104105
grid_pt_indexes[j * 8 + 5] = compute_index(ux_offset, ly_offset, uz_offset, len_y, len_z);
105-
grid_pt_weights[j * 24 + 15] = compute_weight(pt_x, upper_x);
106-
grid_pt_weights[j * 24 + 16] = compute_weight(pt_y, lower_y);
107-
grid_pt_weights[j * 24 + 17] = compute_weight(pt_z, upper_z);
106+
grid_pt_weights[j * 24 + 15] = compute_weight<scalar_t>(pt_x, upper_x);
107+
grid_pt_weights[j * 24 + 16] = compute_weight<scalar_t>(pt_y, lower_y);
108+
grid_pt_weights[j * 24 + 17] = compute_weight<scalar_t>(pt_z, upper_z);
108109

109110
// UUL -> Upper X, Upper Y, Lower Z
110111
grid_pt_indexes[j * 8 + 6] = compute_index(ux_offset, uy_offset, lz_offset, len_y, len_z);
111-
grid_pt_weights[j * 24 + 18] = compute_weight(pt_x, upper_x);
112-
grid_pt_weights[j * 24 + 19] = compute_weight(pt_y, upper_y);
113-
grid_pt_weights[j * 24 + 20] = compute_weight(pt_z, lower_z);
112+
grid_pt_weights[j * 24 + 18] = compute_weight<scalar_t>(pt_x, upper_x);
113+
grid_pt_weights[j * 24 + 19] = compute_weight<scalar_t>(pt_y, upper_y);
114+
grid_pt_weights[j * 24 + 20] = compute_weight<scalar_t>(pt_z, lower_z);
114115

115116
// UUU -> Upper X, Upper Y, Upper Z
116117
grid_pt_indexes[j * 8 + 7] = compute_index(ux_offset, uy_offset, uz_offset, len_y, len_z);
117-
grid_pt_weights[j * 24 + 21] = compute_weight(pt_x, upper_x);
118-
grid_pt_weights[j * 24 + 22] = compute_weight(pt_y, upper_y);
119-
grid_pt_weights[j * 24 + 23] = compute_weight(pt_z, upper_z);
118+
grid_pt_weights[j * 24 + 21] = compute_weight<scalar_t>(pt_x, upper_x);
119+
grid_pt_weights[j * 24 + 22] = compute_weight<scalar_t>(pt_y, upper_y);
120+
grid_pt_weights[j * 24 + 23] = compute_weight<scalar_t>(pt_z, upper_z);
120121
}
121122

122123
__syncthreads();
@@ -179,9 +180,9 @@ std::vector<torch::Tensor> gridding_kernel_warpper(float min_x, float max_x, flo
179180
int n_grid_vertices = len_x * len_y * len_z;
180181

181182
torch::Tensor grid_weights =
182-
torch::zeros({batch_size, n_grid_vertices}, torch::CUDA(torch::kFloat));
183+
torch::zeros({batch_size, n_grid_vertices}, torch::CUDA(ptcloud.scalar_type()));
183184
torch::Tensor grid_pt_weights =
184-
torch::zeros({batch_size, n_pts, 8, 3}, torch::CUDA(torch::kFloat));
185+
torch::zeros({batch_size, n_pts, 8, 3}, torch::CUDA(ptcloud.scalar_type()));
185186
torch::Tensor grid_pt_indexes = torch::zeros({batch_size, n_pts, 8}, torch::CUDA(torch::kInt));
186187

187188
AT_DISPATCH_FLOATING_TYPES(
@@ -310,7 +311,8 @@ torch::Tensor gridding_grad_kernel_warpper(torch::Tensor grid_pt_weights,
310311
int n_grid_vertices = grad_grid.size(1);
311312
int n_pts = grid_pt_indexes.size(1);
312313

313-
torch::Tensor grad_ptcloud = torch::zeros({batch_size, n_pts, 3}, torch::CUDA(torch::kFloat));
314+
torch::Tensor grad_ptcloud =
315+
torch::zeros({batch_size, n_pts, 3}, torch::CUDA(grid_pt_weights.scalar_type()));
314316

315317
AT_DISPATCH_FLOATING_TYPES(
316318
grid_pt_weights.scalar_type(), "gridding_grad_cuda", ([&] {

test/test_gridding.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import os
3+
import sys
4+
import torch
5+
import unittest
6+
7+
from torch.autograd import gradcheck
8+
9+
from . import run_if_cuda
10+
11+
12+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
13+
sys.path.insert(0, ROOT)
14+
15+
from torch_points_kernels.gridding import GriddingFunction
16+
17+
18+
class TestGridding(unittest.TestCase):
19+
@run_if_cuda
20+
def test_gridding_function_32pts(self):
21+
x = torch.rand(1, 32, 3)
22+
x.requires_grad = True
23+
self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda(), 4]))
24+
25+
@run_if_cuda
26+
def test_gridding_function_64pts(self):
27+
x = torch.rand(1, 64, 3)
28+
x.requires_grad = True
29+
self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda(), 8]))
30+

torch_points_kernels/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
"instance_iou",
1616
"chamfer_dist",
1717
"cubic_feature_sampling",
18+
"gridding",
1819
]

torch_points_kernels/chamfer_dist.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ class ChamferFunction(torch.autograd.Function):
88
@staticmethod
99
def forward(ctx, xyz1, xyz2):
1010
if not torch.cuda.is_available():
11-
raise NotImplementedError(
12-
"CPU version is not available for Chamfer Distance"
13-
)
11+
raise NotImplementedError("CPU version is not available for Chamfer Distance")
1412

1513
dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2)
1614
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
@@ -20,9 +18,7 @@ def forward(ctx, xyz1, xyz2):
2018
@staticmethod
2119
def backward(ctx, grad_dist1, grad_dist2):
2220
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
23-
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(
24-
xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2
25-
)
21+
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
2622
return grad_xyz1, grad_xyz2
2723

2824

@@ -45,7 +41,7 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
4541
(B, ): the distances between B pairs of point clouds
4642
"""
4743
if len(xyz1.shape) != 3 or xyz1.size(2) != 3 or len(xyz2.shape) != 3 or xyz2.size(2) != 3:
48-
raise ValueError('The input point cloud should be of size (B, n_pts, 3)')
44+
raise ValueError("The input point cloud should be of size (B, n_pts, 3)")
4945

5046
batch_size = xyz1.size(0)
5147
if batch_size == 1 and ignore_zeros:
@@ -56,4 +52,3 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
5652

5753
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
5854
return torch.mean(dist1) + torch.mean(dist2)
59-

torch_points_kernels/gridding.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
3+
if torch.cuda.is_available():
4+
import torch_points_kernels.points_cuda as tpcuda
5+
6+
7+
class GriddingFunction(torch.autograd.Function):
8+
@staticmethod
9+
def forward(ctx, ptcloud, scale):
10+
if not torch.cuda.is_available():
11+
raise NotImplementedError("CPU version is not available for Chamfer Distance")
12+
13+
grid, grid_pt_weights, grid_pt_indexes = tpcuda.gridding(
14+
-scale, scale - 1, -scale, scale - 1, -scale, scale - 1, ptcloud
15+
)
16+
# print(grid.size()) # torch.Size(batch_size, n_grid_vertices)
17+
# print(grid_pt_weights.size()) # torch.Size(batch_size, n_pts, 8, 3)
18+
# print(grid_pt_indexes.size()) # torch.Size(batch_size, n_pts, 8)
19+
ctx.save_for_backward(grid_pt_weights, grid_pt_indexes)
20+
21+
return grid
22+
23+
@staticmethod
24+
def backward(ctx, grad_grid):
25+
grid_pt_weights, grid_pt_indexes = ctx.saved_tensors
26+
grad_ptcloud = tpcuda.gridding_grad(grid_pt_weights, grid_pt_indexes, grad_grid)
27+
# print(grad_ptcloud.size()) # torch.Size(batch_size, n_pts, 3)
28+
29+
return grad_ptcloud, None
30+
31+
32+
def gridding(ptcloud, scale):
33+
r"""
34+
Converts the input point clouds into 3D grids by trilinear interpolcation.
35+
Please refer to https://arxiv.org/pdf/2006.03761 for more information
36+
37+
Parameters
38+
----------
39+
ptcloud : torch.Tensor (dtype=torch.float32)
40+
(B, n_pts, 3) B point clouds containing n_pts points
41+
scale : Int
42+
the resolution of the 3D grid
43+
44+
Returns
45+
-------
46+
grid: torch.Tensor
47+
(B, scale, scale, scale): the grid of the resolution of scale * scale * scale
48+
"""
49+
if len(ptcloud.shape) != 3 or ptcloud.size(2) != 3:
50+
raise ValueError("The input point cloud should be of size (B, n_pts, 3)")
51+
52+
ptcloud = ptcloud * scale
53+
_ptcloud = torch.split(ptcloud, 1, dim=0)
54+
grids = []
55+
for p in _ptcloud:
56+
non_zeros = torch.sum(p, dim=2).ne(0)
57+
p = p[non_zeros].unsqueeze(dim=0)
58+
grids.append(GriddingFunction.apply(p, scale))
59+
60+
return torch.cat(grids, dim=0).contiguous()

torch_points_kernels/torchpoints.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,3 @@ def ball_query(
216216
return ball_query_dense(radius, nsample, x, y, sort=sort)
217217
else:
218218
raise Exception("unrecognized mode {}".format(mode))
219-

0 commit comments

Comments
 (0)