Skip to content

Commit 13c5a52

Browse files
Merge pull request #47 from hzxie/cubic-feature-sampling
Add Cubic feature sampling
2 parents ea6277f + eda4158 commit 13c5a52

16 files changed

+513
-22
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <vector>
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <torch/extension.h>
5+
6+
std::vector<torch::Tensor> cubic_feature_sampling(int scale, int neighborhood_size,
7+
torch::Tensor ptcloud,
8+
torch::Tensor cubic_features);
9+
10+
std::vector<torch::Tensor> cubic_feature_sampling_grad(int scale, int neighborhood_size,
11+
torch::Tensor grad_point_features,
12+
torch::Tensor grid_pt_indexes);
13+
14+
std::vector<torch::Tensor> cubic_feature_sampling_kernel_wrapper(int scale, int neighborhood_size,
15+
torch::Tensor ptcloud,
16+
torch::Tensor cubic_features,
17+
cudaStream_t stream);
18+
19+
std::vector<torch::Tensor>
20+
cubic_feature_sampling_grad_kernel_wrapper(int scale, int neighborhood_size,
21+
torch::Tensor grad_point_features,
22+
torch::Tensor grid_pt_indexes, cudaStream_t stream);

cuda/src/bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "ball_query.h"
22
#include "chamfer_dist.h"
3+
#include "cubic_feature_sampling.h"
34
#include "interpolate.h"
45
#include "metrics.h"
56
#include "sampling.h"
@@ -19,4 +20,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
1920

2021
m.def("chamfer_dist", &chamfer_dist);
2122
m.def("chamfer_dist_grad", &chamfer_dist_grad);
23+
24+
m.def("cubic_feature_sampling", &cubic_feature_sampling);
25+
m.def("cubic_feature_sampling_grad", &cubic_feature_sampling_grad);
2226
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "cubic_feature_sampling.h"
2+
#include "utils.h"
3+
4+
std::vector<torch::Tensor> cubic_feature_sampling(int scale, int neighborhood_size,
5+
torch::Tensor ptcloud,
6+
torch::Tensor cubic_features)
7+
{
8+
CHECK_CUDA(ptcloud);
9+
CHECK_CONTIGUOUS(ptcloud);
10+
CHECK_CUDA(cubic_features);
11+
CHECK_CONTIGUOUS(cubic_features);
12+
13+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
14+
return cubic_feature_sampling_kernel_wrapper(scale, neighborhood_size, ptcloud, cubic_features,
15+
stream);
16+
}
17+
18+
std::vector<torch::Tensor> cubic_feature_sampling_grad(int scale, int neighborhood_size,
19+
torch::Tensor grad_point_features,
20+
torch::Tensor grid_pt_indexes)
21+
{
22+
CHECK_CUDA(grad_point_features);
23+
CHECK_CONTIGUOUS(grad_point_features);
24+
CHECK_CUDA(grid_pt_indexes);
25+
CHECK_CONTIGUOUS(grid_pt_indexes);
26+
27+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
28+
return cubic_feature_sampling_grad_kernel_wrapper(scale, neighborhood_size, grad_point_features,
29+
grid_pt_indexes, stream);
30+
}
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
#include <cmath>
2+
#include <cstdio>
3+
#include <cstdlib>
4+
#include <torch/extension.h>
5+
6+
#include "cuda_utils.h"
7+
8+
#define CUDA_NUM_THREADS 512
9+
10+
// Computer the number of threads needed in GPU
11+
inline int get_n_threads(int n)
12+
{
13+
const int pow_2 = std::log(static_cast<float>(n)) / std::log(2.0);
14+
return max(min(1 << pow_2, CUDA_NUM_THREADS), 1);
15+
}
16+
17+
__device__ int compute_index(int offset_x, int offset_y, int offset_z, int scale)
18+
{
19+
return offset_x * scale * scale + offset_y * scale + offset_z;
20+
}
21+
22+
template <typename scalar_t>
23+
__global__ void cubic_feature_sampling_kernel(int scale, int neighborhood_size, int n_vertices,
24+
int n_pts, int n_cubic_channels,
25+
const scalar_t* __restrict__ ptcloud,
26+
const scalar_t* __restrict__ cubic_features,
27+
scalar_t* __restrict__ point_features,
28+
int* __restrict__ grid_pt_indexes)
29+
{
30+
int batch_index = blockIdx.x;
31+
int index = threadIdx.x;
32+
int stride = blockDim.x;
33+
int cub_scale = scale * scale * scale;
34+
35+
ptcloud += batch_index * n_pts * 3;
36+
cubic_features += batch_index * n_cubic_channels * cub_scale;
37+
point_features += batch_index * n_pts * n_vertices * n_cubic_channels;
38+
grid_pt_indexes += batch_index * n_pts * n_vertices;
39+
40+
for (int i = index; i < n_pts; i += stride)
41+
{
42+
scalar_t pt_x = ptcloud[i * 3 + 0];
43+
scalar_t pt_y = ptcloud[i * 3 + 1];
44+
scalar_t pt_z = ptcloud[i * 3 + 2];
45+
46+
int lower_x = std::floor(pt_x);
47+
int upper_x = std::ceil(pt_x);
48+
if (lower_x == upper_x)
49+
{
50+
upper_x += 1;
51+
}
52+
int lower_y = std::floor(pt_y);
53+
int upper_y = std::ceil(pt_y);
54+
if (lower_y == upper_y)
55+
{
56+
upper_y += 1;
57+
}
58+
int lower_z = std::floor(pt_z);
59+
int upper_z = std::ceil(pt_z);
60+
if (lower_z == upper_z)
61+
{
62+
upper_z += 1;
63+
}
64+
65+
int ns = neighborhood_size - 1;
66+
int vertex_idx = 0;
67+
for (int j = lower_x - ns; j <= upper_x + ns; ++j)
68+
{
69+
for (int k = lower_y - ns; k <= upper_y + ns; ++k)
70+
{
71+
for (int m = lower_z - ns; m <= upper_z + ns; ++m)
72+
{
73+
if (j < 0 || j >= scale || k < 0 || k >= scale || m < 0 || m >= scale)
74+
{
75+
// Ignore points lies out of the grid
76+
grid_pt_indexes[i * n_vertices + vertex_idx++] = -1;
77+
}
78+
else
79+
{
80+
// Calcuating indexes for adjacent vertices
81+
grid_pt_indexes[i * n_vertices + vertex_idx++] =
82+
compute_index(j, k, m, scale);
83+
}
84+
}
85+
}
86+
}
87+
88+
// Gather Features
89+
for (int j = 0; j < n_vertices; ++j)
90+
{
91+
for (int k = 0; k < n_cubic_channels; ++k)
92+
{
93+
int vertex_idx = grid_pt_indexes[i * n_vertices + j];
94+
if (vertex_idx == -1)
95+
{
96+
continue;
97+
}
98+
int feature_idx = i * n_vertices * n_cubic_channels + j * n_cubic_channels + k;
99+
scalar_t feature_val = cubic_features[k * cub_scale + vertex_idx];
100+
point_features[feature_idx] = feature_val;
101+
}
102+
}
103+
}
104+
}
105+
106+
std::vector<torch::Tensor> cubic_feature_sampling_kernel_wrapper(int scale, int neighborhood_size,
107+
torch::Tensor ptcloud,
108+
torch::Tensor cubic_features,
109+
cudaStream_t stream)
110+
{
111+
int batch_size = ptcloud.size(0);
112+
int n_pts = ptcloud.size(1);
113+
int n_cubic_channels = cubic_features.size(1);
114+
115+
int n_vertices = std::pow(neighborhood_size * 2, 3);
116+
torch::Tensor point_features = torch::zeros({batch_size, n_pts, n_vertices, n_cubic_channels},
117+
torch::CUDA(ptcloud.scalar_type()));
118+
torch::Tensor grid_pt_indexes =
119+
torch::zeros({batch_size, n_pts, n_vertices}, torch::CUDA(torch::kInt));
120+
121+
AT_DISPATCH_FLOATING_TYPES(
122+
ptcloud.scalar_type(), "cubic_feature_sampling_cuda", ([&] {
123+
cubic_feature_sampling_kernel<<<batch_size, get_n_threads(n_pts), 0, stream>>>(
124+
scale, neighborhood_size, n_vertices, n_pts, n_cubic_channels,
125+
ptcloud.data_ptr<scalar_t>(), cubic_features.data_ptr<scalar_t>(),
126+
point_features.data_ptr<scalar_t>(), grid_pt_indexes.data_ptr<int>());
127+
}));
128+
129+
cudaError_t err = cudaGetLastError();
130+
if (err != cudaSuccess)
131+
{
132+
printf("Error in cubic_feature_sampling_kernel_wrapper: %s\n", cudaGetErrorString(err));
133+
}
134+
return {point_features, grid_pt_indexes};
135+
}
136+
137+
template <typename scalar_t>
138+
__global__ void cubic_feature_sampling_grad_kernel(int scale, int neighborhood_size, int n_vertices,
139+
int n_pts, int n_cubic_channels,
140+
const scalar_t* __restrict__ grad_point_features,
141+
const int* __restrict__ grid_pt_indexes,
142+
scalar_t* __restrict__ grad_ptcloud,
143+
scalar_t* __restrict__ grad_cubic_features)
144+
{
145+
int batch_index = blockIdx.x;
146+
int index = threadIdx.x;
147+
int stride = blockDim.x;
148+
int cub_scale = scale * scale * scale;
149+
150+
grad_point_features += batch_index * n_pts * n_vertices * n_cubic_channels;
151+
grid_pt_indexes += batch_index * n_pts * n_vertices;
152+
grad_ptcloud += batch_index * n_pts * 3;
153+
grad_cubic_features += batch_index * n_cubic_channels * cub_scale;
154+
155+
for (int i = index; i < n_pts; i += stride)
156+
{
157+
for (int j = 0; j < n_vertices; ++j)
158+
{
159+
int vertex_idx = grid_pt_indexes[i * n_vertices + j];
160+
if (vertex_idx == -1)
161+
{
162+
continue;
163+
}
164+
for (int k = 0; k < n_cubic_channels; ++k)
165+
{
166+
int grad_idx = i * n_vertices * n_cubic_channels + j * n_cubic_channels + k;
167+
scalar_t grad_val = grad_point_features[grad_idx];
168+
// Fix bugs: the gradients of ceil and floor functions are zeros.
169+
// Ref: https://github.com/tensorflow/tensorflow/issues/897
170+
// atomicAdd(&(grad_ptcloud[i * 3 + 0]), grad_val);
171+
// atomicAdd(&(grad_ptcloud[i * 3 + 1]), grad_val);
172+
// atomicAdd(&(grad_ptcloud[i * 3 + 2]), grad_val);
173+
atomicAdd(&(grad_cubic_features[k * cub_scale + vertex_idx]), grad_val);
174+
}
175+
}
176+
}
177+
}
178+
179+
std::vector<torch::Tensor>
180+
cubic_feature_sampling_grad_kernel_wrapper(int scale, int neighborhood_size,
181+
torch::Tensor grad_point_features,
182+
torch::Tensor grid_pt_indexes, cudaStream_t stream)
183+
{
184+
int batch_size = grad_point_features.size(0);
185+
int n_cubic_channels = grad_point_features.size(3);
186+
int n_pts = grid_pt_indexes.size(1);
187+
int n_vertices = std::pow(neighborhood_size * 2, 3);
188+
189+
torch::Tensor grad_ptcloud =
190+
torch::zeros({batch_size, n_pts, 3}, torch::CUDA(grad_point_features.scalar_type()));
191+
torch::Tensor grad_cubic_features =
192+
torch::zeros({batch_size, n_cubic_channels, scale, scale, scale},
193+
torch::CUDA(grad_point_features.scalar_type()));
194+
195+
AT_DISPATCH_FLOATING_TYPES(
196+
grad_point_features.scalar_type(), "cubic_feature_sampling_grad_cuda", ([&] {
197+
cubic_feature_sampling_grad_kernel<<<batch_size, get_n_threads(n_pts), 0, stream>>>(
198+
scale, neighborhood_size, n_vertices, n_pts, n_cubic_channels,
199+
grad_point_features.data_ptr<scalar_t>(), grid_pt_indexes.data_ptr<int>(),
200+
grad_ptcloud.data_ptr<scalar_t>(), grad_cubic_features.data_ptr<scalar_t>());
201+
}));
202+
203+
cudaError_t err = cudaGetLastError();
204+
if (err != cudaSuccess)
205+
{
206+
printf("Error in cubic_feature_sampling_grad_kernel_wrapper: %s\n",
207+
cudaGetErrorString(err));
208+
}
209+
return {grad_ptcloud, grad_cubic_features};
210+
}

setup.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def get_ext_modules():
3737
name="torch_points_kernels.points_cuda",
3838
sources=ext_sources,
3939
include_dirs=["{}/include".format(ext_src_root)],
40-
extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,},
40+
extra_compile_args={
41+
"cxx": extra_compile_args,
42+
"nvcc": extra_compile_args,
43+
},
4144
)
4245
)
4346

@@ -49,7 +52,9 @@ def get_ext_modules():
4952
name="torch_points_kernels.points_cpu",
5053
sources=cpu_ext_sources,
5154
include_dirs=["{}/include".format(cpu_ext_src_root)],
52-
extra_compile_args={"cxx": extra_compile_args,},
55+
extra_compile_args={
56+
"cxx": extra_compile_args,
57+
},
5358
)
5459
)
5560
return ext_modules
@@ -81,5 +86,8 @@ def get_cmdclass():
8186
cmdclass=get_cmdclass(),
8287
long_description=long_description,
8388
long_description_content_type="text/markdown",
84-
classifiers=["Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License",],
89+
classifiers=[
90+
"Programming Language :: Python :: 3",
91+
"License :: OSI Approved :: MIT License",
92+
],
8593
)

test/speed_radius.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,26 @@ def test_speed(self):
2323
R = 1
2424
samples = 50
2525

26-
idx, dist = ball_query(R, samples, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,)
27-
idx1, dist = ball_query(R, samples, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,)
26+
idx, dist = ball_query(
27+
R,
28+
samples,
29+
a,
30+
b,
31+
mode="PARTIAL_DENSE",
32+
batch_x=batch_a,
33+
batch_y=batch_b,
34+
sort=True,
35+
)
36+
idx1, dist = ball_query(
37+
R,
38+
samples,
39+
a,
40+
b,
41+
mode="PARTIAL_DENSE",
42+
batch_x=batch_a,
43+
batch_y=batch_b,
44+
sort=True,
45+
)
2846
print(time.time() - start)
2947
torch.testing.assert_allclose(idx1, idx)
3048

0 commit comments

Comments
 (0)