44
55#include < vector>
66
7- __global__ void chamfer_dist_kernel (int batch_size, int n, const float * xyz1, int m,
8- const float * xyz2, float * dist, int * indexes)
7+ template <typename scalar_t >
8+ __global__ void chamfer_dist_kernel (int batch_size, int n, const scalar_t * __restrict__ xyz1, int m,
9+ const scalar_t * __restrict__ xyz2, scalar_t * __restrict__ dist,
10+ int * indexes)
911{
1012 const int batch = 512 ;
11- __shared__ float buf[batch * 3 ];
13+ __shared__ scalar_t buf[batch * 3 ];
1214 for (int i = blockIdx .x ; i < batch_size; i += gridDim .x )
1315 {
1416 for (int k2 = 0 ; k2 < m; k2 += batch)
@@ -21,21 +23,21 @@ __global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, in
2123 __syncthreads ();
2224 for (int j = threadIdx .x + blockIdx .y * blockDim .x ; j < n; j += blockDim .x * gridDim .y )
2325 {
24- float x1 = xyz1[(i * n + j) * 3 + 0 ];
25- float y1 = xyz1[(i * n + j) * 3 + 1 ];
26- float z1 = xyz1[(i * n + j) * 3 + 2 ];
27- float best_dist = 0 ;
26+ scalar_t x1 = xyz1[(i * n + j) * 3 + 0 ];
27+ scalar_t y1 = xyz1[(i * n + j) * 3 + 1 ];
28+ scalar_t z1 = xyz1[(i * n + j) * 3 + 2 ];
29+ scalar_t best_dist = 0 ;
2830 int best_dist_index = 0 ;
2931 int end_ka = end_k - (end_k & 3 );
3032 if (end_ka == batch)
3133 {
3234 for (int k = 0 ; k < batch; k += 4 )
3335 {
3436 {
35- float x2 = buf[k * 3 + 0 ] - x1;
36- float y2 = buf[k * 3 + 1 ] - y1;
37- float z2 = buf[k * 3 + 2 ] - z1;
38- float dist = x2 * x2 + y2 * y2 + z2 * z2;
37+ scalar_t x2 = buf[k * 3 + 0 ] - x1;
38+ scalar_t y2 = buf[k * 3 + 1 ] - y1;
39+ scalar_t z2 = buf[k * 3 + 2 ] - z1;
40+ scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
3941
4042 if (k == 0 || dist < best_dist)
4143 {
@@ -44,32 +46,32 @@ __global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, in
4446 }
4547 }
4648 {
47- float x2 = buf[k * 3 + 3 ] - x1;
48- float y2 = buf[k * 3 + 4 ] - y1;
49- float z2 = buf[k * 3 + 5 ] - z1;
50- float dist = x2 * x2 + y2 * y2 + z2 * z2;
49+ scalar_t x2 = buf[k * 3 + 3 ] - x1;
50+ scalar_t y2 = buf[k * 3 + 4 ] - y1;
51+ scalar_t z2 = buf[k * 3 + 5 ] - z1;
52+ scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
5153 if (dist < best_dist)
5254 {
5355 best_dist = dist;
5456 best_dist_index = k + k2 + 1 ;
5557 }
5658 }
5759 {
58- float x2 = buf[k * 3 + 6 ] - x1;
59- float y2 = buf[k * 3 + 7 ] - y1;
60- float z2 = buf[k * 3 + 8 ] - z1;
61- float dist = x2 * x2 + y2 * y2 + z2 * z2;
60+ scalar_t x2 = buf[k * 3 + 6 ] - x1;
61+ scalar_t y2 = buf[k * 3 + 7 ] - y1;
62+ scalar_t z2 = buf[k * 3 + 8 ] - z1;
63+ scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
6264 if (dist < best_dist)
6365 {
6466 best_dist = dist;
6567 best_dist_index = k + k2 + 2 ;
6668 }
6769 }
6870 {
69- float x2 = buf[k * 3 + 9 ] - x1;
70- float y2 = buf[k * 3 + 10 ] - y1;
71- float z2 = buf[k * 3 + 11 ] - z1;
72- float dist = x2 * x2 + y2 * y2 + z2 * z2;
71+ scalar_t x2 = buf[k * 3 + 9 ] - x1;
72+ scalar_t y2 = buf[k * 3 + 10 ] - y1;
73+ scalar_t z2 = buf[k * 3 + 11 ] - z1;
74+ scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
7375 if (dist < best_dist)
7476 {
7577 best_dist = dist;
@@ -83,43 +85,43 @@ __global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, in
8385 for (int k = 0 ; k < end_ka; k += 4 )
8486 {
8587 {
86- float x2 = buf[k * 3 + 0 ] - x1;
87- float y2 = buf[k * 3 + 1 ] - y1;
88- float z2 = buf[k * 3 + 2 ] - z1;
89- float dist = x2 * x2 + y2 * y2 + z2 * z2;
88+ scalar_t x2 = buf[k * 3 + 0 ] - x1;
89+ scalar_t y2 = buf[k * 3 + 1 ] - y1;
90+ scalar_t z2 = buf[k * 3 + 2 ] - z1;
91+ scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
9092 if (k == 0 || dist < best_dist)
9193 {
9294 best_dist = dist;
9395 best_dist_index = k + k2;
9496 }
9597 }
9698 {
97- float x2 = buf[k * 3 + 3 ] - x1;
98- float y2 = buf[k * 3 + 4 ] - y1;
99- float z2 = buf[k * 3 + 5 ] - z1;
100- float dist = x2 * x2 + y2 * y2 + z2 * z2;
99+ scalar_t x2 = buf[k * 3 + 3 ] - x1;
100+ scalar_t y2 = buf[k * 3 + 4 ] - y1;
101+ scalar_t z2 = buf[k * 3 + 5 ] - z1;
102+ scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
101103 if (dist < best_dist)
102104 {
103105 best_dist = dist;
104106 best_dist_index = k + k2 + 1 ;
105107 }
106108 }
107109 {
108- float x2 = buf[k * 3 + 6 ] - x1;
109- float y2 = buf[k * 3 + 7 ] - y1;
110- float z2 = buf[k * 3 + 8 ] - z1;
111- float dist = x2 * x2 + y2 * y2 + z2 * z2;
110+ scalar_t x2 = buf[k * 3 + 6 ] - x1;
111+ scalar_t y2 = buf[k * 3 + 7 ] - y1;
112+ scalar_t z2 = buf[k * 3 + 8 ] - z1;
113+ scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
112114 if (dist < best_dist)
113115 {
114116 best_dist = dist;
115117 best_dist_index = k + k2 + 2 ;
116118 }
117119 }
118120 {
119- float x2 = buf[k * 3 + 9 ] - x1;
120- float y2 = buf[k * 3 + 10 ] - y1;
121- float z2 = buf[k * 3 + 11 ] - z1;
122- float dist = x2 * x2 + y2 * y2 + z2 * z2;
121+ scalar_t x2 = buf[k * 3 + 9 ] - x1;
122+ scalar_t y2 = buf[k * 3 + 10 ] - y1;
123+ scalar_t z2 = buf[k * 3 + 11 ] - z1;
124+ scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
123125 if (dist < best_dist)
124126 {
125127 best_dist = dist;
@@ -130,10 +132,10 @@ __global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, in
130132 }
131133 for (int k = end_ka; k < end_k; k++)
132134 {
133- float x2 = buf[k * 3 + 0 ] - x1;
134- float y2 = buf[k * 3 + 1 ] - y1;
135- float z2 = buf[k * 3 + 2 ] - z1;
136- float dist = x2 * x2 + y2 * y2 + z2 * z2;
135+ scalar_t x2 = buf[k * 3 + 0 ] - x1;
136+ scalar_t y2 = buf[k * 3 + 1 ] - y1;
137+ scalar_t z2 = buf[k * 3 + 2 ] - z1;
138+ scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
137139 if (k == 0 || dist < best_dist)
138140 {
139141 best_dist = dist;
@@ -161,12 +163,16 @@ std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch
161163 torch::Tensor idx1 = torch::zeros ({batch_size, n}, torch::CUDA (torch::kInt ));
162164 torch::Tensor idx2 = torch::zeros ({batch_size, m}, torch::CUDA (torch::kInt ));
163165
164- chamfer_dist_kernel<<<dim3 (32 , 16 , 1 ), 512 >>> (batch_size, n, xyz1.data_ptr <float >(), m,
165- xyz2.data_ptr <float >(), dist1.data_ptr <float >(),
166- idx1.data_ptr <int >());
167- chamfer_dist_kernel<<<dim3 (32 , 16 , 1 ), 512 >>> (batch_size, m, xyz2.data_ptr <float >(), n,
168- xyz1.data_ptr <float >(), dist2.data_ptr <float >(),
169- idx2.data_ptr <int >());
166+ AT_DISPATCH_FLOATING_TYPES (
167+ xyz1.scalar_type (), " chamfer_dist_cuda" , ([&] {
168+ chamfer_dist_kernel<scalar_t ><<<dim3 (32 , 16 , 1 ), 512 >>> (
169+ batch_size, n, xyz1.data_ptr <scalar_t >(), m, xyz2.data_ptr <scalar_t >(),
170+ dist1.data_ptr <scalar_t >(), idx1.data_ptr <int >());
171+
172+ chamfer_dist_kernel<scalar_t ><<<dim3 (32 , 16 , 1 ), 512 >>> (
173+ batch_size, m, xyz2.data_ptr <scalar_t >(), n, xyz1.data_ptr <scalar_t >(),
174+ dist2.data_ptr <scalar_t >(), idx2.data_ptr <int >());
175+ }));
170176
171177 cudaError_t err = cudaGetLastError ();
172178 if (err != cudaSuccess)
@@ -176,22 +182,25 @@ std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch
176182 return {dist1, dist2, idx1, idx2};
177183}
178184
179- __global__ void chamfer_dist_grad_kernel (int b, int n, const float * xyz1, int m, const float * xyz2,
180- const float * grad_dist1, const int * idx1, float * grad_xyz1,
181- float * grad_xyz2)
185+ template <typename scalar_t >
186+ __global__ void chamfer_dist_grad_kernel (int b, int n, const scalar_t * __restrict__ xyz1, int m,
187+ const scalar_t * __restrict__ xyz2,
188+ const scalar_t * __restrict__ grad_dist1, const int * idx1,
189+ scalar_t * __restrict__ grad_xyz1,
190+ scalar_t * __restrict__ grad_xyz2)
182191{
183192 for (int i = blockIdx .x ; i < b; i += gridDim .x )
184193 {
185194 for (int j = threadIdx .x + blockIdx .y * blockDim .x ; j < n; j += blockDim .x * gridDim .y )
186195 {
187- float x1 = xyz1[(i * n + j) * 3 + 0 ];
188- float y1 = xyz1[(i * n + j) * 3 + 1 ];
189- float z1 = xyz1[(i * n + j) * 3 + 2 ];
196+ scalar_t x1 = xyz1[(i * n + j) * 3 + 0 ];
197+ scalar_t y1 = xyz1[(i * n + j) * 3 + 1 ];
198+ scalar_t z1 = xyz1[(i * n + j) * 3 + 2 ];
190199 int j2 = idx1[i * n + j];
191- float x2 = xyz2[(i * m + j2) * 3 + 0 ];
192- float y2 = xyz2[(i * m + j2) * 3 + 1 ];
193- float z2 = xyz2[(i * m + j2) * 3 + 2 ];
194- float g = grad_dist1[i * n + j] * 2 ;
200+ scalar_t x2 = xyz2[(i * m + j2) * 3 + 0 ];
201+ scalar_t y2 = xyz2[(i * m + j2) * 3 + 1 ];
202+ scalar_t z2 = xyz2[(i * m + j2) * 3 + 2 ];
203+ scalar_t g = grad_dist1[i * n + j] * 2 ;
195204 atomicAdd (&(grad_xyz1[(i * n + j) * 3 + 0 ]), g * (x1 - x2));
196205 atomicAdd (&(grad_xyz1[(i * n + j) * 3 + 1 ]), g * (y1 - y2));
197206 atomicAdd (&(grad_xyz1[(i * n + j) * 3 + 2 ]), g * (z1 - z2));
@@ -213,14 +222,18 @@ std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1,
213222 torch::Tensor grad_xyz1 = torch::zeros_like (xyz1, torch::CUDA (torch::kFloat ));
214223 torch::Tensor grad_xyz2 = torch::zeros_like (xyz2, torch::CUDA (torch::kFloat ));
215224
216- chamfer_dist_grad_kernel<<<dim3 (1 , 16 , 1 ), 256 >>> (
217- batch_size, n, xyz1.data_ptr <float >(), m, xyz2.data_ptr <float >(),
218- grad_dist1.data_ptr <float >(), idx1.data_ptr <int >(), grad_xyz1.data_ptr <float >(),
219- grad_xyz2.data_ptr <float >());
220- chamfer_dist_grad_kernel<<<dim3 (1 , 16 , 1 ), 256 >>> (
221- batch_size, m, xyz2.data_ptr <float >(), n, xyz1.data_ptr <float >(),
222- grad_dist2.data_ptr <float >(), idx2.data_ptr <int >(), grad_xyz2.data_ptr <float >(),
223- grad_xyz1.data_ptr <float >());
225+ AT_DISPATCH_FLOATING_TYPES (
226+ xyz1.scalar_type (), " chamfer_dist_grad_cuda" , ([&] {
227+ chamfer_dist_grad_kernel<scalar_t ><<<dim3 (1 , 16 , 1 ), 256 >>> (
228+ batch_size, n, xyz1.data_ptr <scalar_t >(), m, xyz2.data_ptr <scalar_t >(),
229+ grad_dist1.data_ptr <scalar_t >(), idx1.data_ptr <int >(),
230+ grad_xyz1.data_ptr <scalar_t >(), grad_xyz2.data_ptr <scalar_t >());
231+
232+ chamfer_dist_grad_kernel<scalar_t ><<<dim3 (1 , 16 , 1 ), 256 >>> (
233+ batch_size, m, xyz2.data_ptr <scalar_t >(), n, xyz1.data_ptr <scalar_t >(),
234+ grad_dist2.data_ptr <scalar_t >(), idx2.data_ptr <int >(),
235+ grad_xyz2.data_ptr <scalar_t >(), grad_xyz1.data_ptr <scalar_t >());
236+ }));
224237
225238 cudaError_t err = cudaGetLastError ();
226239 if (err != cudaSuccess)
0 commit comments