|
| 1 | +#include <cmath> |
| 2 | +#include <cstdio> |
| 3 | +#include <cstdlib> |
| 4 | +#include <torch/extension.h> |
| 5 | + |
| 6 | +#include "cuda_utils.h" |
| 7 | + |
| 8 | +#define CUDA_NUM_THREADS 512 |
| 9 | + |
| 10 | +// Computer the number of threads needed in GPU |
| 11 | +inline int get_n_threads(int n) |
| 12 | +{ |
| 13 | + const int pow_2 = std::log(static_cast<float>(n)) / std::log(2.0); |
| 14 | + return max(min(1 << pow_2, CUDA_NUM_THREADS), 1); |
| 15 | +} |
| 16 | + |
| 17 | +__device__ int compute_index(int offset_x, int offset_y, int offset_z, int scale) |
| 18 | +{ |
| 19 | + return offset_x * scale * scale + offset_y * scale + offset_z; |
| 20 | +} |
| 21 | + |
| 22 | +template <typename scalar_t> |
| 23 | +__global__ void cubic_feature_sampling_kernel(int scale, int neighborhood_size, int n_vertices, |
| 24 | + int n_pts, int n_cubic_channels, |
| 25 | + const scalar_t* __restrict__ ptcloud, |
| 26 | + const scalar_t* __restrict__ cubic_features, |
| 27 | + scalar_t* __restrict__ point_features, |
| 28 | + int* __restrict__ grid_pt_indexes) |
| 29 | +{ |
| 30 | + int batch_index = blockIdx.x; |
| 31 | + int index = threadIdx.x; |
| 32 | + int stride = blockDim.x; |
| 33 | + int cub_scale = scale * scale * scale; |
| 34 | + |
| 35 | + ptcloud += batch_index * n_pts * 3; |
| 36 | + cubic_features += batch_index * n_cubic_channels * cub_scale; |
| 37 | + point_features += batch_index * n_pts * n_vertices * n_cubic_channels; |
| 38 | + grid_pt_indexes += batch_index * n_pts * n_vertices; |
| 39 | + |
| 40 | + for (int i = index; i < n_pts; i += stride) |
| 41 | + { |
| 42 | + scalar_t pt_x = ptcloud[i * 3 + 0]; |
| 43 | + scalar_t pt_y = ptcloud[i * 3 + 1]; |
| 44 | + scalar_t pt_z = ptcloud[i * 3 + 2]; |
| 45 | + |
| 46 | + int lower_x = std::floor(pt_x); |
| 47 | + int upper_x = std::ceil(pt_x); |
| 48 | + if (lower_x == upper_x) |
| 49 | + { |
| 50 | + upper_x += 1; |
| 51 | + } |
| 52 | + int lower_y = std::floor(pt_y); |
| 53 | + int upper_y = std::ceil(pt_y); |
| 54 | + if (lower_y == upper_y) |
| 55 | + { |
| 56 | + upper_y += 1; |
| 57 | + } |
| 58 | + int lower_z = std::floor(pt_z); |
| 59 | + int upper_z = std::ceil(pt_z); |
| 60 | + if (lower_z == upper_z) |
| 61 | + { |
| 62 | + upper_z += 1; |
| 63 | + } |
| 64 | + |
| 65 | + int ns = neighborhood_size - 1; |
| 66 | + int vertex_idx = 0; |
| 67 | + for (int j = lower_x - ns; j <= upper_x + ns; ++j) |
| 68 | + { |
| 69 | + for (int k = lower_y - ns; k <= upper_y + ns; ++k) |
| 70 | + { |
| 71 | + for (int m = lower_z - ns; m <= upper_z + ns; ++m) |
| 72 | + { |
| 73 | + if (j < 0 || j >= scale || k < 0 || k >= scale || m < 0 || m >= scale) |
| 74 | + { |
| 75 | + // Ignore points lies out of the grid |
| 76 | + grid_pt_indexes[i * n_vertices + vertex_idx++] = -1; |
| 77 | + } |
| 78 | + else |
| 79 | + { |
| 80 | + // Calcuating indexes for adjacent vertices |
| 81 | + grid_pt_indexes[i * n_vertices + vertex_idx++] = |
| 82 | + compute_index(j, k, m, scale); |
| 83 | + } |
| 84 | + } |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + // Gather Features |
| 89 | + for (int j = 0; j < n_vertices; ++j) |
| 90 | + { |
| 91 | + for (int k = 0; k < n_cubic_channels; ++k) |
| 92 | + { |
| 93 | + int vertex_idx = grid_pt_indexes[i * n_vertices + j]; |
| 94 | + if (vertex_idx == -1) |
| 95 | + { |
| 96 | + continue; |
| 97 | + } |
| 98 | + int feature_idx = i * n_vertices * n_cubic_channels + j * n_cubic_channels + k; |
| 99 | + scalar_t feature_val = cubic_features[k * cub_scale + vertex_idx]; |
| 100 | + point_features[feature_idx] = feature_val; |
| 101 | + } |
| 102 | + } |
| 103 | + } |
| 104 | +} |
| 105 | + |
| 106 | +std::vector<torch::Tensor> cubic_feature_sampling_kernel_wrapper(int scale, int neighborhood_size, |
| 107 | + torch::Tensor ptcloud, |
| 108 | + torch::Tensor cubic_features, |
| 109 | + cudaStream_t stream) |
| 110 | +{ |
| 111 | + int batch_size = ptcloud.size(0); |
| 112 | + int n_pts = ptcloud.size(1); |
| 113 | + int n_cubic_channels = cubic_features.size(1); |
| 114 | + |
| 115 | + int n_vertices = std::pow(neighborhood_size * 2, 3); |
| 116 | + torch::Tensor point_features = torch::zeros({batch_size, n_pts, n_vertices, n_cubic_channels}, |
| 117 | + torch::CUDA(ptcloud.scalar_type())); |
| 118 | + torch::Tensor grid_pt_indexes = |
| 119 | + torch::zeros({batch_size, n_pts, n_vertices}, torch::CUDA(torch::kInt)); |
| 120 | + |
| 121 | + AT_DISPATCH_FLOATING_TYPES( |
| 122 | + ptcloud.scalar_type(), "cubic_feature_sampling_cuda", ([&] { |
| 123 | + cubic_feature_sampling_kernel<<<batch_size, get_n_threads(n_pts), 0, stream>>>( |
| 124 | + scale, neighborhood_size, n_vertices, n_pts, n_cubic_channels, |
| 125 | + ptcloud.data_ptr<scalar_t>(), cubic_features.data_ptr<scalar_t>(), |
| 126 | + point_features.data_ptr<scalar_t>(), grid_pt_indexes.data_ptr<int>()); |
| 127 | + })); |
| 128 | + |
| 129 | + cudaError_t err = cudaGetLastError(); |
| 130 | + if (err != cudaSuccess) |
| 131 | + { |
| 132 | + printf("Error in cubic_feature_sampling_kernel_wrapper: %s\n", cudaGetErrorString(err)); |
| 133 | + } |
| 134 | + return {point_features, grid_pt_indexes}; |
| 135 | +} |
| 136 | + |
| 137 | +template <typename scalar_t> |
| 138 | +__global__ void cubic_feature_sampling_grad_kernel(int scale, int neighborhood_size, int n_vertices, |
| 139 | + int n_pts, int n_cubic_channels, |
| 140 | + const scalar_t* __restrict__ grad_point_features, |
| 141 | + const int* __restrict__ grid_pt_indexes, |
| 142 | + scalar_t* __restrict__ grad_ptcloud, |
| 143 | + scalar_t* __restrict__ grad_cubic_features) |
| 144 | +{ |
| 145 | + int batch_index = blockIdx.x; |
| 146 | + int index = threadIdx.x; |
| 147 | + int stride = blockDim.x; |
| 148 | + int cub_scale = scale * scale * scale; |
| 149 | + |
| 150 | + grad_point_features += batch_index * n_pts * n_vertices * n_cubic_channels; |
| 151 | + grid_pt_indexes += batch_index * n_pts * n_vertices; |
| 152 | + grad_ptcloud += batch_index * n_pts * 3; |
| 153 | + grad_cubic_features += batch_index * n_cubic_channels * cub_scale; |
| 154 | + |
| 155 | + for (int i = index; i < n_pts; i += stride) |
| 156 | + { |
| 157 | + for (int j = 0; j < n_vertices; ++j) |
| 158 | + { |
| 159 | + int vertex_idx = grid_pt_indexes[i * n_vertices + j]; |
| 160 | + if (vertex_idx == -1) |
| 161 | + { |
| 162 | + continue; |
| 163 | + } |
| 164 | + for (int k = 0; k < n_cubic_channels; ++k) |
| 165 | + { |
| 166 | + int grad_idx = i * n_vertices * n_cubic_channels + j * n_cubic_channels + k; |
| 167 | + scalar_t grad_val = grad_point_features[grad_idx]; |
| 168 | + // Fix bugs: the gradients of ceil and floor functions are zeros. |
| 169 | + // Ref: https://github.com/tensorflow/tensorflow/issues/897 |
| 170 | + // atomicAdd(&(grad_ptcloud[i * 3 + 0]), grad_val); |
| 171 | + // atomicAdd(&(grad_ptcloud[i * 3 + 1]), grad_val); |
| 172 | + // atomicAdd(&(grad_ptcloud[i * 3 + 2]), grad_val); |
| 173 | + atomicAdd(&(grad_cubic_features[k * cub_scale + vertex_idx]), grad_val); |
| 174 | + } |
| 175 | + } |
| 176 | + } |
| 177 | +} |
| 178 | + |
| 179 | +std::vector<torch::Tensor> |
| 180 | +cubic_feature_sampling_grad_kernel_wrapper(int scale, int neighborhood_size, |
| 181 | + torch::Tensor grad_point_features, |
| 182 | + torch::Tensor grid_pt_indexes, cudaStream_t stream) |
| 183 | +{ |
| 184 | + int batch_size = grad_point_features.size(0); |
| 185 | + int n_cubic_channels = grad_point_features.size(3); |
| 186 | + int n_pts = grid_pt_indexes.size(1); |
| 187 | + int n_vertices = std::pow(neighborhood_size * 2, 3); |
| 188 | + |
| 189 | + torch::Tensor grad_ptcloud = |
| 190 | + torch::zeros({batch_size, n_pts, 3}, torch::CUDA(grad_point_features.scalar_type())); |
| 191 | + torch::Tensor grad_cubic_features = |
| 192 | + torch::zeros({batch_size, n_cubic_channels, scale, scale, scale}, |
| 193 | + torch::CUDA(grad_point_features.scalar_type())); |
| 194 | + |
| 195 | + AT_DISPATCH_FLOATING_TYPES( |
| 196 | + grad_point_features.scalar_type(), "cubic_feature_sampling_grad_cuda", ([&] { |
| 197 | + cubic_feature_sampling_grad_kernel<<<batch_size, get_n_threads(n_pts), 0, stream>>>( |
| 198 | + scale, neighborhood_size, n_vertices, n_pts, n_cubic_channels, |
| 199 | + grad_point_features.data_ptr<scalar_t>(), grid_pt_indexes.data_ptr<int>(), |
| 200 | + grad_ptcloud.data_ptr<scalar_t>(), grad_cubic_features.data_ptr<scalar_t>()); |
| 201 | + })); |
| 202 | + |
| 203 | + cudaError_t err = cudaGetLastError(); |
| 204 | + if (err != cudaSuccess) |
| 205 | + { |
| 206 | + printf("Error in cubic_feature_sampling_grad_kernel_wrapper: %s\n", |
| 207 | + cudaGetErrorString(err)); |
| 208 | + } |
| 209 | + return {grad_ptcloud, grad_cubic_features}; |
| 210 | +} |
0 commit comments