@@ -112,9 +112,164 @@ def backward(ctx, grad_out, *args):
112112 return grad_src , None , None , None
113113
114114
115- def segment_coo (src , index , out = None , dim_size = None , reduce = 'add' ):
115+ def segment_coo (src , index , out = None , dim_size = None , reduce = "add" ):
116+ r"""
117+ |
118+
119+ .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
120+ master/docs/source/_figures/segment_coo.svg?sanitize=true
121+ :align: center
122+ :width: 400px
123+
124+ |
125+
126+ Reduces all values from the :attr:`src` tensor into :attr:`out` at the
127+ indices specified in the :attr:`index` tensor along the last dimension of
128+ :attr:`index`.
129+ For each value in :attr:`src`, its output index is specified by its index
130+ in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the
131+ corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`.
132+ The applied reduction is defined via the :attr:`reduce` argument.
133+
134+ Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and
135+ :math:`m`-dimensional tensors with
136+ size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
137+ :math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an
138+ :math:`n`-dimensional tensor with size
139+ :math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`.
140+ Moreover, the values of :attr:`index` must be between :math:`0` and
141+ :math:`y - 1` in ascending order.
142+ The :attr:`index` tensor supports broadcasting in case its dimensions do
143+ not match with :attr:`src`.
144+ For one-dimensional tensors with :obj:`reduce="add"`, the operation
145+ computes
146+
147+ .. math::
148+ \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
149+
150+ where :math:`\sum_j` is over :math:`j` such that
151+ :math:`\mathrm{index}_j = i`.
152+
153+ In contrast to :meth:`scatter`, this method expects values in :attr:`index`
154+ **to be sorted** along dimension :obj:`index.dim() - 1`.
155+ Due to the use of sorted indices, :meth:`segment_coo` is usually faster
156+ than the more general :meth:`scatter` operation.
157+
158+ For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
159+ second tensor representing the :obj:`argmin` and :obj:`argmax`,
160+ respectively.
161+
162+ .. note::
163+
164+ This operation is implemented via atomic operations on the GPU and is
165+ therefore **non-deterministic** since the order of parallel operations
166+ to the same value is undetermined.
167+ For floating-point variables, this results in a source of variance in
168+ the result.
169+
170+ Args:
171+ src (Tensor): The source tensor.
172+ index (LongTensor): The sorted indices of elements to segment.
173+ The number of dimensions of :attr:`index` needs to be less than or
174+ equal to :attr:`src`.
175+ out (Tensor, optional): The destination tensor. (default: :obj:`None`)
176+ dim_size (int, optional): If :attr:`out` is not given, automatically
177+ create output with size :attr:`dim_size` at dimension
178+ :obj:`index.dim() - 1`.
179+ If :attr:`dim_size` is not given, a minimal sized output tensor
180+ according to :obj:`index.max() + 1` is returned.
181+ (default: :obj:`None`)
182+ reduce (string, optional): The reduce operation (:obj:`"add"`,
183+ :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
184+ (default: :obj:`"add"`)
185+
186+ :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
187+
188+ .. code-block:: python
189+
190+ from torch_scatter import segment_coo
191+
192+ src = torch.randn(10, 6, 64)
193+ index = torch.tensor([0, 0, 1, 1, 1, 2])
194+ index = index.view(1, -1) # Broadcasting in the first and last dim.
195+
196+ out = segment_coo(src, index, reduce="add")
197+
198+ print(out.size())
199+
200+ .. code-block::
201+
202+ torch.Size([10, 3, 64])
203+ """
116204 return SegmentCOO .apply (src , index , out , dim_size , reduce )
117205
118206
119- def segment_csr (src , indptr , out = None , reduce = 'add' ):
207+ def segment_csr (src , indptr , out = None , reduce = "add" ):
208+ r"""
209+ Reduces all values from the :attr:`src` tensor into :attr:`out` within the
210+ ranges specified in the :attr:`indptr` tensor along the last dimension of
211+ :attr:`indptr`.
212+ For each value in :attr:`src`, its output index is specified by its index
213+ in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the
214+ corresponding range index in :attr:`indptr` for dimension
215+ :obj:`indptr.dim() - 1`.
216+ The applied reduction is defined via the :attr:`reduce` argument.
217+
218+ Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
219+ :math:`m`-dimensional tensors with
220+ size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
221+ :math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an
222+ :math:`n`-dimensional tensor with size
223+ :math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`.
224+ Moreover, the values of :attr:`indptr` must be between :math:`0` and
225+ :math:`x_m` in ascending order.
226+ The :attr:`indptr` tensor supports broadcasting in case its dimensions do
227+ not match with :attr:`src`.
228+ For one-dimensional tensors with :obj:`reduce="add"`, the operation
229+ computes
230+
231+ .. math::
232+ \mathrm{out}_i =
233+ \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.
234+
235+ Due to the use of index pointers, :meth:`segment_csr` is the fastest
236+ method to apply for grouped reductions.
237+
238+ For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
239+ second tensor representing the :obj:`argmin` and :obj:`argmax`,
240+ respectively.
241+
242+ .. note::
243+
244+ In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
245+ operation is **fully-deterministic**.
246+
247+ Args:
248+ src (Tensor): The source tensor.
249+ indptr (LongTensor): The index pointers between elements to segment.
250+ The number of dimensions of :attr:`index` needs to be less than or
251+ equal to :attr:`src`.
252+ out (Tensor, optional): The destination tensor. (default: :obj:`None`)
253+ reduce (string, optional): The reduce operation (:obj:`"add"`,
254+ :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
255+ (default: :obj:`"add"`)
256+
257+ :rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
258+
259+ .. code-block:: python
260+
261+ from torch_scatter import segment_csr
262+
263+ src = torch.randn(10, 6, 64)
264+ indptr = torch.tensor([0, 2, 5, 6])
265+ indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
266+
267+ out = segment_csr(src, indptr, reduce="add")
268+
269+ print(out.size())
270+
271+ .. code-block::
272+
273+ torch.Size([10, 3, 64])
274+ """
120275 return SegmentCSR .apply (src , indptr , out , reduce )
0 commit comments