|
7 | 7 |
|
8 | 8 | from .utils import reductions, tensor, dtypes, devices |
9 | 9 |
|
| 10 | +reductions = reductions + ['mul'] |
| 11 | + |
10 | 12 | tests = [ |
11 | 13 | { |
12 | 14 | 'src': [1, 3, 2, 4, 5, 6], |
13 | 15 | 'index': [0, 1, 0, 1, 1, 3], |
14 | 16 | 'dim': 0, |
15 | 17 | 'sum': [3, 12, 0, 6], |
16 | 18 | 'add': [3, 12, 0, 6], |
| 19 | + 'mul': [2, 60, 1, 6], |
17 | 20 | 'mean': [1.5, 4, 0, 6], |
18 | 21 | 'min': [1, 3, 0, 6], |
19 | 22 | 'arg_min': [0, 1, 6, 5], |
|
26 | 29 | 'dim': 0, |
27 | 30 | 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], |
28 | 31 | 'add': [[4, 6], [21, 24], [0, 0], [11, 12]], |
| 32 | + 'mul': [[1 * 3, 2 * 4], [5 * 7 * 9, 6 * 8 * 10], [1, 1], [11, 12]], |
29 | 33 | 'mean': [[2, 3], [7, 8], [0, 0], [11, 12]], |
30 | 34 | 'min': [[1, 2], [5, 6], [0, 0], [11, 12]], |
31 | 35 | 'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]], |
|
38 | 42 | 'dim': 1, |
39 | 43 | 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], |
40 | 44 | 'add': [[4, 21, 0, 11], [12, 18, 12, 0]], |
| 45 | + 'mul': [[1 * 3, 5 * 7 * 9, 1, 11], [2 * 4 * 6, 8 * 10, 12, 1]], |
41 | 46 | 'mean': [[2, 7, 0, 11], [4, 9, 12, 0]], |
42 | 47 | 'min': [[1, 5, 0, 11], [2, 8, 12, 0]], |
43 | 48 | 'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]], |
|
50 | 55 | 'dim': 1, |
51 | 56 | 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], |
52 | 57 | 'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], |
| 58 | + 'mul': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 11 * 13]]], |
53 | 59 | 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], |
54 | 60 | 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], |
55 | 61 | 'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]], |
|
62 | 68 | 'dim': 1, |
63 | 69 | 'sum': [[4], [6]], |
64 | 70 | 'add': [[4], [6]], |
| 71 | + 'mul': [[3], [8]], |
65 | 72 | 'mean': [[2], [3]], |
66 | 73 | 'min': [[1], [2]], |
67 | 74 | 'arg_min': [[0], [0]], |
|
74 | 81 | 'dim': 1, |
75 | 82 | 'sum': [[[4, 4]], [[6, 6]]], |
76 | 83 | 'add': [[[4, 4]], [[6, 6]]], |
| 84 | + 'mul': [[[3, 3]], [[8, 8]]], |
77 | 85 | 'mean': [[[2, 2]], [[3, 3]]], |
78 | 86 | 'min': [[[1, 1]], [[2, 2]]], |
79 | 87 | 'arg_min': [[[0, 0]], [[0, 0]]], |
@@ -125,6 +133,8 @@ def test_out(test, reduce, dtype, device): |
125 | 133 |
|
126 | 134 | if reduce == 'sum' or reduce == 'add': |
127 | 135 | expected = expected - 2 |
| 136 | + elif reduce == 'mul': |
| 137 | + expected = out # We can not really test this here. |
128 | 138 | elif reduce == 'mean': |
129 | 139 | expected = out # We can not really test this here. |
130 | 140 | elif reduce == 'min': |
|
0 commit comments