Skip to content

Commit 5e69a34

Browse files
committed
only execute kernels if size(dim) > 0
1 parent a8e3e28 commit 5e69a34

File tree

4 files changed

+8
-0
lines changed

4 files changed

+8
-0
lines changed

torch_scatter/div.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,6 @@ def scatter_div(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
8989
[0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]])
9090
"""
9191
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
92+
if src.size(dim) == 0: # pragma: no cover
93+
return out
9294
return ScatterDiv.apply(out, src, index, dim)

torch_scatter/max.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,6 @@ def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
9696
[ 1, 4, 3, -1, -1, -1]])
9797
"""
9898
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
99+
if src.size(dim) == 0: # pragma: no cover
100+
return out
99101
return ScatterMax.apply(out, src, index, dim)

torch_scatter/min.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,6 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
9696
[ 1, 4, 3, -1, -1, -1]])
9797
"""
9898
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
99+
if src.size(dim) == 0: # pragma: no cover
100+
return out
99101
return ScatterMin.apply(out, src, index, dim)

torch_scatter/mul.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,6 @@ def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
8888
[6, 4, 8, 1, 1, 1]])
8989
"""
9090
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
91+
if src.size(dim) == 0: # pragma: no cover
92+
return out
9193
return ScatterMul.apply(out, src, index, dim)

0 commit comments

Comments
 (0)