|
7 | 7 | import torch |
8 | 8 | from scipy.io import loadmat |
9 | 9 |
|
10 | | -import torch_scatter |
11 | | -from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max |
12 | | -from torch_scatter import segment_coo, segment_csr |
| 10 | +from torch_scatter import scatter, segment_coo, segment_csr |
13 | 11 |
|
14 | 12 | short_rows = [ |
15 | 13 | ('DIMACS10', 'citationCiteseer'), |
@@ -47,34 +45,30 @@ def correctness(dataset): |
47 | 45 | x = torch.randn((row.size(0), size), device=args.device) |
48 | 46 | x = x.squeeze(-1) if size == 1 else x |
49 | 47 |
|
50 | | - out1 = scatter_add(x, row, dim=0, dim_size=dim_size) |
| 48 | + out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add') |
51 | 49 | out2 = segment_coo(x, row, dim_size=dim_size, reduce='add') |
52 | 50 | out3 = segment_csr(x, rowptr, reduce='add') |
53 | 51 |
|
54 | 52 | assert torch.allclose(out1, out2, atol=1e-4) |
55 | 53 | assert torch.allclose(out1, out3, atol=1e-4) |
56 | 54 |
|
57 | | - out1 = scatter_mean(x, row, dim=0, dim_size=dim_size) |
| 55 | + out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean') |
58 | 56 | out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean') |
59 | 57 | out3 = segment_csr(x, rowptr, reduce='mean') |
60 | 58 |
|
61 | 59 | assert torch.allclose(out1, out2, atol=1e-4) |
62 | 60 | assert torch.allclose(out1, out3, atol=1e-4) |
63 | 61 |
|
64 | | - x = x.abs_().mul_(-1) |
65 | | - |
66 | | - out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1)) |
67 | | - out2, _ = segment_coo(x, row, reduce='min') |
68 | | - out3, _ = segment_csr(x, rowptr, reduce='min') |
| 62 | + out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min') |
| 63 | + out2 = segment_coo(x, row, reduce='min') |
| 64 | + out3 = segment_csr(x, rowptr, reduce='min') |
69 | 65 |
|
70 | 66 | assert torch.allclose(out1, out2, atol=1e-4) |
71 | 67 | assert torch.allclose(out1, out3, atol=1e-4) |
72 | 68 |
|
73 | | - x = x.abs_() |
74 | | - |
75 | | - out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1)) |
76 | | - out2, _ = segment_coo(x, row, reduce='max') |
77 | | - out3, _ = segment_csr(x, rowptr, reduce='max') |
| 69 | + out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max') |
| 70 | + out2 = segment_coo(x, row, reduce='max') |
| 71 | + out3 = segment_csr(x, rowptr, reduce='max') |
78 | 72 |
|
79 | 73 | assert torch.allclose(out1, out2, atol=1e-4) |
80 | 74 | assert torch.allclose(out1, out3, atol=1e-4) |
@@ -117,17 +111,15 @@ def timing(dataset): |
117 | 111 | mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() |
118 | 112 | rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) |
119 | 113 | row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) |
120 | | - row_perm = row[torch.randperm(row.size(0))] |
| 114 | + row2 = row[torch.randperm(row.size(0))] |
121 | 115 | dim_size = rowptr.size(0) - 1 |
122 | 116 | avg_row_len = row.size(0) / dim_size |
123 | 117 |
|
124 | 118 | def sca_row(x): |
125 | | - op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}') |
126 | | - return op(x, row, dim=0, dim_size=dim_size) |
| 119 | + return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce) |
127 | 120 |
|
128 | 121 | def sca_col(x): |
129 | | - op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}') |
130 | | - return op(x, row_perm, dim=0, dim_size=dim_size) |
| 122 | + return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce) |
131 | 123 |
|
132 | 124 | def seg_coo(x): |
133 | 125 | return segment_coo(x, row, reduce=args.reduce) |
@@ -205,11 +197,10 @@ def dense2(x): |
205 | 197 | if __name__ == '__main__': |
206 | 198 | parser = argparse.ArgumentParser() |
207 | 199 | parser.add_argument('--reduce', type=str, required=True, |
208 | | - choices=['sum', 'mean', 'min', 'max']) |
| 200 | + choices=['sum', 'add', 'mean', 'min', 'max']) |
209 | 201 | parser.add_argument('--with_backward', action='store_true') |
210 | 202 | parser.add_argument('--device', type=str, default='cuda') |
211 | 203 | args = parser.parse_args() |
212 | | - args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce |
213 | 204 | iters = 1 if args.device == 'cpu' else 20 |
214 | 205 | sizes = [1, 16, 32, 64, 128, 256, 512] |
215 | 206 | sizes = sizes[:3] if args.device == 'cpu' else sizes |
|
0 commit comments