Skip to content

Commit a50d3f2

Browse files
committed
Add pybinding for Chamfer Distance.
1 parent 652c179 commit a50d3f2

File tree

5 files changed

+91
-69
lines changed

5 files changed

+91
-69
lines changed

cuda/include/chamfer_dist.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
#include <torch/extension.h>
22
#include <vector>
33

4+
std::vector<torch::Tensor> chamfer_dist(torch::Tensor xyz1, torch::Tensor xyz2);
5+
6+
std::vector<torch::Tensor> chamfer_dist_grad(torch::Tensor xyz1, torch::Tensor xyz2,
7+
torch::Tensor idx1, torch::Tensor idx2,
8+
torch::Tensor grad_dist1, torch::Tensor grad_dist2);
9+
410
std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2);
511

612
std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2,
713
torch::Tensor idx1, torch::Tensor idx2,
814
torch::Tensor grad_dist1,
9-
torch::Tensor grad_dist2);
15+
torch::Tensor grad_dist2);

cuda/src/bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "ball_query.h"
2+
#include "chamfer_dist.h"
23
#include "interpolate.h"
34
#include "metrics.h"
45
#include "sampling.h"
@@ -15,4 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
1516
m.def("ball_query_partial_dense", &ball_query_partial_dense);
1617

1718
m.def("instance_iou_cuda", &instance_iou_cuda);
19+
20+
m.def("chamfer_dist", &chamfer_dist);
21+
m.def("chamfer_dist_grad", &chamfer_dist_grad);
1822
}

cuda/src/chamfer_dist.cu

Lines changed: 80 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
#include <vector>
66

7-
__global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, int m,
8-
const float* xyz2, float* dist, int* indexes)
7+
template <typename scalar_t>
8+
__global__ void chamfer_dist_kernel(int batch_size, int n, const scalar_t* __restrict__ xyz1, int m,
9+
const scalar_t* __restrict__ xyz2, scalar_t* __restrict__ dist,
10+
int* indexes)
911
{
1012
const int batch = 512;
11-
__shared__ float buf[batch * 3];
13+
__shared__ scalar_t buf[batch * 3];
1214
for (int i = blockIdx.x; i < batch_size; i += gridDim.x)
1315
{
1416
for (int k2 = 0; k2 < m; k2 += batch)
@@ -21,21 +23,21 @@ __global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, in
2123
__syncthreads();
2224
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y)
2325
{
24-
float x1 = xyz1[(i * n + j) * 3 + 0];
25-
float y1 = xyz1[(i * n + j) * 3 + 1];
26-
float z1 = xyz1[(i * n + j) * 3 + 2];
27-
float best_dist = 0;
26+
scalar_t x1 = xyz1[(i * n + j) * 3 + 0];
27+
scalar_t y1 = xyz1[(i * n + j) * 3 + 1];
28+
scalar_t z1 = xyz1[(i * n + j) * 3 + 2];
29+
scalar_t best_dist = 0;
2830
int best_dist_index = 0;
2931
int end_ka = end_k - (end_k & 3);
3032
if (end_ka == batch)
3133
{
3234
for (int k = 0; k < batch; k += 4)
3335
{
3436
{
35-
float x2 = buf[k * 3 + 0] - x1;
36-
float y2 = buf[k * 3 + 1] - y1;
37-
float z2 = buf[k * 3 + 2] - z1;
38-
float dist = x2 * x2 + y2 * y2 + z2 * z2;
37+
scalar_t x2 = buf[k * 3 + 0] - x1;
38+
scalar_t y2 = buf[k * 3 + 1] - y1;
39+
scalar_t z2 = buf[k * 3 + 2] - z1;
40+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
3941

4042
if (k == 0 || dist < best_dist)
4143
{
@@ -44,32 +46,32 @@ __global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, in
4446
}
4547
}
4648
{
47-
float x2 = buf[k * 3 + 3] - x1;
48-
float y2 = buf[k * 3 + 4] - y1;
49-
float z2 = buf[k * 3 + 5] - z1;
50-
float dist = x2 * x2 + y2 * y2 + z2 * z2;
49+
scalar_t x2 = buf[k * 3 + 3] - x1;
50+
scalar_t y2 = buf[k * 3 + 4] - y1;
51+
scalar_t z2 = buf[k * 3 + 5] - z1;
52+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
5153
if (dist < best_dist)
5254
{
5355
best_dist = dist;
5456
best_dist_index = k + k2 + 1;
5557
}
5658
}
5759
{
58-
float x2 = buf[k * 3 + 6] - x1;
59-
float y2 = buf[k * 3 + 7] - y1;
60-
float z2 = buf[k * 3 + 8] - z1;
61-
float dist = x2 * x2 + y2 * y2 + z2 * z2;
60+
scalar_t x2 = buf[k * 3 + 6] - x1;
61+
scalar_t y2 = buf[k * 3 + 7] - y1;
62+
scalar_t z2 = buf[k * 3 + 8] - z1;
63+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
6264
if (dist < best_dist)
6365
{
6466
best_dist = dist;
6567
best_dist_index = k + k2 + 2;
6668
}
6769
}
6870
{
69-
float x2 = buf[k * 3 + 9] - x1;
70-
float y2 = buf[k * 3 + 10] - y1;
71-
float z2 = buf[k * 3 + 11] - z1;
72-
float dist = x2 * x2 + y2 * y2 + z2 * z2;
71+
scalar_t x2 = buf[k * 3 + 9] - x1;
72+
scalar_t y2 = buf[k * 3 + 10] - y1;
73+
scalar_t z2 = buf[k * 3 + 11] - z1;
74+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
7375
if (dist < best_dist)
7476
{
7577
best_dist = dist;
@@ -83,43 +85,43 @@ __global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, in
8385
for (int k = 0; k < end_ka; k += 4)
8486
{
8587
{
86-
float x2 = buf[k * 3 + 0] - x1;
87-
float y2 = buf[k * 3 + 1] - y1;
88-
float z2 = buf[k * 3 + 2] - z1;
89-
float dist = x2 * x2 + y2 * y2 + z2 * z2;
88+
scalar_t x2 = buf[k * 3 + 0] - x1;
89+
scalar_t y2 = buf[k * 3 + 1] - y1;
90+
scalar_t z2 = buf[k * 3 + 2] - z1;
91+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
9092
if (k == 0 || dist < best_dist)
9193
{
9294
best_dist = dist;
9395
best_dist_index = k + k2;
9496
}
9597
}
9698
{
97-
float x2 = buf[k * 3 + 3] - x1;
98-
float y2 = buf[k * 3 + 4] - y1;
99-
float z2 = buf[k * 3 + 5] - z1;
100-
float dist = x2 * x2 + y2 * y2 + z2 * z2;
99+
scalar_t x2 = buf[k * 3 + 3] - x1;
100+
scalar_t y2 = buf[k * 3 + 4] - y1;
101+
scalar_t z2 = buf[k * 3 + 5] - z1;
102+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
101103
if (dist < best_dist)
102104
{
103105
best_dist = dist;
104106
best_dist_index = k + k2 + 1;
105107
}
106108
}
107109
{
108-
float x2 = buf[k * 3 + 6] - x1;
109-
float y2 = buf[k * 3 + 7] - y1;
110-
float z2 = buf[k * 3 + 8] - z1;
111-
float dist = x2 * x2 + y2 * y2 + z2 * z2;
110+
scalar_t x2 = buf[k * 3 + 6] - x1;
111+
scalar_t y2 = buf[k * 3 + 7] - y1;
112+
scalar_t z2 = buf[k * 3 + 8] - z1;
113+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
112114
if (dist < best_dist)
113115
{
114116
best_dist = dist;
115117
best_dist_index = k + k2 + 2;
116118
}
117119
}
118120
{
119-
float x2 = buf[k * 3 + 9] - x1;
120-
float y2 = buf[k * 3 + 10] - y1;
121-
float z2 = buf[k * 3 + 11] - z1;
122-
float dist = x2 * x2 + y2 * y2 + z2 * z2;
121+
scalar_t x2 = buf[k * 3 + 9] - x1;
122+
scalar_t y2 = buf[k * 3 + 10] - y1;
123+
scalar_t z2 = buf[k * 3 + 11] - z1;
124+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
123125
if (dist < best_dist)
124126
{
125127
best_dist = dist;
@@ -130,10 +132,10 @@ __global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, in
130132
}
131133
for (int k = end_ka; k < end_k; k++)
132134
{
133-
float x2 = buf[k * 3 + 0] - x1;
134-
float y2 = buf[k * 3 + 1] - y1;
135-
float z2 = buf[k * 3 + 2] - z1;
136-
float dist = x2 * x2 + y2 * y2 + z2 * z2;
135+
scalar_t x2 = buf[k * 3 + 0] - x1;
136+
scalar_t y2 = buf[k * 3 + 1] - y1;
137+
scalar_t z2 = buf[k * 3 + 2] - z1;
138+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
137139
if (k == 0 || dist < best_dist)
138140
{
139141
best_dist = dist;
@@ -161,12 +163,16 @@ std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch
161163
torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));
162164
torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));
163165

164-
chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(batch_size, n, xyz1.data_ptr<float>(), m,
165-
xyz2.data_ptr<float>(), dist1.data_ptr<float>(),
166-
idx1.data_ptr<int>());
167-
chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(batch_size, m, xyz2.data_ptr<float>(), n,
168-
xyz1.data_ptr<float>(), dist2.data_ptr<float>(),
169-
idx2.data_ptr<int>());
166+
AT_DISPATCH_FLOATING_TYPES(
167+
xyz1.scalar_type(), "chamfer_dist_cuda", ([&] {
168+
chamfer_dist_kernel<scalar_t><<<dim3(32, 16, 1), 512>>>(
169+
batch_size, n, xyz1.data_ptr<scalar_t>(), m, xyz2.data_ptr<scalar_t>(),
170+
dist1.data_ptr<scalar_t>(), idx1.data_ptr<int>());
171+
172+
chamfer_dist_kernel<scalar_t><<<dim3(32, 16, 1), 512>>>(
173+
batch_size, m, xyz2.data_ptr<scalar_t>(), n, xyz1.data_ptr<scalar_t>(),
174+
dist2.data_ptr<scalar_t>(), idx2.data_ptr<int>());
175+
}));
170176

171177
cudaError_t err = cudaGetLastError();
172178
if (err != cudaSuccess)
@@ -176,22 +182,25 @@ std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch
176182
return {dist1, dist2, idx1, idx2};
177183
}
178184

179-
__global__ void chamfer_dist_grad_kernel(int b, int n, const float* xyz1, int m, const float* xyz2,
180-
const float* grad_dist1, const int* idx1, float* grad_xyz1,
181-
float* grad_xyz2)
185+
template <typename scalar_t>
186+
__global__ void chamfer_dist_grad_kernel(int b, int n, const scalar_t* __restrict__ xyz1, int m,
187+
const scalar_t* __restrict__ xyz2,
188+
const scalar_t* __restrict__ grad_dist1, const int* idx1,
189+
scalar_t* __restrict__ grad_xyz1,
190+
scalar_t* __restrict__ grad_xyz2)
182191
{
183192
for (int i = blockIdx.x; i < b; i += gridDim.x)
184193
{
185194
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y)
186195
{
187-
float x1 = xyz1[(i * n + j) * 3 + 0];
188-
float y1 = xyz1[(i * n + j) * 3 + 1];
189-
float z1 = xyz1[(i * n + j) * 3 + 2];
196+
scalar_t x1 = xyz1[(i * n + j) * 3 + 0];
197+
scalar_t y1 = xyz1[(i * n + j) * 3 + 1];
198+
scalar_t z1 = xyz1[(i * n + j) * 3 + 2];
190199
int j2 = idx1[i * n + j];
191-
float x2 = xyz2[(i * m + j2) * 3 + 0];
192-
float y2 = xyz2[(i * m + j2) * 3 + 1];
193-
float z2 = xyz2[(i * m + j2) * 3 + 2];
194-
float g = grad_dist1[i * n + j] * 2;
200+
scalar_t x2 = xyz2[(i * m + j2) * 3 + 0];
201+
scalar_t y2 = xyz2[(i * m + j2) * 3 + 1];
202+
scalar_t z2 = xyz2[(i * m + j2) * 3 + 2];
203+
scalar_t g = grad_dist1[i * n + j] * 2;
195204
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));
196205
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));
197206
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));
@@ -213,14 +222,18 @@ std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1,
213222
torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat));
214223
torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat));
215224

216-
chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>(
217-
batch_size, n, xyz1.data_ptr<float>(), m, xyz2.data_ptr<float>(),
218-
grad_dist1.data_ptr<float>(), idx1.data_ptr<int>(), grad_xyz1.data_ptr<float>(),
219-
grad_xyz2.data_ptr<float>());
220-
chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>(
221-
batch_size, m, xyz2.data_ptr<float>(), n, xyz1.data_ptr<float>(),
222-
grad_dist2.data_ptr<float>(), idx2.data_ptr<int>(), grad_xyz2.data_ptr<float>(),
223-
grad_xyz1.data_ptr<float>());
225+
AT_DISPATCH_FLOATING_TYPES(
226+
xyz1.scalar_type(), "chamfer_dist_grad_cuda", ([&] {
227+
chamfer_dist_grad_kernel<scalar_t><<<dim3(1, 16, 1), 256>>>(
228+
batch_size, n, xyz1.data_ptr<scalar_t>(), m, xyz2.data_ptr<scalar_t>(),
229+
grad_dist1.data_ptr<scalar_t>(), idx1.data_ptr<int>(),
230+
grad_xyz1.data_ptr<scalar_t>(), grad_xyz2.data_ptr<scalar_t>());
231+
232+
chamfer_dist_grad_kernel<scalar_t><<<dim3(1, 16, 1), 256>>>(
233+
batch_size, m, xyz2.data_ptr<scalar_t>(), n, xyz1.data_ptr<scalar_t>(),
234+
grad_dist2.data_ptr<scalar_t>(), idx2.data_ptr<int>(),
235+
grad_xyz2.data_ptr<scalar_t>(), grad_xyz1.data_ptr<scalar_t>());
236+
}));
224237

225238
cudaError_t err = cudaGetLastError();
226239
if (err != cudaSuccess)
File renamed without changes.

torch_points_kernels/torchpoints.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ class ChamferFunction(Function):
241241
@staticmethod
242242
def forward(ctx, xyz1, xyz2):
243243
dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2)
244-
print(dir(tpcuda))
245244
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
246245

247246
return dist1, dist2

0 commit comments

Comments
 (0)