Skip to content

Commit 46a9c9a

Browse files
committed
added docs
1 parent 37d18b5 commit 46a9c9a

File tree

2 files changed

+151
-29
lines changed

2 files changed

+151
-29
lines changed

README.md

Lines changed: 148 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313

1414
--------------------------------------------------------------------------------
1515

16-
This package consists of a small extension library of optimized sparse matrix operations for the use in [PyTorch](http://pytorch.org/), which are missing and or lack autograd support in the main package.
16+
[PyTorch](http://pytorch.org/) (<= 0.4.1) completely lacks autograd support and operations such as sparse sparse matrix multiplication, but is heavily working on improvement (*cf.* [this issue](https://github.com/pytorch/pytorch/issues/9674)).
17+
In the meantime, this package consists of a small extension library of optimized sparse matrix operations with autograd support.
1718
This package currently consists of the following methods:
1819

19-
* **[Autograd Sparse Tensor Creation](#autograd-sparse-tensor-creation)**
20-
* **[Autograd Sparse Tensor Value Extraction](#autograd-sparse-tensor-value-extraction)**
20+
* **[Coalesce](#coalesce)**
21+
* **[Transpose](#transpose)**
22+
* **[Sparse Dense Matrix Multiplication](#sparse-dense-matrix-multiplication)**
2123
* **[Sparse Sparse Matrix Multiplication](#sparse-sparse-matrix-multiplication)**
2224

2325
All included operations work on varying data types and are implemented both for CPU and GPU.
26+
To avoid the hazzle of creating [`torch.sparse_coo_tensor`](https://pytorch.org/docs/stable/torch.html?highlight=sparse_coo_tensor#torch.sparse_coo_tensor), this package defines operations on sparse tensors by simply passing `index` and `value` tensors as arguments ([with same shapes as defined in PyTorch](https://pytorch.org/docs/stable/sparse.html)).
27+
Note that only `value` comes with autograd support, as `index` is discrete and therefore not differentiable.
2428

2529
## Installation
2630

@@ -45,60 +49,176 @@ pip install torch-scatter torch-sparse
4549

4650
If you are running into any installation problems, please create an [issue](https://github.com/rusty1s/pytorch_sparse/issues).
4751

48-
## Autograd Sparse Tensor Creation
52+
## Coalesce
4953

5054
```
51-
torch_sparse.sparse_coo_tensor(torch.LongTensor, torch.Tensor, torch.Size) -> torch.SparseTensor
55+
torch_sparse.coalesce(index, value, m, n, op="add", fill_value=0) -> (torch.LongTensor, torch.Tensor)
5256
```
5357

54-
Constructs a [`torch.SparseTensor`](https://pytorch.org/docs/stable/sparse.html) with autograd capabilities w.r.t. `value`.
58+
Row-wise sorts `value` and removes duplicate entries.
59+
Duplicate entries are removed by scattering them together.
60+
For scattering, any operation of [`torch_scatter`](https://github.com/rusty1s/pytorch_scatter) can be used.
61+
62+
### Parameters
63+
64+
* **index** *(LongTensor)* - The index tensor of sparse matrix.
65+
* **value** *(Tensor)* - The value tensor of sparse matrix.
66+
* **m** *(int)* - First dimension of sparse matrix.
67+
* **n** *(int)* - Second dimension of sparse matrix.
68+
* **op** *(string, optional)* - Scatter operation to use. (default: `"add"`)
69+
* **fill_value** *(int, optional)* - Initial fill value of scatter operation. (default: `0`)
70+
71+
### Returns
72+
73+
* **index** *(LongTensor)* - Coalesced index tensor of sparse matrix.
74+
* **value** *(Tensor)* - Coalesced value tensor of sparse matrix.
75+
76+
### Example
5577

5678
```python
57-
from torch_sparse import sparse_coo_tensor
79+
from torch_sparse import coalesce
80+
81+
index = torch.tensor([[1, 0, 1, 0, 2, 1],
82+
[0, 1, 1, 1, 0, 0]])
83+
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
5884

59-
i = torch.tensor([[0, 1, 1],
60-
[2, 0, 2]])
61-
v = torch.Tensor([3, 4, 5], requires_grad=True)
62-
A = sparse_coo_tensor(i, v, torch.Size([2,3]))
85+
index, value = coalesce(index, value, m=3, n=2)
6386
```
6487

65-
This method may become obsolete in future PyTorch releases (>= 0.4.1) as reported by this [issue](https://github.com/pytorch/pytorch/issues/9674).
88+
```
89+
print(index)
90+
tensor([[0, 1, 1, 2],
91+
[1, 0, 1, 0]])
92+
print(value)
93+
tensor([[6, 8], [7, 9], [3, 4], [5, 6]])
94+
```
6695

67-
## Autograd Sparse Tensor Value Extraction
96+
## Transpose
6897

6998
```
70-
torch_sparse.to_value(torch.SparseTensor) -> torch.Tensor
99+
torch_sparse.transpose(index, value, m, n) -> (torch.LongTensor, torch.Tensor)
71100
```
72101

73-
Wrapper method to support autograd on values of [`torch.SparseTensor`](https://pytorch.org/docs/stable/sparse.html).
102+
Transposes dimensions 0 and 1 of a sparse matrix.
103+
104+
### Parameters
105+
106+
* **index** *(LongTensor)* - The index tensor of sparse matrix.
107+
* **value** *(Tensor)* - The value tensor of sparse matrix.
108+
* **m** *(int)* - First dimension of sparse matrix.
109+
* **n** *(int)* - Second dimension of sparse matrix.
110+
111+
### Returns
112+
113+
* **index** *(LongTensor)* - Transposed index tensor of sparse matrix.
114+
* **value** *(Tensor)* - Transposed value tensor of sparse matrix.
115+
116+
### Example
74117

75118
```python
76-
from torch_sparse import to_value
119+
from torch_sparse import transpose
120+
121+
index = torch.tensor([[1, 0, 1, 0, 2, 1],
122+
[0, 1, 1, 1, 0, 0]])
123+
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
77124

78-
i = torch.tensor([[0, 1, 1],
79-
[2, 0, 2]])
80-
v = torch.Tensor([3, 4, 5], requires_grad=True)
81-
A = torch.sparse_coo_tensor(i, v, torch.Size([2,3]), requires_grad=True)
82-
v = to_value(A)
125+
index, value = transpose(index, value, m=3, n=2)
83126
```
84127

85-
This method may become obsolete in future PyTorch releases (>= 0.4.1) as reported by this [issue](https://github.com/pytorch/pytorch/issues/9674).
128+
```
129+
print(index)
130+
tensor([[0, 0, 1, 1],
131+
[1, 2, 0, 1]])
132+
print(value)
133+
tensor([[7, 9],
134+
[5, 6],
135+
[6, 8],
136+
[3, 4]])
137+
```
86138

87-
## Sparse Sparse Matrix Multiplication
139+
## Sparse Dense Matrix Multiplication
88140

89141
```
90-
torch_sparse.spspmm(torch.SparseTensor, torch.SparseTensor) -> torch.SparseTensor
142+
torch_sparse.spmm(index, value, m, matrix) -> torch.Tensor
91143
```
92144

93-
Sparse matrix product of two sparse tensors with autograd support.
145+
Matrix product of a sparse matrix with a dense matrix.
146+
147+
### Parameters
94148

149+
* **index** *(LongTensor)* - The index tensor of sparse matrix.
150+
* **value** *(Tensor)* - The value tensor of sparse matrix.
151+
* **m** *(int)* - First dimension of sparse matrix.
152+
* **matrix** *(int)* - Dense matrix.
153+
154+
### Returns
155+
156+
* **out** *(Tensor)* - Dense output matrix.
157+
158+
### Example
159+
160+
```python
161+
from torch_sparse import spmm
162+
163+
index = torch.tensor([[0, 0, 1, 2, 2],
164+
[0, 2, 1, 0, 1]])
165+
value = torch.tensor([1, 2, 4, 1, 3])
166+
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]])
167+
168+
out = spmm(index, value, 3, matrix)
169+
```
170+
171+
```
172+
print(out)
173+
tensor([[7, 16],
174+
[8, 20],
175+
[7, 19]])
95176
```
177+
178+
## Sparse Sparse Matrix Multiplication
179+
180+
```
181+
torch_sparse.spspmm(indexA, valueA, indexB, valueB, m, k, n) -> (torch.LongTensor, torch.Tensor)
182+
```
183+
184+
Matrix product of two sparse tensors.
185+
Both input sparse matrices need to be **coalesced**.
186+
187+
### Parameters
188+
189+
* **indexA** *(LongTensor)* - The index tensor of first sparse matrix.
190+
* **valueA** *(Tensor)* - The value tensor of first sparse matrix.
191+
* **indexB** *(LongTensor)* - The index tensor of second sparse matrix.
192+
* **valueB** *(Tensor)* - The value tensor of second sparse matrix.
193+
* **m** *(int)* - First dimension of first sparse matrix.
194+
* **k** *(int)* - Second dimension of first sparse matrix and first dimension of second sparse matrix.
195+
* **n** *(int)* - Second dimension of second sparse matrix.
196+
197+
### Returns
198+
199+
* **index** *(LongTensor)* - Output index tensor of sparse matrix.
200+
* **value** *(Tensor)* - Output value tensor of sparse matrix.
201+
202+
### Example
203+
204+
```python
96205
from torch_sparse import spspmm
97206

98-
A = torch.sparse_coo_tensor(..., requries_grad=True)
99-
B = torch.sparse_coo_tensor(..., requries_grad=True)
207+
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]])
208+
valueA = torch.tensor([1, 2, 3, 4, 5])
209+
210+
indexB = torch.tensor([[0, 2], [1, 0]])
211+
valueB = torch.tensor([2, 4])
100212

101-
C = spspmm(A, B)
213+
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
214+
```
215+
216+
```
217+
print(index)
218+
tensor([[0, 1, 2],
219+
[0, 1, 1]])
220+
print(value)
221+
tensor([8, 6, 8])
102222
```
103223

104224
## Running tests

torch_sparse/coalesce.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
def coalesce(index, value, m, n, op='add', fill_value=0):
6-
"""Row-wise reorders and removes duplicate entries in sparse matrixx."""
6+
"""Row-wise reorders and removes duplicate entries in sparse matrix."""
77

88
row, col = index
99

@@ -16,5 +16,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
1616
if value is not None:
1717
op = getattr(torch_scatter, 'scatter_{}'.format(op))
1818
value = op(value, inv, 0, None, perm.size(0), fill_value)
19+
if isinstance(value, tuple):
20+
value = value[0]
1921

2022
return index, value

0 commit comments

Comments
 (0)