Skip to content

Commit d824c8b

Browse files
committed
segment documentation
1 parent 9b365d3 commit d824c8b

File tree

8 files changed

+197
-19
lines changed

8 files changed

+197
-19
lines changed

docs/source/_figures/segment_coo.svg

Lines changed: 19 additions & 11 deletions
Loading

docs/source/_figures/template.tex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
\draw[edge] (index\i) -- (input\i);
2626
}
2727

28-
\node[title] at (-0.8, 0.0) {output};
28+
\node[title] at (-0.8, 0.0) {out};
2929
\foreach \i in {0,...,\numberOutputs} {
3030
\pgfmathparse{\outputs[\i]}\let\out\pgfmathresult
3131
\pgfmathparse{\colors[\i]}\let\co\pgfmathresult
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Segment COO
2+
===========
3+
4+
.. automodule:: torch_scatter
5+
:noindex:
6+
7+
.. autofunction:: segment_coo
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Segment CSR
2+
===========
3+
4+
.. automodule:: torch_scatter
5+
:noindex:
6+
7+
.. autofunction:: segment_csr

docs/source/index.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
PyTorch Scatter Documentation
44
=============================
55

6-
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package.
7-
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor.
6+
This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package.
7+
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
8+
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
89

910
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
1011

torch_scatter/add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
1515
Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
1616
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
1717
each value in :attr:`src`, its output index is specified by its index in
18-
:attr:`input` for dimensions outside of :attr:`dim` and by the
18+
:attr:`src` for dimensions outside of :attr:`dim` and by the
1919
corresponding value in :attr:`index` for dimension :attr:`dim`. If
2020
multiple indices reference the same location, their **contributions add**.
2121

torch_scatter/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import torch
22

33

4-
def min_value(dtype):
4+
def min_value(dtype): # pragma: no cover
55
try:
66
return torch.finfo(dtype).min
77
except TypeError:
88
return torch.iinfo(dtype).min
99

1010

11-
def max_value(dtype):
11+
def max_value(dtype): # pragma: no cover
1212
try:
1313
return torch.finfo(dtype).max
1414
except TypeError:

torch_scatter/segment.py

Lines changed: 157 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,164 @@ def backward(ctx, grad_out, *args):
112112
return grad_src, None, None, None
113113

114114

115-
def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
115+
def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
116+
r"""
117+
|
118+
119+
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
120+
master/docs/source/_figures/segment_coo.svg?sanitize=true
121+
:align: center
122+
:width: 400px
123+
124+
|
125+
126+
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
127+
indices specified in the :attr:`index` tensor along the last dimension of
128+
:attr:`index`.
129+
For each value in :attr:`src`, its output index is specified by its index
130+
in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the
131+
corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`.
132+
The applied reduction is defined via the :attr:`reduce` argument.
133+
134+
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and
135+
:math:`m`-dimensional tensors with
136+
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
137+
:math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an
138+
:math:`n`-dimensional tensor with size
139+
:math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`.
140+
Moreover, the values of :attr:`index` must be between :math:`0` and
141+
:math:`y - 1` in ascending order.
142+
The :attr:`index` tensor supports broadcasting in case its dimensions do
143+
not match with :attr:`src`.
144+
For one-dimensional tensors with :obj:`reduce="add"`, the operation
145+
computes
146+
147+
.. math::
148+
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
149+
150+
where :math:`\sum_j` is over :math:`j` such that
151+
:math:`\mathrm{index}_j = i`.
152+
153+
In contrast to :meth:`scatter`, this method expects values in :attr:`index`
154+
**to be sorted** along dimension :obj:`index.dim() - 1`.
155+
Due to the use of sorted indices, :meth:`segment_coo` is usually faster
156+
than the more general :meth:`scatter` operation.
157+
158+
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
159+
second tensor representing the :obj:`argmin` and :obj:`argmax`,
160+
respectively.
161+
162+
.. note::
163+
164+
This operation is implemented via atomic operations on the GPU and is
165+
therefore **non-deterministic** since the order of parallel operations
166+
to the same value is undetermined.
167+
For floating-point variables, this results in a source of variance in
168+
the result.
169+
170+
Args:
171+
src (Tensor): The source tensor.
172+
index (LongTensor): The sorted indices of elements to segment.
173+
The number of dimensions of :attr:`index` needs to be less than or
174+
equal to :attr:`src`.
175+
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
176+
dim_size (int, optional): If :attr:`out` is not given, automatically
177+
create output with size :attr:`dim_size` at dimension
178+
:obj:`index.dim() - 1`.
179+
If :attr:`dim_size` is not given, a minimal sized output tensor
180+
according to :obj:`index.max() + 1` is returned.
181+
(default: :obj:`None`)
182+
reduce (string, optional): The reduce operation (:obj:`"add"`,
183+
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
184+
(default: :obj:`"add"`)
185+
186+
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
187+
188+
.. code-block:: python
189+
190+
from torch_scatter import segment_coo
191+
192+
src = torch.randn(10, 6, 64)
193+
index = torch.tensor([0, 0, 1, 1, 1, 2])
194+
index = index.view(1, -1) # Broadcasting in the first and last dim.
195+
196+
out = segment_coo(src, index, reduce="add")
197+
198+
print(out.size())
199+
200+
.. code-block::
201+
202+
torch.Size([10, 3, 64])
203+
"""
116204
return SegmentCOO.apply(src, index, out, dim_size, reduce)
117205

118206

119-
def segment_csr(src, indptr, out=None, reduce='add'):
207+
def segment_csr(src, indptr, out=None, reduce="add"):
208+
r"""
209+
Reduces all values from the :attr:`src` tensor into :attr:`out` within the
210+
ranges specified in the :attr:`indptr` tensor along the last dimension of
211+
:attr:`indptr`.
212+
For each value in :attr:`src`, its output index is specified by its index
213+
in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the
214+
corresponding range index in :attr:`indptr` for dimension
215+
:obj:`indptr.dim() - 1`.
216+
The applied reduction is defined via the :attr:`reduce` argument.
217+
218+
Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
219+
:math:`m`-dimensional tensors with
220+
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
221+
:math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an
222+
:math:`n`-dimensional tensor with size
223+
:math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`.
224+
Moreover, the values of :attr:`indptr` must be between :math:`0` and
225+
:math:`x_m` in ascending order.
226+
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
227+
not match with :attr:`src`.
228+
For one-dimensional tensors with :obj:`reduce="add"`, the operation
229+
computes
230+
231+
.. math::
232+
\mathrm{out}_i =
233+
\sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.
234+
235+
Due to the use of index pointers, :meth:`segment_csr` is the fastest
236+
method to apply for grouped reductions.
237+
238+
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
239+
second tensor representing the :obj:`argmin` and :obj:`argmax`,
240+
respectively.
241+
242+
.. note::
243+
244+
In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
245+
operation is **fully-deterministic**.
246+
247+
Args:
248+
src (Tensor): The source tensor.
249+
indptr (LongTensor): The index pointers between elements to segment.
250+
The number of dimensions of :attr:`index` needs to be less than or
251+
equal to :attr:`src`.
252+
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
253+
reduce (string, optional): The reduce operation (:obj:`"add"`,
254+
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
255+
(default: :obj:`"add"`)
256+
257+
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
258+
259+
.. code-block:: python
260+
261+
from torch_scatter import segment_csr
262+
263+
src = torch.randn(10, 6, 64)
264+
indptr = torch.tensor([0, 2, 5, 6])
265+
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
266+
267+
out = segment_csr(src, indptr, reduce="add")
268+
269+
print(out.size())
270+
271+
.. code-block::
272+
273+
torch.Size([10, 3, 64])
274+
"""
120275
return SegmentCSR.apply(src, indptr, out, reduce)

0 commit comments

Comments
 (0)