Skip to content

Commit 4c8e529

Browse files
committed
fix set_diag for nnz=0
1 parent 6456fb4 commit 4c8e529

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

csrc/cuda/diag_cuda.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
5454
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
5555
auto mask_data = mask.data_ptr<bool>();
5656

57+
if (E == 0)
58+
return mask;
59+
5760
auto stream = at::cuda::getCurrentCUDAStream();
5861
non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
5962
row_data, col_data, mask_data, N, k, num_diag, E);

0 commit comments

Comments
 (0)