22
33from torch_scatter import scatter_add , scatter_max
44
5+
56def scatter_log_softmax (src , index , dim = - 1 , dim_size = None ):
67 r"""
7- Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
8+ Numerical safe log-softmax of all values from
9+ the :attr:`src` tensor into :attr:`out` at the
810 indices specified in the :attr:`index` tensor along a given axis
911 :attr:`dim`.If multiple indices reference the same location, their
1012 **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
1113
1214 For one-dimensional tensors, the operation computes
1315
1416 .. math::
15- \mathrm{out}_i = softmax(\mathrm{src}_i) =
17+ \mathrm{out}_i = softmax(\mathrm{src}_i) =
1618 \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j)
1719
1820 where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
@@ -42,9 +44,12 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
4244 :rtype: :class:`Tensor`
4345 """
4446 if not torch .is_floating_point (src ):
45- raise ValueError ('log_softmax can be computed only over tensors with floating point data types.' )
47+ raise ValueError ('log_softmax can be computed only over '
48+ 'tensors with floating point data types.' )
4649
47- max_value_per_index , _ = scatter_max (src , index , dim = dim , dim_size = dim_size )
50+ max_value_per_index , _ = scatter_max (src , index ,
51+ dim = dim ,
52+ dim_size = dim_size )
4853 max_per_src_element = max_value_per_index .gather (dim , index )
4954
5055 recentered_scores = src - max_per_src_element
@@ -62,15 +67,16 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None):
6267
6368def scatter_softmax (src , index , dim = - 1 , dim_size = None , epsilon = 1e-16 ):
6469 r"""
65- Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the
70+ Numerical safe log-softmax of all values from
71+ the :attr:`src` tensor into :attr:`out` at the
6672 indices specified in the :attr:`index` tensor along a given axis
6773 :attr:`dim`. If multiple indices reference the same location, their
6874 **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
6975
7076 For one-dimensional tensors, the operation computes
7177
7278 .. math::
73- \mathrm{out}_i = softmax(\mathrm{src}_i) =
79+ \mathrm{out}_i = softmax(\mathrm{src}_i) =
7480 \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)}
7581
7682 where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that
@@ -100,9 +106,12 @@ def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16):
100106 :rtype: :class:`Tensor`
101107 """
102108 if not torch .is_floating_point (src ):
103- raise ValueError ('softmax can be computed only over tensors with floating point data types.' )
109+ raise ValueError ('softmax can be computed only over '
110+ 'tensors with floating point data types.' )
104111
105- max_value_per_index , _ = scatter_max (src , index , dim = dim , dim_size = dim_size )
112+ max_value_per_index , _ = scatter_max (src , index ,
113+ dim = dim ,
114+ dim_size = dim_size )
106115 max_per_src_element = max_value_per_index .gather (dim , index )
107116
108117 recentered_scores = src - max_per_src_element
0 commit comments