@@ -158,8 +158,8 @@ std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch
158158 const int batch_size = xyz1.size (0 );
159159 const int n = xyz1.size (1 ); // num_points point cloud A
160160 const int m = xyz2.size (1 ); // num_points point cloud B
161- torch::Tensor dist1 = torch::zeros ({batch_size, n}, torch::CUDA (torch:: kFloat ));
162- torch::Tensor dist2 = torch::zeros ({batch_size, m}, torch::CUDA (torch:: kFloat ));
161+ torch::Tensor dist1 = torch::zeros ({batch_size, n}, torch::CUDA (xyz1. scalar_type () ));
162+ torch::Tensor dist2 = torch::zeros ({batch_size, m}, torch::CUDA (xyz1. scalar_type () ));
163163 torch::Tensor idx1 = torch::zeros ({batch_size, n}, torch::CUDA (torch::kInt ));
164164 torch::Tensor idx2 = torch::zeros ({batch_size, m}, torch::CUDA (torch::kInt ));
165165
@@ -219,8 +219,8 @@ std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1,
219219 const int batch_size = xyz1.size (0 );
220220 const int n = xyz1.size (1 ); // num_points point cloud A
221221 const int m = xyz2.size (1 ); // num_points point cloud B
222- torch::Tensor grad_xyz1 = torch::zeros_like (xyz1, torch::CUDA (torch:: kFloat ) );
223- torch::Tensor grad_xyz2 = torch::zeros_like (xyz2, torch::CUDA (torch:: kFloat ) );
222+ torch::Tensor grad_xyz1 = torch::zeros_like (xyz1);
223+ torch::Tensor grad_xyz2 = torch::zeros_like (xyz2);
224224
225225 AT_DISPATCH_FLOATING_TYPES (
226226 xyz1.scalar_type (), " chamfer_dist_grad_cuda" , ([&] {
0 commit comments