|
3 | 3 | from torch_scatter import scatter_add, scatter_max |
4 | 4 |
|
5 | 5 |
|
6 | | -def scatter_log_softmax(src, index, dim=-1, dim_size=None): |
| 6 | +def scatter_softmax(src, index, dim=-1, eps=1e-12): |
7 | 7 | r""" |
8 | | - Numerical safe log-softmax of all values from |
9 | | - the :attr:`src` tensor into :attr:`out` at the |
10 | | - indices specified in the :attr:`index` tensor along a given axis |
11 | | - :attr:`dim`.If multiple indices reference the same location, their |
12 | | - **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). |
| 8 | + Softmax operation over all values in :attr:`src` tensor that share indices |
| 9 | + specified in the :attr:`index` tensor along a given axis :attr:`dim`. |
13 | 10 |
|
14 | 11 | For one-dimensional tensors, the operation computes |
15 | 12 |
|
16 | 13 | .. math:: |
17 | | - \mathrm{out}_i = softmax(\mathrm{src}_i) = |
18 | | - \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j) |
| 14 | + \mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i = |
| 15 | + \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)} |
19 | 16 |
|
20 | | - where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that |
| 17 | + where :math:`\sum_j` is over :math:`j` such that |
21 | 18 | :math:`\mathrm{index}_j = i`. |
22 | 19 |
|
23 | | - Compute a numerically safe log softmax operation |
24 | | - from the :attr:`src` tensor into :attr:`out` at the indices |
25 | | - specified in the :attr:`index` tensor along a given axis :attr:`dim`. For |
26 | | - each value in :attr:`src`, its output index is specified by its index in |
27 | | - :attr:`input` for dimensions outside of :attr:`dim` and by the |
28 | | - corresponding value in :attr:`index` for dimension :attr:`dim`. |
29 | | -
|
30 | 20 | Args: |
31 | 21 | src (Tensor): The source tensor. |
32 | 22 | index (LongTensor): The indices of elements to scatter. |
33 | 23 | dim (int, optional): The axis along which to index. |
34 | 24 | (default: :obj:`-1`) |
35 | | - dim_size (int, optional): If :attr:`out` is not given, automatically |
36 | | - create output with size :attr:`dim_size` at dimension :attr:`dim`. |
37 | | - If :attr:`dim_size` is not given, a minimal sized output tensor is |
38 | | - returned. (default: :obj:`None`) |
39 | | - fill_value (int, optional): If :attr:`out` is not given, automatically |
40 | | - fill output tensor with :attr:`fill_value`. If set to :obj:`None`, |
41 | | - the output tensor is filled with the smallest possible value of |
42 | | - :obj:`src.dtype`. (default: :obj:`None`) |
| 25 | + eps (float, optional): Small value to ensure numerical stability. |
| 26 | + (default: :obj:`1e-12`) |
43 | 27 |
|
44 | 28 | :rtype: :class:`Tensor` |
45 | 29 | """ |
46 | 30 | if not torch.is_floating_point(src): |
47 | | - raise ValueError('log_softmax can be computed only over ' |
48 | | - 'tensors with floating point data types.') |
| 31 | + raise ValueError('`scatter_softmax` can only be computed over tensors ' |
| 32 | + 'with floating point data types.') |
49 | 33 |
|
50 | | - max_value_per_index, _ = scatter_max(src, index, |
51 | | - dim=dim, |
52 | | - dim_size=dim_size) |
| 34 | + max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) |
53 | 35 | max_per_src_element = max_value_per_index.gather(dim, index) |
54 | 36 |
|
55 | 37 | recentered_scores = src - max_per_src_element |
| 38 | + recentered_scores_exp = recentered_scores.exp() |
56 | 39 |
|
57 | | - sum_per_index = scatter_add( |
58 | | - src=recentered_scores.exp(), |
59 | | - index=index, |
60 | | - dim=dim, |
61 | | - dim_size=dim_size |
62 | | - ) |
63 | | - log_normalizing_constants = sum_per_index.log().gather(dim, index) |
| 40 | + sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim) |
| 41 | + normalizing_constants = (sum_per_index + eps).gather(dim, index) |
64 | 42 |
|
65 | | - return recentered_scores - log_normalizing_constants |
| 43 | + return recentered_scores_exp / normalizing_constants |
66 | 44 |
|
67 | 45 |
|
68 | | -def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16): |
| 46 | +def scatter_log_softmax(src, index, dim=-1, eps=1e-12): |
69 | 47 | r""" |
70 | | - Numerical safe log-softmax of all values from |
71 | | - the :attr:`src` tensor into :attr:`out` at the |
| 48 | + Log-softmax operation over all values in :attr:`src` tensor that share |
72 | 49 | indices specified in the :attr:`index` tensor along a given axis |
73 | | - :attr:`dim`. If multiple indices reference the same location, their |
74 | | - **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). |
| 50 | + :attr:`dim`. |
75 | 51 |
|
76 | 52 | For one-dimensional tensors, the operation computes |
77 | 53 |
|
78 | 54 | .. math:: |
79 | | - \mathrm{out}_i = softmax(\mathrm{src}_i) = |
80 | | - \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)} |
| 55 | + \mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i = |
| 56 | + \log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)} |
| 57 | + \right) |
81 | 58 |
|
82 | | - where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that |
| 59 | + where :math:`\sum_j` is over :math:`j` such that |
83 | 60 | :math:`\mathrm{index}_j = i`. |
84 | 61 |
|
85 | | - Compute a numerically safe softmax operation |
86 | | - from the :attr:`src` tensor into :attr:`out` at the indices |
87 | | - specified in the :attr:`index` tensor along a given axis :attr:`dim`. For |
88 | | - each value in :attr:`src`, its output index is specified by its index in |
89 | | - :attr:`input` for dimensions outside of :attr:`dim` and by the |
90 | | - corresponding value in :attr:`index` for dimension :attr:`dim`. |
91 | | -
|
92 | 62 | Args: |
93 | 63 | src (Tensor): The source tensor. |
94 | 64 | index (LongTensor): The indices of elements to scatter. |
95 | 65 | dim (int, optional): The axis along which to index. |
96 | 66 | (default: :obj:`-1`) |
97 | | - dim_size (int, optional): If :attr:`out` is not given, automatically |
98 | | - create output with size :attr:`dim_size` at dimension :attr:`dim`. |
99 | | - If :attr:`dim_size` is not given, a minimal sized output tensor is |
100 | | - returned. (default: :obj:`None`) |
101 | | - fill_value (int, optional): If :attr:`out` is not given, automatically |
102 | | - fill output tensor with :attr:`fill_value`. If set to :obj:`None`, |
103 | | - the output tensor is filled with the smallest possible value of |
104 | | - :obj:`src.dtype`. (default: :obj:`None`) |
| 67 | + eps (float, optional): Small value to ensure numerical stability. |
| 68 | + (default: :obj:`1e-12`) |
105 | 69 |
|
106 | 70 | :rtype: :class:`Tensor` |
107 | 71 | """ |
108 | 72 | if not torch.is_floating_point(src): |
109 | | - raise ValueError('softmax can be computed only over ' |
| 73 | + raise ValueError('`scatter_log_softmax` can only be computed over ' |
110 | 74 | 'tensors with floating point data types.') |
111 | 75 |
|
112 | | - max_value_per_index, _ = scatter_max(src, index, |
113 | | - dim=dim, |
114 | | - dim_size=dim_size) |
| 76 | + max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) |
115 | 77 | max_per_src_element = max_value_per_index.gather(dim, index) |
116 | 78 |
|
117 | 79 | recentered_scores = src - max_per_src_element |
118 | | - exped_recentered_scores = recentered_scores.exp() |
119 | | - |
120 | | - sum_per_index = scatter_add( |
121 | | - src=exped_recentered_scores, |
122 | | - index=index, |
123 | | - dim=dim, |
124 | | - dim_size=dim_size |
125 | | - ) |
126 | | - normalizing_constant = (sum_per_index + epsilon).gather(dim, index) |
127 | | - return exped_recentered_scores / normalizing_constant |
| 80 | + |
| 81 | + sum_per_index = scatter_add(src=recentered_scores.exp(), index=index, |
| 82 | + dim=dim) |
| 83 | + |
| 84 | + normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index) |
| 85 | + |
| 86 | + return recentered_scores - normalizing_constants |
0 commit comments