|
| 1 | +#include <cuda.h> |
| 2 | +#include <cuda_runtime.h> |
| 3 | +#include <torch/extension.h> |
| 4 | + |
| 5 | +#include <vector> |
| 6 | + |
| 7 | +__global__ void chamfer_dist_kernel(int batch_size, int n, const float* xyz1, int m, |
| 8 | + const float* xyz2, float* dist, int* indexes) |
| 9 | +{ |
| 10 | + const int batch = 512; |
| 11 | + __shared__ float buf[batch * 3]; |
| 12 | + for (int i = blockIdx.x; i < batch_size; i += gridDim.x) |
| 13 | + { |
| 14 | + for (int k2 = 0; k2 < m; k2 += batch) |
| 15 | + { |
| 16 | + int end_k = min(m, k2 + batch) - k2; |
| 17 | + for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) |
| 18 | + { |
| 19 | + buf[j] = xyz2[(i * m + k2) * 3 + j]; |
| 20 | + } |
| 21 | + __syncthreads(); |
| 22 | + for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y) |
| 23 | + { |
| 24 | + float x1 = xyz1[(i * n + j) * 3 + 0]; |
| 25 | + float y1 = xyz1[(i * n + j) * 3 + 1]; |
| 26 | + float z1 = xyz1[(i * n + j) * 3 + 2]; |
| 27 | + float best_dist = 0; |
| 28 | + int best_dist_index = 0; |
| 29 | + int end_ka = end_k - (end_k & 3); |
| 30 | + if (end_ka == batch) |
| 31 | + { |
| 32 | + for (int k = 0; k < batch; k += 4) |
| 33 | + { |
| 34 | + { |
| 35 | + float x2 = buf[k * 3 + 0] - x1; |
| 36 | + float y2 = buf[k * 3 + 1] - y1; |
| 37 | + float z2 = buf[k * 3 + 2] - z1; |
| 38 | + float dist = x2 * x2 + y2 * y2 + z2 * z2; |
| 39 | + |
| 40 | + if (k == 0 || dist < best_dist) |
| 41 | + { |
| 42 | + best_dist = dist; |
| 43 | + best_dist_index = k + k2; |
| 44 | + } |
| 45 | + } |
| 46 | + { |
| 47 | + float x2 = buf[k * 3 + 3] - x1; |
| 48 | + float y2 = buf[k * 3 + 4] - y1; |
| 49 | + float z2 = buf[k * 3 + 5] - z1; |
| 50 | + float dist = x2 * x2 + y2 * y2 + z2 * z2; |
| 51 | + if (dist < best_dist) |
| 52 | + { |
| 53 | + best_dist = dist; |
| 54 | + best_dist_index = k + k2 + 1; |
| 55 | + } |
| 56 | + } |
| 57 | + { |
| 58 | + float x2 = buf[k * 3 + 6] - x1; |
| 59 | + float y2 = buf[k * 3 + 7] - y1; |
| 60 | + float z2 = buf[k * 3 + 8] - z1; |
| 61 | + float dist = x2 * x2 + y2 * y2 + z2 * z2; |
| 62 | + if (dist < best_dist) |
| 63 | + { |
| 64 | + best_dist = dist; |
| 65 | + best_dist_index = k + k2 + 2; |
| 66 | + } |
| 67 | + } |
| 68 | + { |
| 69 | + float x2 = buf[k * 3 + 9] - x1; |
| 70 | + float y2 = buf[k * 3 + 10] - y1; |
| 71 | + float z2 = buf[k * 3 + 11] - z1; |
| 72 | + float dist = x2 * x2 + y2 * y2 + z2 * z2; |
| 73 | + if (dist < best_dist) |
| 74 | + { |
| 75 | + best_dist = dist; |
| 76 | + best_dist_index = k + k2 + 3; |
| 77 | + } |
| 78 | + } |
| 79 | + } |
| 80 | + } |
| 81 | + else |
| 82 | + { |
| 83 | + for (int k = 0; k < end_ka; k += 4) |
| 84 | + { |
| 85 | + { |
| 86 | + float x2 = buf[k * 3 + 0] - x1; |
| 87 | + float y2 = buf[k * 3 + 1] - y1; |
| 88 | + float z2 = buf[k * 3 + 2] - z1; |
| 89 | + float dist = x2 * x2 + y2 * y2 + z2 * z2; |
| 90 | + if (k == 0 || dist < best_dist) |
| 91 | + { |
| 92 | + best_dist = dist; |
| 93 | + best_dist_index = k + k2; |
| 94 | + } |
| 95 | + } |
| 96 | + { |
| 97 | + float x2 = buf[k * 3 + 3] - x1; |
| 98 | + float y2 = buf[k * 3 + 4] - y1; |
| 99 | + float z2 = buf[k * 3 + 5] - z1; |
| 100 | + float dist = x2 * x2 + y2 * y2 + z2 * z2; |
| 101 | + if (dist < best_dist) |
| 102 | + { |
| 103 | + best_dist = dist; |
| 104 | + best_dist_index = k + k2 + 1; |
| 105 | + } |
| 106 | + } |
| 107 | + { |
| 108 | + float x2 = buf[k * 3 + 6] - x1; |
| 109 | + float y2 = buf[k * 3 + 7] - y1; |
| 110 | + float z2 = buf[k * 3 + 8] - z1; |
| 111 | + float dist = x2 * x2 + y2 * y2 + z2 * z2; |
| 112 | + if (dist < best_dist) |
| 113 | + { |
| 114 | + best_dist = dist; |
| 115 | + best_dist_index = k + k2 + 2; |
| 116 | + } |
| 117 | + } |
| 118 | + { |
| 119 | + float x2 = buf[k * 3 + 9] - x1; |
| 120 | + float y2 = buf[k * 3 + 10] - y1; |
| 121 | + float z2 = buf[k * 3 + 11] - z1; |
| 122 | + float dist = x2 * x2 + y2 * y2 + z2 * z2; |
| 123 | + if (dist < best_dist) |
| 124 | + { |
| 125 | + best_dist = dist; |
| 126 | + best_dist_index = k + k2 + 3; |
| 127 | + } |
| 128 | + } |
| 129 | + } |
| 130 | + } |
| 131 | + for (int k = end_ka; k < end_k; k++) |
| 132 | + { |
| 133 | + float x2 = buf[k * 3 + 0] - x1; |
| 134 | + float y2 = buf[k * 3 + 1] - y1; |
| 135 | + float z2 = buf[k * 3 + 2] - z1; |
| 136 | + float dist = x2 * x2 + y2 * y2 + z2 * z2; |
| 137 | + if (k == 0 || dist < best_dist) |
| 138 | + { |
| 139 | + best_dist = dist; |
| 140 | + best_dist_index = k + k2; |
| 141 | + } |
| 142 | + } |
| 143 | + if (k2 == 0 || dist[(i * n + j)] > best_dist) |
| 144 | + { |
| 145 | + dist[(i * n + j)] = best_dist; |
| 146 | + indexes[(i * n + j)] = best_dist_index; |
| 147 | + } |
| 148 | + } |
| 149 | + __syncthreads(); |
| 150 | + } |
| 151 | + } |
| 152 | +} |
| 153 | + |
| 154 | +std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2) |
| 155 | +{ |
| 156 | + const int batch_size = xyz1.size(0); |
| 157 | + const int n = xyz1.size(1); // num_points point cloud A |
| 158 | + const int m = xyz2.size(1); // num_points point cloud B |
| 159 | + torch::Tensor dist1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat)); |
| 160 | + torch::Tensor dist2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat)); |
| 161 | + torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt)); |
| 162 | + torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt)); |
| 163 | + |
| 164 | + chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(batch_size, n, xyz1.data_ptr<float>(), m, |
| 165 | + xyz2.data_ptr<float>(), dist1.data_ptr<float>(), |
| 166 | + idx1.data_ptr<int>()); |
| 167 | + chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(batch_size, m, xyz2.data_ptr<float>(), n, |
| 168 | + xyz1.data_ptr<float>(), dist2.data_ptr<float>(), |
| 169 | + idx2.data_ptr<int>()); |
| 170 | + |
| 171 | + cudaError_t err = cudaGetLastError(); |
| 172 | + if (err != cudaSuccess) |
| 173 | + { |
| 174 | + printf("Error in chamfer_dist_kernel_wrapper: %s\n", cudaGetErrorString(err)); |
| 175 | + } |
| 176 | + return {dist1, dist2, idx1, idx2}; |
| 177 | +} |
| 178 | + |
| 179 | +__global__ void chamfer_dist_grad_kernel(int b, int n, const float* xyz1, int m, const float* xyz2, |
| 180 | + const float* grad_dist1, const int* idx1, float* grad_xyz1, |
| 181 | + float* grad_xyz2) |
| 182 | +{ |
| 183 | + for (int i = blockIdx.x; i < b; i += gridDim.x) |
| 184 | + { |
| 185 | + for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y) |
| 186 | + { |
| 187 | + float x1 = xyz1[(i * n + j) * 3 + 0]; |
| 188 | + float y1 = xyz1[(i * n + j) * 3 + 1]; |
| 189 | + float z1 = xyz1[(i * n + j) * 3 + 2]; |
| 190 | + int j2 = idx1[i * n + j]; |
| 191 | + float x2 = xyz2[(i * m + j2) * 3 + 0]; |
| 192 | + float y2 = xyz2[(i * m + j2) * 3 + 1]; |
| 193 | + float z2 = xyz2[(i * m + j2) * 3 + 2]; |
| 194 | + float g = grad_dist1[i * n + j] * 2; |
| 195 | + atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2)); |
| 196 | + atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2)); |
| 197 | + atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2)); |
| 198 | + atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2))); |
| 199 | + atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2))); |
| 200 | + atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2))); |
| 201 | + } |
| 202 | + } |
| 203 | +} |
| 204 | + |
| 205 | +std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1, torch::Tensor xyz2, |
| 206 | + torch::Tensor idx1, torch::Tensor idx2, |
| 207 | + torch::Tensor grad_dist1, |
| 208 | + torch::Tensor grad_dist2) |
| 209 | +{ |
| 210 | + const int batch_size = xyz1.size(0); |
| 211 | + const int n = xyz1.size(1); // num_points point cloud A |
| 212 | + const int m = xyz2.size(1); // num_points point cloud B |
| 213 | + torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat)); |
| 214 | + torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat)); |
| 215 | + |
| 216 | + chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>( |
| 217 | + batch_size, n, xyz1.data_ptr<float>(), m, xyz2.data_ptr<float>(), |
| 218 | + grad_dist1.data_ptr<float>(), idx1.data_ptr<int>(), grad_xyz1.data_ptr<float>(), |
| 219 | + grad_xyz2.data_ptr<float>()); |
| 220 | + chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>( |
| 221 | + batch_size, m, xyz2.data_ptr<float>(), n, xyz1.data_ptr<float>(), |
| 222 | + grad_dist2.data_ptr<float>(), idx2.data_ptr<int>(), grad_xyz2.data_ptr<float>(), |
| 223 | + grad_xyz1.data_ptr<float>()); |
| 224 | + |
| 225 | + cudaError_t err = cudaGetLastError(); |
| 226 | + if (err != cudaSuccess) |
| 227 | + { |
| 228 | + printf("Error in chamfer_dist_grad_kernel_wrapper: %s\n", cudaGetErrorString(err)); |
| 229 | + } |
| 230 | + return {grad_xyz1, grad_xyz2}; |
| 231 | +} |
0 commit comments