Skip to content

Commit 105a60b

Browse files
committed
fix spmm for highly sparse matrices
1 parent fca6819 commit 105a60b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

csrc/cpu/spmm_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
5656
value_data = optional_value.value().data_ptr<scalar_t>();
5757
}
5858

59-
int64_t grain_size =
60-
at::internal::GRAIN_SIZE / (K * (col.numel() / M));
59+
int64_t grain_size = at::internal::GRAIN_SIZE /
60+
(K * std::max(col.numel() / M, (int64_t)1));
6161
at::parallel_for(
6262
0, B * M, grain_size, [&](int64_t begin, int64_t end) {
6363
scalar_t val;

0 commit comments

Comments
 (0)