Skip to content

Commit 99db5b8

Browse files
committed
benchmark fixes
1 parent 82838e1 commit 99db5b8

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

benchmark/scatter_segment.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
import torch
88
from scipy.io import loadmat
99

10-
import torch_scatter
11-
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
12-
from torch_scatter import segment_coo, segment_csr
10+
from torch_scatter import scatter, segment_coo, segment_csr
1311

1412
short_rows = [
1513
('DIMACS10', 'citationCiteseer'),
@@ -47,34 +45,30 @@ def correctness(dataset):
4745
x = torch.randn((row.size(0), size), device=args.device)
4846
x = x.squeeze(-1) if size == 1 else x
4947

50-
out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
48+
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add')
5149
out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
5250
out3 = segment_csr(x, rowptr, reduce='add')
5351

5452
assert torch.allclose(out1, out2, atol=1e-4)
5553
assert torch.allclose(out1, out3, atol=1e-4)
5654

57-
out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
55+
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean')
5856
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
5957
out3 = segment_csr(x, rowptr, reduce='mean')
6058

6159
assert torch.allclose(out1, out2, atol=1e-4)
6260
assert torch.allclose(out1, out3, atol=1e-4)
6361

64-
x = x.abs_().mul_(-1)
65-
66-
out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1))
67-
out2, _ = segment_coo(x, row, reduce='min')
68-
out3, _ = segment_csr(x, rowptr, reduce='min')
62+
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min')
63+
out2 = segment_coo(x, row, reduce='min')
64+
out3 = segment_csr(x, rowptr, reduce='min')
6965

7066
assert torch.allclose(out1, out2, atol=1e-4)
7167
assert torch.allclose(out1, out3, atol=1e-4)
7268

73-
x = x.abs_()
74-
75-
out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1))
76-
out2, _ = segment_coo(x, row, reduce='max')
77-
out3, _ = segment_csr(x, rowptr, reduce='max')
69+
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max')
70+
out2 = segment_coo(x, row, reduce='max')
71+
out3 = segment_csr(x, rowptr, reduce='max')
7872

7973
assert torch.allclose(out1, out2, atol=1e-4)
8074
assert torch.allclose(out1, out3, atol=1e-4)
@@ -117,17 +111,15 @@ def timing(dataset):
117111
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
118112
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
119113
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
120-
row_perm = row[torch.randperm(row.size(0))]
114+
row2 = row[torch.randperm(row.size(0))]
121115
dim_size = rowptr.size(0) - 1
122116
avg_row_len = row.size(0) / dim_size
123117

124118
def sca_row(x):
125-
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
126-
return op(x, row, dim=0, dim_size=dim_size)
119+
return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
127120

128121
def sca_col(x):
129-
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
130-
return op(x, row_perm, dim=0, dim_size=dim_size)
122+
return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
131123

132124
def seg_coo(x):
133125
return segment_coo(x, row, reduce=args.reduce)
@@ -205,11 +197,10 @@ def dense2(x):
205197
if __name__ == '__main__':
206198
parser = argparse.ArgumentParser()
207199
parser.add_argument('--reduce', type=str, required=True,
208-
choices=['sum', 'mean', 'min', 'max'])
200+
choices=['sum', 'add', 'mean', 'min', 'max'])
209201
parser.add_argument('--with_backward', action='store_true')
210202
parser.add_argument('--device', type=str, default='cuda')
211203
args = parser.parse_args()
212-
args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
213204
iters = 1 if args.device == 'cpu' else 20
214205
sizes = [1, 16, 32, 64, 128, 256, 512]
215206
sizes = sizes[:3] if args.device == 'cpu' else sizes

0 commit comments

Comments
 (0)