Skip to content

Commit 134b5c7

Browse files
committed
fix division for segment_coo_mean when input is of type long
1 parent 8ec6d0c commit 134b5c7

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

csrc/cuda/segment_coo_cuda.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,10 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
273273
auto count = arg_out.value();
274274
for (int i = dim + 1; i < out.dim(); i++)
275275
count = count.unsqueeze(-1);
276-
out.div_(count);
276+
if (out.is_floating_point())
277+
out.true_divide_(count);
278+
else
279+
out.floor_divide_(count);
277280
}
278281
});
279282
});

0 commit comments

Comments
 (0)