Skip to content

Commit 66bcc36

Browse files
committed
fix pytorch 1.8.0 compatibility
1 parent 1f49a3a commit 66bcc36

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

csrc/cuda/segment_coo_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
277277
for (int i = dim + 1; i < out.dim(); i++)
278278
count = count.unsqueeze(-1);
279279
if (out.is_floating_point())
280-
out.div_(count);
280+
out.true_divide_(count);
281281
else
282-
out.div_(count, "floor");
282+
out.floor_divide_(count);
283283
}
284284
});
285285
});

csrc/scatter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ class ScatterMean : public torch::autograd::Function<ScatterMean> {
130130
count.masked_fill_(count < 1, 1);
131131
count = broadcast(count, out, dim);
132132
if (out.is_floating_point())
133-
out.div_(count);
133+
out.true_divide_(count);
134134
else
135-
out.div_(count, "floor");
135+
out.floor_divide_(count);
136136

137137
ctx->save_for_backward({index, count});
138138
if (optional_out.has_value())

torch_scatter/scatter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
5252
count = scatter_sum(ones, index, index_dim, None, dim_size)
5353
count[count < 1] = 1
5454
count = broadcast(count, out, dim)
55-
rounding_mode = None if torch.is_floating_point(out) else 'floor'
56-
out.div_(count, rounding_mode=rounding_mode)
55+
if out.is_floating_point():
56+
out.true_divide_(count)
57+
else:
58+
out.floor_divide_(count)
5759
return out
5860

5961

0 commit comments

Comments
 (0)