Skip to content

Commit bf1f101

Browse files
committed
use scatter add pytorch implementation
1 parent 1006514 commit bf1f101

File tree

7 files changed

+29
-9
lines changed

7 files changed

+29
-9
lines changed

benchmark/scatter_segment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def dense2(x):
217217
parser.add_argument('--with_backward', action='store_true')
218218
parser.add_argument('--device', type=str, default='cuda')
219219
args = parser.parse_args()
220-
iters = 1 if args.device == 'cpu' else 50
220+
iters = 1 if args.device == 'cpu' else 20
221221
sizes = [1, 16, 32, 64, 128, 256, 512]
222222
sizes = sizes[:3] if args.device == 'cpu' else sizes
223223

test/test_zero_tensors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
from torch_scatter import scatter
3+
4+
5+
def test_zero_elements():
6+
x = torch.randn(0, 16)
7+
index = torch.tensor([]).view(0, 16)
8+
print(x)
9+
print(index)
10+
11+
scatter(x, index, dim=0, dim_size=0, reduce="add")

torch_scatter/composite/logsumexp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch_scatter import scatter_sum, scatter_max
55

6-
from .utils import broadcast
6+
from torch_scatter.utils import broadcast
77

88

99
@torch.jit.script

torch_scatter/composite/softmax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import torch
22

33
from torch_scatter import scatter_sum, scatter_max
4-
5-
from .utils import broadcast
4+
from torch_scatter.utils import broadcast
65

76

87
@torch.jit.script

torch_scatter/composite/std.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import torch
44
from torch_scatter import scatter_sum
5-
6-
from .utils import broadcast
5+
from torch_scatter.utils import broadcast
76

87

98
@torch.jit.script

torch_scatter/scatter.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import torch
66

7+
from .utils import broadcast
8+
79
try:
810
torch.ops.load_library(
911
osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so'))
@@ -23,7 +25,6 @@ def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
2325
raise ImportError
2426
return src, index
2527

26-
torch.ops.torch_scatter.scatter_sum = scatter_placeholder
2728
torch.ops.torch_scatter.scatter_mean = scatter_placeholder
2829
torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder
2930
torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder
@@ -33,14 +34,24 @@ def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
3334
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3435
out: Optional[torch.Tensor] = None,
3536
dim_size: Optional[int] = None) -> torch.Tensor:
36-
return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size)
37+
index = broadcast(index, src, dim)
38+
if out is None:
39+
size = src.size()
40+
if dim_size is None:
41+
size[dim] = int(index.max()) + 1
42+
else:
43+
size[dim] = dim_size
44+
out = src.new_zeros(size)
45+
return out.scatter_add_(dim, index, src)
46+
else:
47+
return out.scatter_add_(dim, index, src)
3748

3849

3950
@torch.jit.script
4051
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
4152
out: Optional[torch.Tensor] = None,
4253
dim_size: Optional[int] = None) -> torch.Tensor:
43-
return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size)
54+
return scatter_sum(src, index, dim, out, dim_size)
4455

4556

4657
@torch.jit.script
File renamed without changes.

0 commit comments

Comments
 (0)