@@ -50,7 +50,7 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
5050 .. math::
5151 \mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j
5252
53- where :math:`\sum ` is over :math:`j` such that
53+ where :math:`\sum_j ` is over :math:`j` such that
5454 :math:`\mathrm{index}_j = i`.
5555
5656 Args:
@@ -75,17 +75,19 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
7575 .. testcode::
7676
7777 from torch_scatter import scatter_add
78+
7879 src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
7980 index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
8081 out = src.new_zeros((2, 6))
82+
8183 out = scatter_add(src, index, out=out)
84+
8285 print(out)
8386
8487 .. testoutput::
8588
86- 0 0 4 3 3 0
87- 2 4 4 0 0 0
88- [torch.FloatTensor of size 2x6]
89+ tensor([[ 0, 0, 4, 3, 3, 0],
90+ [ 2, 4, 4, 0, 0, 0]])
8991 """
9092 src , out , index , dim = gen (src , index , dim , out , dim_size , fill_value )
9193 return ScatterAdd .apply (out , src , index , dim )
0 commit comments