Skip to content

Commit c50bd18

Browse files
Merge pull request #57 from hzxie/gridding
Add the implement of the Gridding layer (arXiv 2006.03761)
2 parents 48e3a16 + e2c315b commit c50bd18

File tree

11 files changed

+486
-23
lines changed

11 files changed

+486
-23
lines changed

cuda/include/cuda_utils.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,22 @@ inline dim3 opt_block_config(int x, int y)
3232
// from https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
3333
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600
3434
#else
35-
__device__ double atomicAdd(double* address, double val)
36-
{
37-
unsigned long long int* address_as_ull =
38-
(unsigned long long int*)address;
39-
unsigned long long int old = *address_as_ull, assumed;
35+
__device__ double atomicAdd(double* address, double val)
36+
{
37+
unsigned long long int* address_as_ull = (unsigned long long int*)address;
38+
unsigned long long int old = *address_as_ull, assumed;
4039

41-
do {
42-
assumed = old;
43-
old = atomicCAS(address_as_ull, assumed,
44-
__double_as_longlong(val +
45-
__longlong_as_double(assumed)));
40+
do
41+
{
42+
assumed = old;
43+
old = atomicCAS(address_as_ull, assumed,
44+
__double_as_longlong(val + __longlong_as_double(assumed)));
4645

4746
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
48-
} while (assumed != old);
47+
} while (assumed != old);
4948

50-
return __longlong_as_double(old);
51-
}
49+
return __longlong_as_double(old);
50+
}
5251
#endif
5352

5453
#define CUDA_CHECK_ERRORS() \

cuda/include/gridding.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include <vector>
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <torch/extension.h>
5+
6+
std::vector<torch::Tensor> gridding_kernel_warpper(float min_x, float max_x, float min_y,
7+
float max_y, float min_z, float max_z,
8+
torch::Tensor ptcloud, cudaStream_t stream);
9+
10+
torch::Tensor gridding_grad_kernel_warpper(torch::Tensor grid_pt_weights,
11+
torch::Tensor grid_pt_indexes, torch::Tensor grad_grid,
12+
cudaStream_t stream);
13+
14+
std::vector<torch::Tensor> gridding(float min_x, float max_x, float min_y, float max_y, float min_z,
15+
float max_z, torch::Tensor ptcloud);
16+
17+
torch::Tensor gridding_grad(torch::Tensor grid_pt_weights, torch::Tensor grid_pt_indexes,
18+
torch::Tensor grad_grid);

cuda/src/bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "ball_query.h"
22
#include "chamfer_dist.h"
33
#include "cubic_feature_sampling.h"
4+
#include "gridding.h"
45
#include "interpolate.h"
56
#include "metrics.h"
67
#include "sampling.h"
@@ -23,4 +24,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2324

2425
m.def("cubic_feature_sampling", &cubic_feature_sampling);
2526
m.def("cubic_feature_sampling_grad", &cubic_feature_sampling_grad);
27+
28+
m.def("gridding", &gridding);
29+
m.def("gridding_grad", &gridding_grad);
2630
}

cuda/src/chamfer_dist_gpu.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#include <cuda_runtime.h>
33
#include <torch/extension.h>
44

5-
#include <vector>
65
#include "cuda_utils.h"
6+
#include <vector>
77

88
template <typename scalar_t>
99
__global__ void chamfer_dist_kernel(int batch_size, int n, const scalar_t* __restrict__ xyz1, int m,

cuda/src/gridding.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include "gridding.h"
2+
#include "utils.h"
3+
4+
std::vector<torch::Tensor> gridding(float min_x, float max_x, float min_y, float max_y, float min_z,
5+
float max_z, torch::Tensor ptcloud)
6+
{
7+
CHECK_CUDA(ptcloud);
8+
CHECK_CONTIGUOUS(ptcloud);
9+
10+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
11+
return gridding_kernel_warpper(min_x, max_x, min_y, max_y, min_z, max_z, ptcloud, stream);
12+
}
13+
14+
torch::Tensor gridding_grad(torch::Tensor grid_pt_weights, torch::Tensor grid_pt_indexes,
15+
torch::Tensor grad_grid)
16+
{
17+
CHECK_CUDA(grid_pt_weights);
18+
CHECK_CONTIGUOUS(grid_pt_weights);
19+
CHECK_CUDA(grid_pt_indexes);
20+
CHECK_CONTIGUOUS(grid_pt_indexes);
21+
CHECK_CUDA(grad_grid);
22+
CHECK_CONTIGUOUS(grad_grid);
23+
24+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
25+
return gridding_grad_kernel_warpper(grid_pt_weights, grid_pt_indexes, grad_grid, stream);
26+
}

0 commit comments

Comments
 (0)