File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change 66#define BLOCKS (N ) (N + THREADS - 1 ) / THREADS
77
88template <typename scalar_t >
9- __global__ void unique_cuda_kernel (scalar_t *__restrict__ src, uint8_t *mask,
9+ __global__ void unique_cuda_kernel (scalar_t *__restrict__ src, bool *mask,
1010 size_t numel) {
1111 const size_t index = blockIdx .x * blockDim .x + threadIdx .x ;
1212 const size_t stride = blockDim .x * gridDim .x ;
1313 for (ptrdiff_t i = index; i < numel; i += stride) {
1414 if (i == 0 || src[i] != src[i - 1 ]) {
15- mask[i] = 1 ;
15+ mask[i] = true ;
1616 }
1717 }
1818}
@@ -22,10 +22,10 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
2222 at::Tensor perm;
2323 std::tie (src, perm) = src.sort ();
2424
25- auto mask = at::zeros (src.numel (), src.options ().dtype (at::kByte ));
25+ auto mask = at::zeros (src.numel (), src.options ().dtype (at::kBool ));
2626 AT_DISPATCH_ALL_TYPES (src.scalar_type (), " grid_cuda_kernel" , [&] {
2727 unique_cuda_kernel<scalar_t ><<<BLOCKS(src.numel()), THREADS>>> (
28- src.DATA_PTR <scalar_t >(), mask.DATA_PTR <uint8_t >(), src.numel ());
28+ src.DATA_PTR <scalar_t >(), mask.DATA_PTR <bool >(), src.numel ());
2929 });
3030
3131 src = src.masked_select (mask);
You can’t perform that action at this time.
0 commit comments