Skip to content

Commit b56c235

Browse files
committed
use bool mask
1 parent 1c4fdfe commit b56c235

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

cuda/unique_kernel.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
#define BLOCKS(N) (N + THREADS - 1) / THREADS
77

88
template <typename scalar_t>
9-
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, uint8_t *mask,
9+
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, bool *mask,
1010
size_t numel) {
1111
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
1212
const size_t stride = blockDim.x * gridDim.x;
1313
for (ptrdiff_t i = index; i < numel; i += stride) {
1414
if (i == 0 || src[i] != src[i - 1]) {
15-
mask[i] = 1;
15+
mask[i] = true;
1616
}
1717
}
1818
}
@@ -22,10 +22,10 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
2222
at::Tensor perm;
2323
std::tie(src, perm) = src.sort();
2424

25-
auto mask = at::zeros(src.numel(), src.options().dtype(at::kByte));
25+
auto mask = at::zeros(src.numel(), src.options().dtype(at::kBool));
2626
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
2727
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
28-
src.DATA_PTR<scalar_t>(), mask.DATA_PTR<uint8_t>(), src.numel());
28+
src.DATA_PTR<scalar_t>(), mask.DATA_PTR<bool>(), src.numel());
2929
});
3030

3131
src = src.masked_select(mask);

0 commit comments

Comments
 (0)