Skip to content

Commit c01f9ba

Browse files
authored
Merge pull request #105 from rusty1s/traceable
[WIP] tracebale functions
2 parents 2520670 + 02a47c4 commit c01f9ba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+3048
-3507
lines changed

.coveragerc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,5 @@ source=torch_scatter
33
[report]
44
exclude_lines =
55
pragma: no cover
6-
cuda
7-
forward
8-
backward
9-
apply
6+
torch.jit.script
107
raise
11-
min_value
12-
max_value

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ install:
3939
- pip install codecov
4040
- pip install sphinx
4141
- pip install sphinx_rtd_theme
42+
- pip install sphinx-autodoc-typehints
4243
script:
4344
- python -c "import torch; print(torch.__version__)"
4445
- pycodestyle .

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2019 Matthias Fey <[email protected]>
1+
Copyright (c) 2020 Matthias Fey <[email protected]>
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,27 @@
2222

2323
**[Documentation](https://pytorch-scatter.readthedocs.io)**
2424

25-
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
26-
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor.
27-
The package consists of the following operations:
25+
This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
26+
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
27+
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
2828

29-
* [**Scatter Add**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html)
30-
* [**Scatter Sub**](https://pytorch-scatter.readthedocs.io/en/latest/functions/sub.html)
31-
* [**Scatter Mul**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mul.html)
32-
* [**Scatter Div**](https://pytorch-scatter.readthedocs.io/en/latest/functions/div.html)
33-
* [**Scatter Mean**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mean.html)
34-
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
35-
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
36-
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
37-
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
29+
The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`:
3830

39-
In addition, we provide composite functions which make use of `scatter_*` operations under the hood:
31+
* [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment.html) based on arbitrary indices
32+
* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices
33+
* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers
4034

41-
* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
42-
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
35+
In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: :`scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`.
4336

44-
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
37+
All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
4538

4639
## Installation
4740

48-
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
41+
Ensure that at least PyTorch 1.3.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
4942

5043
```
5144
$ python -c "import torch; print(torch.__version__)"
52-
>>> 1.1.0
45+
>>> 1.3.0
5346
5447
$ echo $PATH
5548
>>> /usr/local/cuda/bin:...
@@ -81,17 +74,17 @@ from torch_scatter import scatter_max
8174
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
8275
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
8376

84-
out, argmax = scatter_max(src, index, fill_value=0)
77+
out, argmax = scatter_max(src, index, dim=-1)
8578
```
8679

8780
```
8881
print(out)
89-
tensor([[ 0, 0, 4, 3, 2, 0],
90-
[ 2, 4, 3, 0, 0, 0]])
82+
tensor([[0, 0, 4, 3, 2, 0],
83+
[2, 4, 3, 0, 0, 0]])
9184
9285
print(argmax)
93-
tensor([[-1, -1, 3, 4, 0, 1]
94-
[ 1, 4, 3, -1, -1, -1]])
86+
tensor([[5, 5, 3, 4, 0, 1]
87+
[1, 4, 3, 5, 5, 5]])
9588
```
9689

9790
## Running tests

benchmark/scatter_segment.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
import torch
88
from scipy.io import loadmat
99

10-
import torch_scatter
11-
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
12-
from torch_scatter import segment_coo, segment_csr
10+
from torch_scatter import scatter, segment_coo, segment_csr
1311

1412
short_rows = [
1513
('DIMACS10', 'citationCiteseer'),
@@ -47,34 +45,30 @@ def correctness(dataset):
4745
x = torch.randn((row.size(0), size), device=args.device)
4846
x = x.squeeze(-1) if size == 1 else x
4947

50-
out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
48+
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add')
5149
out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
5250
out3 = segment_csr(x, rowptr, reduce='add')
5351

5452
assert torch.allclose(out1, out2, atol=1e-4)
5553
assert torch.allclose(out1, out3, atol=1e-4)
5654

57-
out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
55+
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean')
5856
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
5957
out3 = segment_csr(x, rowptr, reduce='mean')
6058

6159
assert torch.allclose(out1, out2, atol=1e-4)
6260
assert torch.allclose(out1, out3, atol=1e-4)
6361

64-
x = x.abs_().mul_(-1)
65-
66-
out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1))
67-
out2, _ = segment_coo(x, row, reduce='min')
68-
out3, _ = segment_csr(x, rowptr, reduce='min')
62+
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min')
63+
out2 = segment_coo(x, row, reduce='min')
64+
out3 = segment_csr(x, rowptr, reduce='min')
6965

7066
assert torch.allclose(out1, out2, atol=1e-4)
7167
assert torch.allclose(out1, out3, atol=1e-4)
7268

73-
x = x.abs_()
74-
75-
out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1))
76-
out2, _ = segment_coo(x, row, reduce='max')
77-
out3, _ = segment_csr(x, rowptr, reduce='max')
69+
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max')
70+
out2 = segment_coo(x, row, reduce='max')
71+
out3 = segment_csr(x, rowptr, reduce='max')
7872

7973
assert torch.allclose(out1, out2, atol=1e-4)
8074
assert torch.allclose(out1, out3, atol=1e-4)
@@ -117,17 +111,15 @@ def timing(dataset):
117111
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
118112
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
119113
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
120-
row_perm = row[torch.randperm(row.size(0))]
114+
row2 = row[torch.randperm(row.size(0))]
121115
dim_size = rowptr.size(0) - 1
122116
avg_row_len = row.size(0) / dim_size
123117

124118
def sca_row(x):
125-
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
126-
return op(x, row, dim=0, dim_size=dim_size)
119+
return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
127120

128121
def sca_col(x):
129-
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
130-
return op(x, row_perm, dim=0, dim_size=dim_size)
122+
return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
131123

132124
def seg_coo(x):
133125
return segment_coo(x, row, reduce=args.reduce)
@@ -205,11 +197,10 @@ def dense2(x):
205197
if __name__ == '__main__':
206198
parser = argparse.ArgumentParser()
207199
parser.add_argument('--reduce', type=str, required=True,
208-
choices=['sum', 'mean', 'min', 'max'])
200+
choices=['sum', 'add', 'mean', 'min', 'max'])
209201
parser.add_argument('--with_backward', action='store_true')
210202
parser.add_argument('--device', type=str, default='cuda')
211203
args = parser.parse_args()
212-
args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
213204
iters = 1 if args.device == 'cpu' else 20
214205
sizes = [1, 16, 32, 64, 128, 256, 512]
215206
sizes = sizes[:3] if args.device == 'cpu' else sizes

cpu/compat.h

Lines changed: 0 additions & 5 deletions
This file was deleted.

cpu/dim_apply.h

Lines changed: 0 additions & 120 deletions
This file was deleted.

0 commit comments

Comments
 (0)