@@ -19,15 +19,16 @@ __device__ int compute_index(int offset_x, int offset_y, int offset_z, int len_y
1919 return offset_x * len_y * len_z + offset_y * len_z + offset_z;
2020}
2121
22- __device__ float compute_weight (float x, float x0)
22+ template <typename scalar_t >
23+ __device__ scalar_t compute_weight (scalar_t x, scalar_t x0)
2324{
2425 return 1 - abs (x - x0);
2526}
2627
2728template <typename scalar_t >
2829__global__ void
29- gridding_kernel (int n_grid_vertices, int n_pts, float min_x, float min_y, float min_z,
30- int len_y, int len_z, const scalar_t * __restrict__ ptcloud,
30+ gridding_kernel (int n_grid_vertices, int n_pts, float min_x, float min_y, float min_z, int len_y,
31+ int len_z, const scalar_t * __restrict__ ptcloud,
3132 scalar_t * __restrict__ grid_weights, scalar_t * __restrict__ grid_pt_weights,
3233 int * __restrict__ grid_pt_indexes)
3334{
@@ -72,51 +73,51 @@ gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float
7273 // Compute weights and corresponding positions, a loop for 8 points
7374 // LLL -> Lower X, Lower Y, Lower Z
7475 grid_pt_indexes[j * 8 + 0 ] = compute_index (lx_offset, ly_offset, lz_offset, len_y, len_z);
75- grid_pt_weights[j * 24 + 0 ] = compute_weight (pt_x, lower_x);
76- grid_pt_weights[j * 24 + 1 ] = compute_weight (pt_y, lower_y);
77- grid_pt_weights[j * 24 + 2 ] = compute_weight (pt_z, lower_z);
76+ grid_pt_weights[j * 24 + 0 ] = compute_weight< scalar_t > (pt_x, lower_x);
77+ grid_pt_weights[j * 24 + 1 ] = compute_weight< scalar_t > (pt_y, lower_y);
78+ grid_pt_weights[j * 24 + 2 ] = compute_weight< scalar_t > (pt_z, lower_z);
7879
7980 // LLU -> Lower X, Lower Y, Upper Z
8081 grid_pt_indexes[j * 8 + 1 ] = compute_index (lx_offset, ly_offset, uz_offset, len_y, len_z);
81- grid_pt_weights[j * 24 + 3 ] = compute_weight (pt_x, lower_x);
82- grid_pt_weights[j * 24 + 4 ] = compute_weight (pt_y, lower_y);
83- grid_pt_weights[j * 24 + 5 ] = compute_weight (pt_z, upper_z);
82+ grid_pt_weights[j * 24 + 3 ] = compute_weight< scalar_t > (pt_x, lower_x);
83+ grid_pt_weights[j * 24 + 4 ] = compute_weight< scalar_t > (pt_y, lower_y);
84+ grid_pt_weights[j * 24 + 5 ] = compute_weight< scalar_t > (pt_z, upper_z);
8485
8586 // LUL -> Lower X, Upper Y, Lower Z
8687 grid_pt_indexes[j * 8 + 2 ] = compute_index (lx_offset, uy_offset, lz_offset, len_y, len_z);
87- grid_pt_weights[j * 24 + 6 ] = compute_weight (pt_x, lower_x);
88- grid_pt_weights[j * 24 + 7 ] = compute_weight (pt_y, upper_y);
89- grid_pt_weights[j * 24 + 8 ] = compute_weight (pt_z, lower_z);
88+ grid_pt_weights[j * 24 + 6 ] = compute_weight< scalar_t > (pt_x, lower_x);
89+ grid_pt_weights[j * 24 + 7 ] = compute_weight< scalar_t > (pt_y, upper_y);
90+ grid_pt_weights[j * 24 + 8 ] = compute_weight< scalar_t > (pt_z, lower_z);
9091
9192 // LUU -> Lower X, Upper Y, Upper Z
9293 grid_pt_indexes[j * 8 + 3 ] = compute_index (lx_offset, uy_offset, uz_offset, len_y, len_z);
93- grid_pt_weights[j * 24 + 9 ] = compute_weight (pt_x, lower_x);
94- grid_pt_weights[j * 24 + 10 ] = compute_weight (pt_y, upper_y);
95- grid_pt_weights[j * 24 + 11 ] = compute_weight (pt_z, upper_z);
94+ grid_pt_weights[j * 24 + 9 ] = compute_weight< scalar_t > (pt_x, lower_x);
95+ grid_pt_weights[j * 24 + 10 ] = compute_weight< scalar_t > (pt_y, upper_y);
96+ grid_pt_weights[j * 24 + 11 ] = compute_weight< scalar_t > (pt_z, upper_z);
9697
9798 // ULL -> Upper X, Lower Y, Lower Z
9899 grid_pt_indexes[j * 8 + 4 ] = compute_index (ux_offset, ly_offset, lz_offset, len_y, len_z);
99- grid_pt_weights[j * 24 + 12 ] = compute_weight (pt_x, upper_x);
100- grid_pt_weights[j * 24 + 13 ] = compute_weight (pt_y, lower_y);
101- grid_pt_weights[j * 24 + 14 ] = compute_weight (pt_z, lower_z);
100+ grid_pt_weights[j * 24 + 12 ] = compute_weight< scalar_t > (pt_x, upper_x);
101+ grid_pt_weights[j * 24 + 13 ] = compute_weight< scalar_t > (pt_y, lower_y);
102+ grid_pt_weights[j * 24 + 14 ] = compute_weight< scalar_t > (pt_z, lower_z);
102103
103104 // ULU -> Upper X, Lower Y, Upper Z
104105 grid_pt_indexes[j * 8 + 5 ] = compute_index (ux_offset, ly_offset, uz_offset, len_y, len_z);
105- grid_pt_weights[j * 24 + 15 ] = compute_weight (pt_x, upper_x);
106- grid_pt_weights[j * 24 + 16 ] = compute_weight (pt_y, lower_y);
107- grid_pt_weights[j * 24 + 17 ] = compute_weight (pt_z, upper_z);
106+ grid_pt_weights[j * 24 + 15 ] = compute_weight< scalar_t > (pt_x, upper_x);
107+ grid_pt_weights[j * 24 + 16 ] = compute_weight< scalar_t > (pt_y, lower_y);
108+ grid_pt_weights[j * 24 + 17 ] = compute_weight< scalar_t > (pt_z, upper_z);
108109
109110 // UUL -> Upper X, Upper Y, Lower Z
110111 grid_pt_indexes[j * 8 + 6 ] = compute_index (ux_offset, uy_offset, lz_offset, len_y, len_z);
111- grid_pt_weights[j * 24 + 18 ] = compute_weight (pt_x, upper_x);
112- grid_pt_weights[j * 24 + 19 ] = compute_weight (pt_y, upper_y);
113- grid_pt_weights[j * 24 + 20 ] = compute_weight (pt_z, lower_z);
112+ grid_pt_weights[j * 24 + 18 ] = compute_weight< scalar_t > (pt_x, upper_x);
113+ grid_pt_weights[j * 24 + 19 ] = compute_weight< scalar_t > (pt_y, upper_y);
114+ grid_pt_weights[j * 24 + 20 ] = compute_weight< scalar_t > (pt_z, lower_z);
114115
115116 // UUU -> Upper X, Upper Y, Upper Z
116117 grid_pt_indexes[j * 8 + 7 ] = compute_index (ux_offset, uy_offset, uz_offset, len_y, len_z);
117- grid_pt_weights[j * 24 + 21 ] = compute_weight (pt_x, upper_x);
118- grid_pt_weights[j * 24 + 22 ] = compute_weight (pt_y, upper_y);
119- grid_pt_weights[j * 24 + 23 ] = compute_weight (pt_z, upper_z);
118+ grid_pt_weights[j * 24 + 21 ] = compute_weight< scalar_t > (pt_x, upper_x);
119+ grid_pt_weights[j * 24 + 22 ] = compute_weight< scalar_t > (pt_y, upper_y);
120+ grid_pt_weights[j * 24 + 23 ] = compute_weight< scalar_t > (pt_z, upper_z);
120121 }
121122
122123 __syncthreads ();
@@ -179,9 +180,9 @@ std::vector<torch::Tensor> gridding_kernel_warpper(float min_x, float max_x, flo
179180 int n_grid_vertices = len_x * len_y * len_z;
180181
181182 torch::Tensor grid_weights =
182- torch::zeros ({batch_size, n_grid_vertices}, torch::CUDA (torch:: kFloat ));
183+ torch::zeros ({batch_size, n_grid_vertices}, torch::CUDA (ptcloud. scalar_type () ));
183184 torch::Tensor grid_pt_weights =
184- torch::zeros ({batch_size, n_pts, 8 , 3 }, torch::CUDA (torch:: kFloat ));
185+ torch::zeros ({batch_size, n_pts, 8 , 3 }, torch::CUDA (ptcloud. scalar_type () ));
185186 torch::Tensor grid_pt_indexes = torch::zeros ({batch_size, n_pts, 8 }, torch::CUDA (torch::kInt ));
186187
187188 AT_DISPATCH_FLOATING_TYPES (
@@ -310,7 +311,8 @@ torch::Tensor gridding_grad_kernel_warpper(torch::Tensor grid_pt_weights,
310311 int n_grid_vertices = grad_grid.size (1 );
311312 int n_pts = grid_pt_indexes.size (1 );
312313
313- torch::Tensor grad_ptcloud = torch::zeros ({batch_size, n_pts, 3 }, torch::CUDA (torch::kFloat ));
314+ torch::Tensor grad_ptcloud =
315+ torch::zeros ({batch_size, n_pts, 3 }, torch::CUDA (grid_pt_weights.scalar_type ()));
314316
315317 AT_DISPATCH_FLOATING_TYPES (
316318 grid_pt_weights.scalar_type (), " gridding_grad_cuda" , ([&] {
0 commit comments