Skip to content

Commit e9ee3bd

Browse files
committed
Make Chamfer Distance available for both float and double.
1 parent a50d3f2 commit e9ee3bd

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,5 @@ See [this useful chart](http://arnon.dk/matching-sm-architectures-arch-and-genco
6262
* [```Pointnet2_Tensorflow```](https://github.com/charlesq34/pointnet2) by [Charles R. Qi](https://github.com/charlesq34)
6363

6464
* [```Pointnet2_PyTorch```](https://github.com/erikwijmans/Pointnet2_PyTorch) by [Erik Wijmans](https://github.com/erikwijmans)
65+
66+
* [```GRNet```](https://github.com/hzxie/GRNet) by [Haozhe Xie](https://github.com/hzxie)

cuda/src/chamfer_dist.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch
158158
const int batch_size = xyz1.size(0);
159159
const int n = xyz1.size(1); // num_points point cloud A
160160
const int m = xyz2.size(1); // num_points point cloud B
161-
torch::Tensor dist1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat));
162-
torch::Tensor dist2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat));
161+
torch::Tensor dist1 = torch::zeros({batch_size, n}, torch::CUDA(xyz1.scalar_type()));
162+
torch::Tensor dist2 = torch::zeros({batch_size, m}, torch::CUDA(xyz1.scalar_type()));
163163
torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));
164164
torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));
165165

@@ -219,8 +219,8 @@ std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1,
219219
const int batch_size = xyz1.size(0);
220220
const int n = xyz1.size(1); // num_points point cloud A
221221
const int m = xyz2.size(1); // num_points point cloud B
222-
torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat));
223-
torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat));
222+
torch::Tensor grad_xyz1 = torch::zeros_like(xyz1);
223+
torch::Tensor grad_xyz2 = torch::zeros_like(xyz2);
224224

225225
AT_DISPATCH_FLOATING_TYPES(
226226
xyz1.scalar_type(), "chamfer_dist_grad_cuda", ([&] {

0 commit comments

Comments
 (0)