@@ -44,9 +44,9 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
4444 Minimizes all values from the :attr:`src` tensor into :attr:`out` at the
4545 indices specified in the :attr:`index` tensor along a given axis
4646 :attr:`dim`.If multiple indices reference the same location, their
47- **contributions maximize ** (`cf.` :meth:`~torch_scatter.scatter_add`).
47+ **contributions minimize ** (`cf.` :meth:`~torch_scatter.scatter_add`).
4848 The second return tensor contains index location in :attr:`src` of each
49- minimum value (known as argmax ).
49+ minimum value (known as argmin ).
5050
5151 For one-dimensional tensors, the operation computes
5252
@@ -83,10 +83,10 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
8383 index = torch.tensor([[ 4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
8484 out = src.new_zeros((2, 6))
8585
86- out, argmax = scatter_min(src, index, out=out)
86+ out, argmin = scatter_min(src, index, out=out)
8787
8888 print(out)
89- print(argmax )
89+ print(argmin )
9090
9191 .. testoutput::
9292
@@ -97,5 +97,5 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
9797 """
9898 src , out , index , dim = gen (src , index , dim , out , dim_size , fill_value )
9999 if src .size (dim ) == 0 : # pragma: no cover
100- return out
100+ return out , index . new_full ( out . size (), - 1 )
101101 return ScatterMin .apply (out , src , index , dim )
0 commit comments