Skip to content

Commit 652c179

Browse files
committed
Create the implementation for Chamfer Distance. [ci skip]
1 parent 7994646 commit 652c179

File tree

6 files changed

+306
-0
lines changed

6 files changed

+306
-0
lines changed

cuda/include/chamfer_dist.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#include <torch/extension.h>
2+
#include <vector>
3+
4+
std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2);
5+
6+
std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2,
7+
torch::Tensor idx1, torch::Tensor idx2,
8+
torch::Tensor grad_dist1,
9+
torch::Tensor grad_dist2);

cuda/src/chamfer_dist.cu

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
#include <cuda.h>
2+
#include <cuda_runtime.h>
3+
#include <torch/extension.h>
4+
5+
#include <vector>
6+
7+
__global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, int m,
8+
const float* xyz2, float* dist, int* indexes)
9+
{
10+
const int batch = 512;
11+
__shared__ float buf[batch * 3];
12+
for (int i = blockIdx.x; i < batch_size; i += gridDim.x)
13+
{
14+
for (int k2 = 0; k2 < m; k2 += batch)
15+
{
16+
int end_k = min(m, k2 + batch) - k2;
17+
for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x)
18+
{
19+
buf[j] = xyz2[(i * m + k2) * 3 + j];
20+
}
21+
__syncthreads();
22+
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y)
23+
{
24+
float x1 = xyz1[(i * n + j) * 3 + 0];
25+
float y1 = xyz1[(i * n + j) * 3 + 1];
26+
float z1 = xyz1[(i * n + j) * 3 + 2];
27+
float best_dist = 0;
28+
int best_dist_index = 0;
29+
int end_ka = end_k - (end_k & 3);
30+
if (end_ka == batch)
31+
{
32+
for (int k = 0; k < batch; k += 4)
33+
{
34+
{
35+
float x2 = buf[k * 3 + 0] - x1;
36+
float y2 = buf[k * 3 + 1] - y1;
37+
float z2 = buf[k * 3 + 2] - z1;
38+
float dist = x2 * x2 + y2 * y2 + z2 * z2;
39+
40+
if (k == 0 || dist < best_dist)
41+
{
42+
best_dist = dist;
43+
best_dist_index = k + k2;
44+
}
45+
}
46+
{
47+
float x2 = buf[k * 3 + 3] - x1;
48+
float y2 = buf[k * 3 + 4] - y1;
49+
float z2 = buf[k * 3 + 5] - z1;
50+
float dist = x2 * x2 + y2 * y2 + z2 * z2;
51+
if (dist < best_dist)
52+
{
53+
best_dist = dist;
54+
best_dist_index = k + k2 + 1;
55+
}
56+
}
57+
{
58+
float x2 = buf[k * 3 + 6] - x1;
59+
float y2 = buf[k * 3 + 7] - y1;
60+
float z2 = buf[k * 3 + 8] - z1;
61+
float dist = x2 * x2 + y2 * y2 + z2 * z2;
62+
if (dist < best_dist)
63+
{
64+
best_dist = dist;
65+
best_dist_index = k + k2 + 2;
66+
}
67+
}
68+
{
69+
float x2 = buf[k * 3 + 9] - x1;
70+
float y2 = buf[k * 3 + 10] - y1;
71+
float z2 = buf[k * 3 + 11] - z1;
72+
float dist = x2 * x2 + y2 * y2 + z2 * z2;
73+
if (dist < best_dist)
74+
{
75+
best_dist = dist;
76+
best_dist_index = k + k2 + 3;
77+
}
78+
}
79+
}
80+
}
81+
else
82+
{
83+
for (int k = 0; k < end_ka; k += 4)
84+
{
85+
{
86+
float x2 = buf[k * 3 + 0] - x1;
87+
float y2 = buf[k * 3 + 1] - y1;
88+
float z2 = buf[k * 3 + 2] - z1;
89+
float dist = x2 * x2 + y2 * y2 + z2 * z2;
90+
if (k == 0 || dist < best_dist)
91+
{
92+
best_dist = dist;
93+
best_dist_index = k + k2;
94+
}
95+
}
96+
{
97+
float x2 = buf[k * 3 + 3] - x1;
98+
float y2 = buf[k * 3 + 4] - y1;
99+
float z2 = buf[k * 3 + 5] - z1;
100+
float dist = x2 * x2 + y2 * y2 + z2 * z2;
101+
if (dist < best_dist)
102+
{
103+
best_dist = dist;
104+
best_dist_index = k + k2 + 1;
105+
}
106+
}
107+
{
108+
float x2 = buf[k * 3 + 6] - x1;
109+
float y2 = buf[k * 3 + 7] - y1;
110+
float z2 = buf[k * 3 + 8] - z1;
111+
float dist = x2 * x2 + y2 * y2 + z2 * z2;
112+
if (dist < best_dist)
113+
{
114+
best_dist = dist;
115+
best_dist_index = k + k2 + 2;
116+
}
117+
}
118+
{
119+
float x2 = buf[k * 3 + 9] - x1;
120+
float y2 = buf[k * 3 + 10] - y1;
121+
float z2 = buf[k * 3 + 11] - z1;
122+
float dist = x2 * x2 + y2 * y2 + z2 * z2;
123+
if (dist < best_dist)
124+
{
125+
best_dist = dist;
126+
best_dist_index = k + k2 + 3;
127+
}
128+
}
129+
}
130+
}
131+
for (int k = end_ka; k < end_k; k++)
132+
{
133+
float x2 = buf[k * 3 + 0] - x1;
134+
float y2 = buf[k * 3 + 1] - y1;
135+
float z2 = buf[k * 3 + 2] - z1;
136+
float dist = x2 * x2 + y2 * y2 + z2 * z2;
137+
if (k == 0 || dist < best_dist)
138+
{
139+
best_dist = dist;
140+
best_dist_index = k + k2;
141+
}
142+
}
143+
if (k2 == 0 || dist[(i * n + j)] > best_dist)
144+
{
145+
dist[(i * n + j)] = best_dist;
146+
indexes[(i * n + j)] = best_dist_index;
147+
}
148+
}
149+
__syncthreads();
150+
}
151+
}
152+
}
153+
154+
std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2)
155+
{
156+
const int batch_size = xyz1.size(0);
157+
const int n = xyz1.size(1); // num_points point cloud A
158+
const int m = xyz2.size(1); // num_points point cloud B
159+
torch::Tensor dist1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat));
160+
torch::Tensor dist2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat));
161+
torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));
162+
torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));
163+
164+
chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(batch_size, n, xyz1.data_ptr<float>(), m,
165+
xyz2.data_ptr<float>(), dist1.data_ptr<float>(),
166+
idx1.data_ptr<int>());
167+
chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(batch_size, m, xyz2.data_ptr<float>(), n,
168+
xyz1.data_ptr<float>(), dist2.data_ptr<float>(),
169+
idx2.data_ptr<int>());
170+
171+
cudaError_t err = cudaGetLastError();
172+
if (err != cudaSuccess)
173+
{
174+
printf("Error in chamfer_dist_kernel_wrapper: %s\n", cudaGetErrorString(err));
175+
}
176+
return {dist1, dist2, idx1, idx2};
177+
}
178+
179+
__global__ void chamfer_dist_grad_kernel(int b, int n, const float* xyz1, int m, const float* xyz2,
180+
const float* grad_dist1, const int* idx1, float* grad_xyz1,
181+
float* grad_xyz2)
182+
{
183+
for (int i = blockIdx.x; i < b; i += gridDim.x)
184+
{
185+
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y)
186+
{
187+
float x1 = xyz1[(i * n + j) * 3 + 0];
188+
float y1 = xyz1[(i * n + j) * 3 + 1];
189+
float z1 = xyz1[(i * n + j) * 3 + 2];
190+
int j2 = idx1[i * n + j];
191+
float x2 = xyz2[(i * m + j2) * 3 + 0];
192+
float y2 = xyz2[(i * m + j2) * 3 + 1];
193+
float z2 = xyz2[(i * m + j2) * 3 + 2];
194+
float g = grad_dist1[i * n + j] * 2;
195+
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));
196+
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));
197+
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));
198+
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));
199+
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));
200+
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));
201+
}
202+
}
203+
}
204+
205+
std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2,
206+
torch::Tensor idx1, torch::Tensor idx2,
207+
torch::Tensor grad_dist1,
208+
torch::Tensor grad_dist2)
209+
{
210+
const int batch_size = xyz1.size(0);
211+
const int n = xyz1.size(1); // num_points point cloud A
212+
const int m = xyz2.size(1); // num_points point cloud B
213+
torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat));
214+
torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat));
215+
216+
chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>(
217+
batch_size, n, xyz1.data_ptr<float>(), m, xyz2.data_ptr<float>(),
218+
grad_dist1.data_ptr<float>(), idx1.data_ptr<int>(), grad_xyz1.data_ptr<float>(),
219+
grad_xyz2.data_ptr<float>());
220+
chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>(
221+
batch_size, m, xyz2.data_ptr<float>(), n, xyz1.data_ptr<float>(),
222+
grad_dist2.data_ptr<float>(), idx2.data_ptr<int>(), grad_xyz2.data_ptr<float>(),
223+
grad_xyz1.data_ptr<float>());
224+
225+
cudaError_t err = cudaGetLastError();
226+
if (err != cudaSuccess)
227+
{
228+
printf("Error in chamfer_dist_grad_kernel_wrapper: %s\n", cudaGetErrorString(err));
229+
}
230+
return {grad_xyz1, grad_xyz2};
231+
}

cuda/src/chamfer_dist_gpu.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include "chamfer_dist.h"
2+
3+
std::vector<torch::Tensor> chamfer_dist(torch::Tensor xyz1, torch::Tensor xyz2)
4+
{
5+
return chamfer_dist_kernel_wrapper(xyz1, xyz2);
6+
}
7+
8+
std::vector<torch::Tensor> chamfer_dist_grad(torch::Tensor xyz1, torch::Tensor xyz2,
9+
torch::Tensor idx1, torch::Tensor idx2,
10+
torch::Tensor grad_dist1, torch::Tensor grad_dist2)
11+
{
12+
return chamfer_dist_grad_kernel_wrapper(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);
13+
}

test/test_chamfer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
import sys
3+
import torch
4+
import unittest
5+
6+
from torch.autograd import gradcheck
7+
8+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
9+
sys.path.insert(0, ROOT)
10+
11+
from torch_points_kernels import ChamferFunction
12+
13+
14+
class TestChamferDistance(unittest.TestCase):
15+
def test_chamfer_dist(self):
16+
x = torch.rand(4, 64, 3).double()
17+
y = torch.rand(4, 128, 3).double()
18+
x.requires_grad = True
19+
y.requires_grad = True
20+
test = gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()])
21+
22+
23+
if __name__ == '__main__':
24+
unittest.main()

torch_points_kernels/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
"knn",
1313
"region_grow",
1414
"instance_iou",
15+
"chamfer_dist"
1516
]

torch_points_kernels/torchpoints.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,31 @@ def ball_query(
235235
return ball_query_dense(radius, nsample, x, y, sort=sort)
236236
else:
237237
raise Exception("unrecognized mode {}".format(mode))
238+
239+
240+
class ChamferFunction(Function):
241+
@staticmethod
242+
def forward(ctx, xyz1, xyz2):
243+
dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2)
244+
print(dir(tpcuda))
245+
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
246+
247+
return dist1, dist2
248+
249+
@staticmethod
250+
def backward(ctx, grad_dist1, grad_dist2):
251+
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
252+
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
253+
return grad_xyz1, grad_xyz2
254+
255+
256+
def chamfer_dist(self, xyz1, xyz2, ignore_zeros=False):
257+
batch_size = xyz1.size(0)
258+
if batch_size == 1 and ignore_zeros:
259+
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
260+
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
261+
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
262+
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
263+
264+
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
265+
return torch.mean(dist1) + torch.mean(dist2)

0 commit comments

Comments
 (0)