|
7 | 7 |
|
8 | 8 | from .utils import tensor, dtypes, devices |
9 | 9 |
|
10 | | -reductions = ['add', 'mean', 'min', 'max'] |
11 | | -grad_reductions = ['add', 'mean'] |
| 10 | +reductions = ['sum', 'mean', 'min', 'max'] |
| 11 | +grad_reductions = ['sum', 'mean'] |
12 | 12 |
|
13 | 13 | tests = [ |
14 | 14 | { |
15 | 15 | 'src': [1, 2, 3, 4, 5, 6], |
16 | 16 | 'index': [0, 0, 1, 1, 1, 3], |
17 | 17 | 'indptr': [0, 2, 5, 5, 6], |
18 | | - 'add': [3, 12, 0, 6], |
| 18 | + 'sum': [3, 12, 0, 6], |
19 | 19 | 'mean': [1.5, 4, 0, 6], |
20 | 20 | 'min': [1, 3, 0, 6], |
21 | 21 | 'arg_min': [0, 2, 6, 5], |
|
26 | 26 | 'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], |
27 | 27 | 'index': [0, 0, 1, 1, 1, 3], |
28 | 28 | 'indptr': [0, 2, 5, 5, 6], |
29 | | - 'add': [[4, 6], [21, 24], [0, 0], [11, 12]], |
| 29 | + 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], |
30 | 30 | 'mean': [[2, 3], [7, 8], [0, 0], [11, 12]], |
31 | 31 | 'min': [[1, 2], [5, 6], [0, 0], [11, 12]], |
32 | 32 | 'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]], |
|
37 | 37 | 'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]], |
38 | 38 | 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], |
39 | 39 | 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], |
40 | | - 'add': [[4, 21, 0, 11], [12, 18, 12, 0]], |
| 40 | + 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], |
41 | 41 | 'mean': [[2, 7, 0, 11], [4, 9, 12, 0]], |
42 | 42 | 'min': [[1, 5, 0, 11], [2, 8, 12, 0]], |
43 | 43 | 'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]], |
|
48 | 48 | 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], |
49 | 49 | 'index': [[0, 0, 1], [0, 2, 2]], |
50 | 50 | 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]], |
51 | | - 'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], |
| 51 | + 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], |
52 | 52 | 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], |
53 | 53 | 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], |
54 | 54 | 'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]], |
|
59 | 59 | 'src': [[1, 3], [2, 4]], |
60 | 60 | 'index': [[0, 0], [0, 0]], |
61 | 61 | 'indptr': [[0, 2], [0, 2]], |
62 | | - 'add': [[4], [6]], |
| 62 | + 'sum': [[4], [6]], |
63 | 63 | 'mean': [[2], [3]], |
64 | 64 | 'min': [[1], [2]], |
65 | 65 | 'arg_min': [[0], [0]], |
|
70 | 70 | 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], |
71 | 71 | 'index': [[0, 0], [0, 0]], |
72 | 72 | 'indptr': [[0, 2], [0, 2]], |
73 | | - 'add': [[[4, 4]], [[6, 6]]], |
| 73 | + 'sum': [[[4, 4]], [[6, 6]]], |
74 | 74 | 'mean': [[[2, 2]], [[3, 3]]], |
75 | 75 | 'min': [[[1, 1]], [[2, 2]]], |
76 | 76 | 'arg_min': [[[0, 0]], [[0, 0]]], |
@@ -134,7 +134,7 @@ def test_segment_out(test, reduce, dtype, device): |
134 | 134 |
|
135 | 135 | segment_coo(src, index, out, reduce=reduce) |
136 | 136 |
|
137 | | - if reduce == 'add': |
| 137 | + if reduce == 'sum': |
138 | 138 | expected = expected - 2 |
139 | 139 | elif reduce == 'mean': |
140 | 140 | expected = out # We can not really test this here. |
|
0 commit comments