Skip to content

Commit e3d66bc

Browse files
Merge pull request #46 from hzxie/chamfer-dist
Add Chamfer Distance
2 parents 83ac174 + a28cd6f commit e3d66bc

File tree

8 files changed

+374
-0
lines changed

8 files changed

+374
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,5 @@ See [this useful chart](http://arnon.dk/matching-sm-architectures-arch-and-genco
6262
* [```Pointnet2_Tensorflow```](https://github.com/charlesq34/pointnet2) by [Charles R. Qi](https://github.com/charlesq34)
6363

6464
* [```Pointnet2_PyTorch```](https://github.com/erikwijmans/Pointnet2_PyTorch) by [Erik Wijmans](https://github.com/erikwijmans)
65+
66+
* [```GRNet```](https://github.com/hzxie/GRNet) by [Haozhe Xie](https://github.com/hzxie)

cuda/include/chamfer_dist.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include <torch/extension.h>
2+
#include <vector>
3+
4+
std::vector<torch::Tensor> chamfer_dist(torch::Tensor xyz1, torch::Tensor xyz2);
5+
6+
std::vector<torch::Tensor> chamfer_dist_grad(torch::Tensor xyz1, torch::Tensor xyz2,
7+
torch::Tensor idx1, torch::Tensor idx2,
8+
torch::Tensor grad_dist1, torch::Tensor grad_dist2);
9+
10+
std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2);
11+
12+
std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2,
13+
torch::Tensor idx1, torch::Tensor idx2,
14+
torch::Tensor grad_dist1,
15+
torch::Tensor grad_dist2);

cuda/src/bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "ball_query.h"
2+
#include "chamfer_dist.h"
23
#include "interpolate.h"
34
#include "metrics.h"
45
#include "sampling.h"
@@ -15,4 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
1516
m.def("ball_query_partial_dense", &ball_query_partial_dense);
1617

1718
m.def("instance_iou_cuda", &instance_iou_cuda);
19+
20+
m.def("chamfer_dist", &chamfer_dist);
21+
m.def("chamfer_dist_grad", &chamfer_dist_grad);
1822
}

cuda/src/chamfer_dist.cu

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
#include <cuda.h>
2+
#include <cuda_runtime.h>
3+
#include <torch/extension.h>
4+
5+
#include <vector>
6+
7+
template <typename scalar_t>
8+
__global__ void chamfer_dist_kernel(int batch_size, int n, const scalar_t* __restrict__ xyz1, int m,
9+
const scalar_t* __restrict__ xyz2, scalar_t* __restrict__ dist,
10+
int* indexes)
11+
{
12+
const int batch = 512;
13+
__shared__ scalar_t buf[batch * 3];
14+
for (int i = blockIdx.x; i < batch_size; i += gridDim.x)
15+
{
16+
for (int k2 = 0; k2 < m; k2 += batch)
17+
{
18+
int end_k = min(m, k2 + batch) - k2;
19+
for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x)
20+
{
21+
buf[j] = xyz2[(i * m + k2) * 3 + j];
22+
}
23+
__syncthreads();
24+
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y)
25+
{
26+
scalar_t x1 = xyz1[(i * n + j) * 3 + 0];
27+
scalar_t y1 = xyz1[(i * n + j) * 3 + 1];
28+
scalar_t z1 = xyz1[(i * n + j) * 3 + 2];
29+
scalar_t best_dist = 0;
30+
int best_dist_index = 0;
31+
int end_ka = end_k - (end_k & 3);
32+
if (end_ka == batch)
33+
{
34+
for (int k = 0; k < batch; k += 4)
35+
{
36+
{
37+
scalar_t x2 = buf[k * 3 + 0] - x1;
38+
scalar_t y2 = buf[k * 3 + 1] - y1;
39+
scalar_t z2 = buf[k * 3 + 2] - z1;
40+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
41+
42+
if (k == 0 || dist < best_dist)
43+
{
44+
best_dist = dist;
45+
best_dist_index = k + k2;
46+
}
47+
}
48+
{
49+
scalar_t x2 = buf[k * 3 + 3] - x1;
50+
scalar_t y2 = buf[k * 3 + 4] - y1;
51+
scalar_t z2 = buf[k * 3 + 5] - z1;
52+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
53+
if (dist < best_dist)
54+
{
55+
best_dist = dist;
56+
best_dist_index = k + k2 + 1;
57+
}
58+
}
59+
{
60+
scalar_t x2 = buf[k * 3 + 6] - x1;
61+
scalar_t y2 = buf[k * 3 + 7] - y1;
62+
scalar_t z2 = buf[k * 3 + 8] - z1;
63+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
64+
if (dist < best_dist)
65+
{
66+
best_dist = dist;
67+
best_dist_index = k + k2 + 2;
68+
}
69+
}
70+
{
71+
scalar_t x2 = buf[k * 3 + 9] - x1;
72+
scalar_t y2 = buf[k * 3 + 10] - y1;
73+
scalar_t z2 = buf[k * 3 + 11] - z1;
74+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
75+
if (dist < best_dist)
76+
{
77+
best_dist = dist;
78+
best_dist_index = k + k2 + 3;
79+
}
80+
}
81+
}
82+
}
83+
else
84+
{
85+
for (int k = 0; k < end_ka; k += 4)
86+
{
87+
{
88+
scalar_t x2 = buf[k * 3 + 0] - x1;
89+
scalar_t y2 = buf[k * 3 + 1] - y1;
90+
scalar_t z2 = buf[k * 3 + 2] - z1;
91+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
92+
if (k == 0 || dist < best_dist)
93+
{
94+
best_dist = dist;
95+
best_dist_index = k + k2;
96+
}
97+
}
98+
{
99+
scalar_t x2 = buf[k * 3 + 3] - x1;
100+
scalar_t y2 = buf[k * 3 + 4] - y1;
101+
scalar_t z2 = buf[k * 3 + 5] - z1;
102+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
103+
if (dist < best_dist)
104+
{
105+
best_dist = dist;
106+
best_dist_index = k + k2 + 1;
107+
}
108+
}
109+
{
110+
scalar_t x2 = buf[k * 3 + 6] - x1;
111+
scalar_t y2 = buf[k * 3 + 7] - y1;
112+
scalar_t z2 = buf[k * 3 + 8] - z1;
113+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
114+
if (dist < best_dist)
115+
{
116+
best_dist = dist;
117+
best_dist_index = k + k2 + 2;
118+
}
119+
}
120+
{
121+
scalar_t x2 = buf[k * 3 + 9] - x1;
122+
scalar_t y2 = buf[k * 3 + 10] - y1;
123+
scalar_t z2 = buf[k * 3 + 11] - z1;
124+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
125+
if (dist < best_dist)
126+
{
127+
best_dist = dist;
128+
best_dist_index = k + k2 + 3;
129+
}
130+
}
131+
}
132+
}
133+
for (int k = end_ka; k < end_k; k++)
134+
{
135+
scalar_t x2 = buf[k * 3 + 0] - x1;
136+
scalar_t y2 = buf[k * 3 + 1] - y1;
137+
scalar_t z2 = buf[k * 3 + 2] - z1;
138+
scalar_t dist = x2 * x2 + y2 * y2 + z2 * z2;
139+
if (k == 0 || dist < best_dist)
140+
{
141+
best_dist = dist;
142+
best_dist_index = k + k2;
143+
}
144+
}
145+
if (k2 == 0 || dist[(i * n + j)] > best_dist)
146+
{
147+
dist[(i * n + j)] = best_dist;
148+
indexes[(i * n + j)] = best_dist_index;
149+
}
150+
}
151+
__syncthreads();
152+
}
153+
}
154+
}
155+
156+
std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2)
157+
{
158+
const int batch_size = xyz1.size(0);
159+
const int n = xyz1.size(1); // num_points point cloud A
160+
const int m = xyz2.size(1); // num_points point cloud B
161+
torch::Tensor dist1 = torch::zeros({batch_size, n}, torch::CUDA(xyz1.scalar_type()));
162+
torch::Tensor dist2 = torch::zeros({batch_size, m}, torch::CUDA(xyz1.scalar_type()));
163+
torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));
164+
torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));
165+
166+
AT_DISPATCH_FLOATING_TYPES(
167+
xyz1.scalar_type(), "chamfer_dist_cuda", ([&] {
168+
chamfer_dist_kernel<scalar_t><<<dim3(32, 16, 1), 512>>>(
169+
batch_size, n, xyz1.data_ptr<scalar_t>(), m, xyz2.data_ptr<scalar_t>(),
170+
dist1.data_ptr<scalar_t>(), idx1.data_ptr<int>());
171+
172+
chamfer_dist_kernel<scalar_t><<<dim3(32, 16, 1), 512>>>(
173+
batch_size, m, xyz2.data_ptr<scalar_t>(), n, xyz1.data_ptr<scalar_t>(),
174+
dist2.data_ptr<scalar_t>(), idx2.data_ptr<int>());
175+
}));
176+
177+
cudaError_t err = cudaGetLastError();
178+
if (err != cudaSuccess)
179+
{
180+
printf("Error in chamfer_dist_kernel_wrapper: %s\n", cudaGetErrorString(err));
181+
}
182+
return {dist1, dist2, idx1, idx2};
183+
}
184+
185+
template <typename scalar_t>
186+
__global__ void chamfer_dist_grad_kernel(int b, int n, const scalar_t* __restrict__ xyz1, int m,
187+
const scalar_t* __restrict__ xyz2,
188+
const scalar_t* __restrict__ grad_dist1, const int* idx1,
189+
scalar_t* __restrict__ grad_xyz1,
190+
scalar_t* __restrict__ grad_xyz2)
191+
{
192+
for (int i = blockIdx.x; i < b; i += gridDim.x)
193+
{
194+
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y)
195+
{
196+
scalar_t x1 = xyz1[(i * n + j) * 3 + 0];
197+
scalar_t y1 = xyz1[(i * n + j) * 3 + 1];
198+
scalar_t z1 = xyz1[(i * n + j) * 3 + 2];
199+
int j2 = idx1[i * n + j];
200+
scalar_t x2 = xyz2[(i * m + j2) * 3 + 0];
201+
scalar_t y2 = xyz2[(i * m + j2) * 3 + 1];
202+
scalar_t z2 = xyz2[(i * m + j2) * 3 + 2];
203+
scalar_t g = grad_dist1[i * n + j] * 2;
204+
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));
205+
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));
206+
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));
207+
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));
208+
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));
209+
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));
210+
}
211+
}
212+
}
213+
214+
std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2,
215+
torch::Tensor idx1, torch::Tensor idx2,
216+
torch::Tensor grad_dist1,
217+
torch::Tensor grad_dist2)
218+
{
219+
const int batch_size = xyz1.size(0);
220+
const int n = xyz1.size(1); // num_points point cloud A
221+
const int m = xyz2.size(1); // num_points point cloud B
222+
torch::Tensor grad_xyz1 = torch::zeros_like(xyz1);
223+
torch::Tensor grad_xyz2 = torch::zeros_like(xyz2);
224+
225+
AT_DISPATCH_FLOATING_TYPES(
226+
xyz1.scalar_type(), "chamfer_dist_grad_cuda", ([&] {
227+
chamfer_dist_grad_kernel<scalar_t><<<dim3(1, 16, 1), 256>>>(
228+
batch_size, n, xyz1.data_ptr<scalar_t>(), m, xyz2.data_ptr<scalar_t>(),
229+
grad_dist1.data_ptr<scalar_t>(), idx1.data_ptr<int>(),
230+
grad_xyz1.data_ptr<scalar_t>(), grad_xyz2.data_ptr<scalar_t>());
231+
232+
chamfer_dist_grad_kernel<scalar_t><<<dim3(1, 16, 1), 256>>>(
233+
batch_size, m, xyz2.data_ptr<scalar_t>(), n, xyz1.data_ptr<scalar_t>(),
234+
grad_dist2.data_ptr<scalar_t>(), idx2.data_ptr<int>(),
235+
grad_xyz2.data_ptr<scalar_t>(), grad_xyz1.data_ptr<scalar_t>());
236+
}));
237+
238+
cudaError_t err = cudaGetLastError();
239+
if (err != cudaSuccess)
240+
{
241+
printf("Error in chamfer_dist_grad_kernel_wrapper: %s\n", cudaGetErrorString(err));
242+
}
243+
return {grad_xyz1, grad_xyz2};
244+
}

cuda/src/chamfer_dist_gpu.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include "chamfer_dist.h"
2+
3+
std::vector<torch::Tensor> chamfer_dist(torch::Tensor xyz1, torch::Tensor xyz2)
4+
{
5+
return chamfer_dist_kernel_wrapper(xyz1, xyz2);
6+
}
7+
8+
std::vector<torch::Tensor> chamfer_dist_grad(torch::Tensor xyz1, torch::Tensor xyz2,
9+
torch::Tensor idx1, torch::Tensor idx2,
10+
torch::Tensor grad_dist1, torch::Tensor grad_dist2)
11+
{
12+
return chamfer_dist_grad_kernel_wrapper(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);
13+
}

test/test_chamfer_dist.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
import os
3+
import sys
4+
import torch
5+
import unittest
6+
7+
from torch.autograd import gradcheck
8+
9+
from . import run_if_cuda
10+
11+
12+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
13+
sys.path.insert(0, ROOT)
14+
15+
from torch_points_kernels import ChamferFunction, chamfer_dist
16+
17+
18+
class TestChamferDistance(unittest.TestCase):
19+
@run_if_cuda
20+
def test_chamfer_dist_grad(self):
21+
x = torch.rand(4, 64, 3).double()
22+
y = torch.rand(4, 128, 3).double()
23+
x.requires_grad = True
24+
y.requires_grad = True
25+
test = gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()])
26+
27+
@run_if_cuda
28+
def test_chamfer_dist(self):
29+
xyz1 = torch.from_numpy(np.array([[
30+
[0, 0, 0],
31+
[1, 1, 1],
32+
[2, 0, 1]
33+
]])).float()
34+
xyz2 = torch.from_numpy(np.array([[[1, 0, 0], [1, 2, 1]]])).float()
35+
dist = chamfer_dist(xyz1.cuda(), xyz2.cuda())
36+
self.assertAlmostEqual(dist.item(), 2.333333, places=5)
37+
38+
@run_if_cuda
39+
def test_chamfer_dist_ignore_zeros(self):
40+
xyz1 = torch.from_numpy(np.array([[
41+
[0, 0, 0],
42+
[1, 1, 1],
43+
[2, 0, 1]
44+
]])).float()
45+
xyz2 = torch.from_numpy(np.array([[[1, 0, 0], [1, 2, 1]]])).float()
46+
dist = chamfer_dist(xyz1.cuda(), xyz2.cuda(), True)
47+
self.assertAlmostEqual(dist.item(), 3.0, places=5)
48+
49+
50+
if __name__ == "__main__":
51+
unittest.main()

torch_points_kernels/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
"knn",
1313
"region_grow",
1414
"instance_iou",
15+
"chamfer_dist"
1516
]

0 commit comments

Comments
 (0)