@@ -62,139 +62,155 @@ class OpConvolutionBackwardOutTest : public OperatorTest {
6262 grad_weight,
6363 grad_bias);
6464 }
65- };
6665
67- TEST_F (OpConvolutionBackwardOutTest, SmokeTest) {
68- TensorFactory<ScalarType::Float> tf;
66+ template <ScalarType DTYPE>
67+ void test_dtype () {
68+ TensorFactory<DTYPE> tf;
6969
70- std::vector<float > grad_output_data = {
71- 10 , 12 , 87 , 13 , 34 , 87 , 55 , 22 , 48 , 33 , 29 , 38 , 60 , 49 , 88 , 30 ,
72- 99 , 19 , 42 , 37 , 61 , 31 , 33 , 58 , 38 , 23 , 2 , 33 , 3 , 21 , 32 , 2 ,
73- 30 , 72 , 10 , 67 , 92 , 19 , 11 , 16 , 65 , 37 , 60 , 74 , 4 , 19 , 45 , 37 };
74- std::vector<float > input_data = {
75- 9 , 89 , 45 , 39 , 25 , 2 , 97 , 55 , 80 , 24 , 18 , 33 , 28 , 89 , 19 , 16 , 19 , 33 ,
76- 69 , 61 , 34 , 84 , 58 , 30 , 33 , 18 , 75 , 30 , 6 , 33 , 42 , 10 , 80 , 41 , 66 , 64 ,
77- 47 , 51 , 67 , 62 , 58 , 10 , 97 , 71 , 24 , 44 , 84 , 34 , 33 , 54 , 8 , 73 , 90 , 15 ,
78- 21 , 92 , 55 , 22 , 56 , 12 , 10 , 63 , 32 , 76 , 65 , 38 , 95 , 92 , 22 , 15 , 37 , 12 ,
79- 67 , 14 , 60 , 44 , 73 , 74 , 23 , 4 , 56 , 64 , 88 , 90 , 82 , 32 , 91 , 3 , 6 , 87 ,
80- 55 , 95 , 7 , 14 , 24 , 69 , 52 , 44 , 14 , 37 , 75 , 52 , 37 , 40 , 25 , 54 , 4 , 15 ,
81- 97 , 51 , 46 , 28 , 65 , 95 , 50 , 82 , 23 , 39 , 50 , 55 , 97 , 52 , 91 , 16 , 19 , 49 ,
82- 61 , 50 , 42 , 47 , 87 , 99 , 9 , 60 , 22 , 71 , 47 , 17 , 0 , 80 , 28 , 88 , 93 , 43 ,
83- 65 , 25 , 88 , 67 , 21 , 89 , 24 , 81 , 3 , 71 , 20 , 34 , 17 , 17 , 94 , 10 , 82 , 25 ,
84- 10 , 11 , 7 , 28 , 77 , 39 , 74 , 79 , 17 , 40 , 67 , 54 , 49 , 54 , 21 , 89 , 17 , 7 ,
85- 52 , 64 , 68 , 80 , 7 , 72 , 44 , 35 , 92 , 47 , 4 , 13 , 10 , 43 , 64 , 66 , 83 , 49 ,
86- 81 , 78 , 58 , 22 , 86 , 48 , 35 , 64 , 98 , 79 , 8 , 52 , 56 , 23 , 38 , 74 , 16 , 63 ,
87- 51 , 70 , 44 , 28 , 43 , 13 , 51 , 85 , 42 , 29 , 64 , 26 , 54 , 91 , 9 , 96 , 41 , 56 ,
88- 7 , 52 , 27 , 22 , 69 , 13 , 8 , 20 , 22 , 49 , 66 , 98 , 77 , 42 , 54 , 38 , 70 , 83 ,
89- 13 , 8 , 21 , 56 , 78 , 37 , 28 , 69 , 42 , 30 , 91 , 5 , 28 , 15 , 20 , 14 , 16 , 39 ,
90- 95 , 66 , 4 , 72 , 52 , 35 , 54 , 93 , 87 , 77 , 3 , 49 , 82 , 70 , 84 , 3 , 73 , 99 ,
91- 32 , 95 , 58 , 65 , 32 , 75 , 34 , 22 , 12 , 84 , 63 , 72 , 85 , 66 , 63 , 27 , 3 , 73 ,
92- 45 , 37 , 61 , 52 , 41 , 16 , 37 , 14 , 80 , 17 , 48 , 8 , 87 , 98 , 69 , 63 , 92 , 68 ,
93- 42 , 63 , 5 , 22 , 66 , 91 , 74 , 11 , 17 , 45 , 45 , 33 , 40 , 85 , 26 , 75 , 73 , 81 ,
94- 54 , 27 , 80 , 1 , 44 , 66 , 10 , 21 , 15 , 10 , 76 , 96 , 0 , 43 , 39 , 3 , 57 , 79 ,
95- 45 , 64 , 58 , 92 , 44 , 42 , 7 , 28 , 94 , 4 , 8 , 22 , 22 , 31 , 75 , 44 , 3 , 70 ,
96- 83 , 72 , 87 , 12 , 20 , 55 , 84 , 31 , 50 , 34 , 25 , 49 , 29 , 71 , 57 , 97 , 25 , 82 ,
97- 84 , 42 , 86 , 41 , 54 , 92 , 34 , 30 , 52 , 34 , 84 , 25 , 54 , 37 , 38 , 26 , 76 , 82 ,
98- 34 , 14 , 85 , 28 , 93 , 9 };
99- std::vector<float > weight_data = {
100- 2 , 54 , 9 , 37 , 0 , 47 , 70 , 9 , 84 , 69 , 56 , 79 , 25 , 35 , 54 , 13 ,
101- 65 , 46 , 38 , 28 , 74 , 27 , 66 , 61 , 20 , 60 , 62 , 58 , 15 , 44 , 75 , 55 ,
102- 7 , 52 , 13 , 36 , 39 , 64 , 62 , 45 , 100 , 6 , 79 , 63 , 63 , 52 , 37 , 60 ,
103- 78 , 12 , 69 , 2 , 74 , 56 , 93 , 39 , 62 , 22 , 55 , 67 , 68 , 74 , 12 , 69 ,
104- 15 , 73 , 28 , 70 , 86 , 20 , 90 , 49 , 52 , 26 , 58 , 2 , 82 , 17 , 70 , 55 ,
105- 54 , 83 , 70 , 11 , 27 , 9 , 5 , 42 , 34 , 62 , 29 , 94 , 69 , 81 , 54 , 4 };
106- std::vector<float > expected_grad_input_data = {
107- 1134 , 7578 , 686 , 2682 , 0 , 4148 , 7136 , 2406 , 8698 , 0 ,
108- 3759 , 6003 , 2163 , 2395 , 0 , 2929 , 5830 , 3469 , 6955 , 0 ,
109- 720 , 6201 , 495 , 2063 , 0 , 5260 , 5989 , 3060 , 7079 , 0 ,
110- 9690 , 3423 , 3385 , 1932 , 0 , 7644 , 8499 , 1323 , 2613 , 0 ,
111- 4334 , 6624 , 8532 , 9719 , 0 , 5496 , 8601 , 1157 , 2215 , 0 ,
112- 4676 , 7600 , 6524 , 10069 , 0 , 4047 , 6117 , 1612 , 2567 , 0 ,
113- 5931 , 5651 , 5669 , 6623 , 0 , 7674 , 3291 , 2748 , 1654 , 0 ,
114- 10455 , 4290 , 4145 , 796 , 0 , 9835 , 5483 , 11649 , 5952 , 0 ,
115- 7098 , 5460 , 3101 , 2443 , 0 , 7788 , 5909 , 8582 , 6298 , 0 ,
116- 9462 , 4845 , 3041 , 2067 , 0 , 7038 , 6336 , 10438 , 6377 , 0 ,
117- 7518 , 8187 , 2079 , 2773 , 0 , 10036 , 2642 , 3952 , 1166 , 0 ,
118- 16014 , 2250 , 10025 , 1908 , 0 , 9610 , 298 , 3868 , 122 , 0 ,
119- 16629 , 4338 , 11335 , 3527 , 0 , 11514 , 5965 , 4762 , 2207 , 0 ,
120- 18552 , 10755 , 13309 , 5996 , 0 , 12454 , 6787 , 4960 , 2875 , 0 ,
121- 8750 , 6999 , 3534 , 3233 , 0 , 14160 , 9399 , 9595 , 8922 , 0 ,
122- 9110 , 6567 , 3820 , 2351 , 0 , 12969 , 11814 , 9436 , 5870 , 0 ,
123- 7631 , 7061 , 2877 , 2499 , 0 , 8553 , 13527 , 3631 , 6863 , 0 ,
124- 1361 , 8634 , 515 , 3372 , 0 , 3394 , 10206 , 1504 , 4112 , 0 ,
125- 5505 , 17421 , 4702 , 11891 , 0 , 4233 , 11894 , 1739 , 5014 , 0 ,
126- 11787 , 14634 , 8981 , 10759 , 0 , 11777 , 6701 , 4719 , 3111 , 0 ,
127- 18459 , 7761 , 12044 , 7627 , 0 , 11214 , 4556 , 4374 , 1594 , 0 ,
128- 604 , 1908 , 1506 , 6102 , 0 , 2532 , 4024 , 1713 , 6121 , 0 ,
129- 1878 , 1814 , 4761 , 5397 , 0 , 1127 , 3885 , 4373 , 5832 , 0 ,
130- 450 , 1414 , 1080 , 4719 , 0 , 5210 , 2683 , 2765 , 4252 , 0 ,
131- 2390 , 1668 , 7710 , 4257 , 0 , 378 , 1698 , 3276 , 6021 , 0 ,
132- 2866 , 4881 , 3547 , 6822 , 0 , 502 , 1238 , 2784 , 5199 , 0 ,
133- 2496 , 3975 , 2700 , 5004 , 0 , 1220 , 1990 , 3633 , 5763 , 0 ,
134- 4501 , 2679 , 4504 , 5412 , 0 , 1968 , 1376 , 6246 , 3669 , 0 ,
135- 3130 , 272 , 9345 , 1950 , 0 , 5167 , 3278 , 9097 , 2138 , 0 ,
136- 2446 , 1946 , 6942 , 5460 , 0 , 5732 , 3404 , 7919 , 5534 , 0 ,
137- 2038 , 1614 , 6978 , 4635 , 0 , 4544 , 4839 , 7367 , 5574 , 0 ,
138- 1242 , 1922 , 4842 , 6333 , 0 , 1066 , 236 , 2236 , 686 , 0 ,
139- 17238 , 2254 , 10413 , 1592 , 0 , 991 , 30 , 2206 , 70 , 0 ,
140- 18823 , 6392 , 12173 , 2470 , 0 , 1142 , 684 , 2742 , 1219 , 0 ,
141- 21256 , 11293 , 12719 , 7512 , 0 , 1303 , 649 , 2818 , 1669 , 0 ,
142- 898 , 574 , 2018 , 1929 , 0 , 15720 , 11989 , 10517 , 5972 , 0 ,
143- 885 , 781 , 2210 , 1281 , 0 , 14601 , 12198 , 7915 , 4958 , 0 ,
144- 856 , 850 , 1601 , 1355 , 0 , 7039 , 14083 , 4113 , 7490 , 0 ,
145- 152 , 927 , 287 , 1902 , 0 , 301 , 1051 , 886 , 2346 , 0 ,
146- 6821 , 19615 , 4491 , 13281 , 0 , 424 , 1146 , 999 , 2906 , 0 ,
147- 15177 , 15480 , 8849 , 12442 , 0 , 1222 , 544 , 2687 , 1859 , 0 ,
148- 20215 , 9693 , 11441 , 4964 , 0 , 1206 , 555 , 2466 , 860 , 0 };
149- std::vector<float > expected_grad_weight_data = {
150- 9246 , 22073 , 12431 , 19714 , 11179 , 19032 , 8458 , 6495 , 18707 , 13830 ,
151- 20445 , 17089 , 17124 , 18710 , 11827 , 17236 , 16824 , 9008 , 14086 , 18834 ,
152- 17419 , 16759 , 13152 , 9339 , 13801 , 20888 , 13976 , 27277 , 13010 , 23949 ,
153- 9838 , 11220 , 17658 , 15019 , 25337 , 17583 , 13270 , 21754 , 16908 , 20563 ,
154- 20732 , 13413 , 20868 , 27521 , 19537 , 21170 , 15888 , 10034 , 19195 , 16370 ,
155- 40243 , 25890 , 40472 , 30460 , 21228 , 21625 , 13289 , 24435 , 19876 , 29816 ,
156- 24188 , 23619 , 13752 , 16251 , 18741 , 19368 , 24517 , 34261 , 27054 , 31257 ,
157- 21238 , 18909 , 15776 , 16881 , 34604 , 22534 , 28101 , 23834 , 18479 , 16469 ,
158- 12852 , 16551 , 14204 , 29983 , 20167 , 24150 , 14281 , 17501 , 15897 , 16019 ,
159- 21661 , 32765 , 23874 , 26527 , 20463 , 18661 };
160- std::vector<float > expected_grad_bias_data = {363 , 438 , 585 , 501 };
70+ using CTYPE = typename decltype (tf)::ctype;
71+ std::vector<CTYPE> grad_output_data = {
72+ 10 , 12 , 87 , 13 , 34 , 87 , 55 , 22 , 48 , 33 , 29 , 38 , 60 , 49 , 88 , 30 ,
73+ 99 , 19 , 42 , 37 , 61 , 31 , 33 , 58 , 38 , 23 , 2 , 33 , 3 , 21 , 32 , 2 ,
74+ 30 , 72 , 10 , 67 , 92 , 19 , 11 , 16 , 65 , 37 , 60 , 74 , 4 , 19 , 45 , 37 };
75+ std::vector<CTYPE> input_data = {
76+ 9 , 89 , 45 , 39 , 25 , 2 , 97 , 55 , 80 , 24 , 18 , 33 , 28 , 89 , 19 , 16 , 19 , 33 ,
77+ 69 , 61 , 34 , 84 , 58 , 30 , 33 , 18 , 75 , 30 , 6 , 33 , 42 , 10 , 80 , 41 , 66 , 64 ,
78+ 47 , 51 , 67 , 62 , 58 , 10 , 97 , 71 , 24 , 44 , 84 , 34 , 33 , 54 , 8 , 73 , 90 , 15 ,
79+ 21 , 92 , 55 , 22 , 56 , 12 , 10 , 63 , 32 , 76 , 65 , 38 , 95 , 92 , 22 , 15 , 37 , 12 ,
80+ 67 , 14 , 60 , 44 , 73 , 74 , 23 , 4 , 56 , 64 , 88 , 90 , 82 , 32 , 91 , 3 , 6 , 87 ,
81+ 55 , 95 , 7 , 14 , 24 , 69 , 52 , 44 , 14 , 37 , 75 , 52 , 37 , 40 , 25 , 54 , 4 , 15 ,
82+ 97 , 51 , 46 , 28 , 65 , 95 , 50 , 82 , 23 , 39 , 50 , 55 , 97 , 52 , 91 , 16 , 19 , 49 ,
83+ 61 , 50 , 42 , 47 , 87 , 99 , 9 , 60 , 22 , 71 , 47 , 17 , 0 , 80 , 28 , 88 , 93 , 43 ,
84+ 65 , 25 , 88 , 67 , 21 , 89 , 24 , 81 , 3 , 71 , 20 , 34 , 17 , 17 , 94 , 10 , 82 , 25 ,
85+ 10 , 11 , 7 , 28 , 77 , 39 , 74 , 79 , 17 , 40 , 67 , 54 , 49 , 54 , 21 , 89 , 17 , 7 ,
86+ 52 , 64 , 68 , 80 , 7 , 72 , 44 , 35 , 92 , 47 , 4 , 13 , 10 , 43 , 64 , 66 , 83 , 49 ,
87+ 81 , 78 , 58 , 22 , 86 , 48 , 35 , 64 , 98 , 79 , 8 , 52 , 56 , 23 , 38 , 74 , 16 , 63 ,
88+ 51 , 70 , 44 , 28 , 43 , 13 , 51 , 85 , 42 , 29 , 64 , 26 , 54 , 91 , 9 , 96 , 41 , 56 ,
89+ 7 , 52 , 27 , 22 , 69 , 13 , 8 , 20 , 22 , 49 , 66 , 98 , 77 , 42 , 54 , 38 , 70 , 83 ,
90+ 13 , 8 , 21 , 56 , 78 , 37 , 28 , 69 , 42 , 30 , 91 , 5 , 28 , 15 , 20 , 14 , 16 , 39 ,
91+ 95 , 66 , 4 , 72 , 52 , 35 , 54 , 93 , 87 , 77 , 3 , 49 , 82 , 70 , 84 , 3 , 73 , 99 ,
92+ 32 , 95 , 58 , 65 , 32 , 75 , 34 , 22 , 12 , 84 , 63 , 72 , 85 , 66 , 63 , 27 , 3 , 73 ,
93+ 45 , 37 , 61 , 52 , 41 , 16 , 37 , 14 , 80 , 17 , 48 , 8 , 87 , 98 , 69 , 63 , 92 , 68 ,
94+ 42 , 63 , 5 , 22 , 66 , 91 , 74 , 11 , 17 , 45 , 45 , 33 , 40 , 85 , 26 , 75 , 73 , 81 ,
95+ 54 , 27 , 80 , 1 , 44 , 66 , 10 , 21 , 15 , 10 , 76 , 96 , 0 , 43 , 39 , 3 , 57 , 79 ,
96+ 45 , 64 , 58 , 92 , 44 , 42 , 7 , 28 , 94 , 4 , 8 , 22 , 22 , 31 , 75 , 44 , 3 , 70 ,
97+ 83 , 72 , 87 , 12 , 20 , 55 , 84 , 31 , 50 , 34 , 25 , 49 , 29 , 71 , 57 , 97 , 25 , 82 ,
98+ 84 , 42 , 86 , 41 , 54 , 92 , 34 , 30 , 52 , 34 , 84 , 25 , 54 , 37 , 38 , 26 , 76 , 82 ,
99+ 34 , 14 , 85 , 28 , 93 , 9 };
100+ std::vector<CTYPE> weight_data = {
101+ 2 , 54 , 9 , 37 , 0 , 47 , 70 , 9 , 84 , 69 , 56 , 79 , 25 , 35 , 54 , 13 ,
102+ 65 , 46 , 38 , 28 , 74 , 27 , 66 , 61 , 20 , 60 , 62 , 58 , 15 , 44 , 75 , 55 ,
103+ 7 , 52 , 13 , 36 , 39 , 64 , 62 , 45 , 100 , 6 , 79 , 63 , 63 , 52 , 37 , 60 ,
104+ 78 , 12 , 69 , 2 , 74 , 56 , 93 , 39 , 62 , 22 , 55 , 67 , 68 , 74 , 12 , 69 ,
105+ 15 , 73 , 28 , 70 , 86 , 20 , 90 , 49 , 52 , 26 , 58 , 2 , 82 , 17 , 70 , 55 ,
106+ 54 , 83 , 70 , 11 , 27 , 9 , 5 , 42 , 34 , 62 , 29 , 94 , 69 , 81 , 54 , 4 };
107+ std::vector<CTYPE> expected_grad_input_data = {
108+ 1134 , 7578 , 686 , 2682 , 0 , 4148 , 7136 , 2406 , 8698 , 0 ,
109+ 3759 , 6003 , 2163 , 2395 , 0 , 2929 , 5830 , 3469 , 6955 , 0 ,
110+ 720 , 6201 , 495 , 2063 , 0 , 5260 , 5989 , 3060 , 7079 , 0 ,
111+ 9690 , 3423 , 3385 , 1932 , 0 , 7644 , 8499 , 1323 , 2613 , 0 ,
112+ 4334 , 6624 , 8532 , 9719 , 0 , 5496 , 8601 , 1157 , 2215 , 0 ,
113+ 4676 , 7600 , 6524 , 10069 , 0 , 4047 , 6117 , 1612 , 2567 , 0 ,
114+ 5931 , 5651 , 5669 , 6623 , 0 , 7674 , 3291 , 2748 , 1654 , 0 ,
115+ 10455 , 4290 , 4145 , 796 , 0 , 9835 , 5483 , 11649 , 5952 , 0 ,
116+ 7098 , 5460 , 3101 , 2443 , 0 , 7788 , 5909 , 8582 , 6298 , 0 ,
117+ 9462 , 4845 , 3041 , 2067 , 0 , 7038 , 6336 , 10438 , 6377 , 0 ,
118+ 7518 , 8187 , 2079 , 2773 , 0 , 10036 , 2642 , 3952 , 1166 , 0 ,
119+ 16014 , 2250 , 10025 , 1908 , 0 , 9610 , 298 , 3868 , 122 , 0 ,
120+ 16629 , 4338 , 11335 , 3527 , 0 , 11514 , 5965 , 4762 , 2207 , 0 ,
121+ 18552 , 10755 , 13309 , 5996 , 0 , 12454 , 6787 , 4960 , 2875 , 0 ,
122+ 8750 , 6999 , 3534 , 3233 , 0 , 14160 , 9399 , 9595 , 8922 , 0 ,
123+ 9110 , 6567 , 3820 , 2351 , 0 , 12969 , 11814 , 9436 , 5870 , 0 ,
124+ 7631 , 7061 , 2877 , 2499 , 0 , 8553 , 13527 , 3631 , 6863 , 0 ,
125+ 1361 , 8634 , 515 , 3372 , 0 , 3394 , 10206 , 1504 , 4112 , 0 ,
126+ 5505 , 17421 , 4702 , 11891 , 0 , 4233 , 11894 , 1739 , 5014 , 0 ,
127+ 11787 , 14634 , 8981 , 10759 , 0 , 11777 , 6701 , 4719 , 3111 , 0 ,
128+ 18459 , 7761 , 12044 , 7627 , 0 , 11214 , 4556 , 4374 , 1594 , 0 ,
129+ 604 , 1908 , 1506 , 6102 , 0 , 2532 , 4024 , 1713 , 6121 , 0 ,
130+ 1878 , 1814 , 4761 , 5397 , 0 , 1127 , 3885 , 4373 , 5832 , 0 ,
131+ 450 , 1414 , 1080 , 4719 , 0 , 5210 , 2683 , 2765 , 4252 , 0 ,
132+ 2390 , 1668 , 7710 , 4257 , 0 , 378 , 1698 , 3276 , 6021 , 0 ,
133+ 2866 , 4881 , 3547 , 6822 , 0 , 502 , 1238 , 2784 , 5199 , 0 ,
134+ 2496 , 3975 , 2700 , 5004 , 0 , 1220 , 1990 , 3633 , 5763 , 0 ,
135+ 4501 , 2679 , 4504 , 5412 , 0 , 1968 , 1376 , 6246 , 3669 , 0 ,
136+ 3130 , 272 , 9345 , 1950 , 0 , 5167 , 3278 , 9097 , 2138 , 0 ,
137+ 2446 , 1946 , 6942 , 5460 , 0 , 5732 , 3404 , 7919 , 5534 , 0 ,
138+ 2038 , 1614 , 6978 , 4635 , 0 , 4544 , 4839 , 7367 , 5574 , 0 ,
139+ 1242 , 1922 , 4842 , 6333 , 0 , 1066 , 236 , 2236 , 686 , 0 ,
140+ 17238 , 2254 , 10413 , 1592 , 0 , 991 , 30 , 2206 , 70 , 0 ,
141+ 18823 , 6392 , 12173 , 2470 , 0 , 1142 , 684 , 2742 , 1219 , 0 ,
142+ 21256 , 11293 , 12719 , 7512 , 0 , 1303 , 649 , 2818 , 1669 , 0 ,
143+ 898 , 574 , 2018 , 1929 , 0 , 15720 , 11989 , 10517 , 5972 , 0 ,
144+ 885 , 781 , 2210 , 1281 , 0 , 14601 , 12198 , 7915 , 4958 , 0 ,
145+ 856 , 850 , 1601 , 1355 , 0 , 7039 , 14083 , 4113 , 7490 , 0 ,
146+ 152 , 927 , 287 , 1902 , 0 , 301 , 1051 , 886 , 2346 , 0 ,
147+ 6821 , 19615 , 4491 , 13281 , 0 , 424 , 1146 , 999 , 2906 , 0 ,
148+ 15177 , 15480 , 8849 , 12442 , 0 , 1222 , 544 , 2687 , 1859 , 0 ,
149+ 20215 , 9693 , 11441 , 4964 , 0 , 1206 , 555 , 2466 , 860 , 0 };
150+ std::vector<CTYPE> expected_grad_weight_data = {
151+ 9246 , 22073 , 12431 , 19714 , 11179 , 19032 , 8458 , 6495 , 18707 , 13830 ,
152+ 20445 , 17089 , 17124 , 18710 , 11827 , 17236 , 16824 , 9008 , 14086 , 18834 ,
153+ 17419 , 16759 , 13152 , 9339 , 13801 , 20888 , 13976 , 27277 , 13010 , 23949 ,
154+ 9838 , 11220 , 17658 , 15019 , 25337 , 17583 , 13270 , 21754 , 16908 , 20563 ,
155+ 20732 , 13413 , 20868 , 27521 , 19537 , 21170 , 15888 , 10034 , 19195 , 16370 ,
156+ 40243 , 25890 , 40472 , 30460 , 21228 , 21625 , 13289 , 24435 , 19876 , 29816 ,
157+ 24188 , 23619 , 13752 , 16251 , 18741 , 19368 , 24517 , 34261 , 27054 , 31257 ,
158+ 21238 , 18909 , 15776 , 16881 , 34604 , 22534 , 28101 , 23834 , 18479 , 16469 ,
159+ 12852 , 16551 , 14204 , 29983 , 20167 , 24150 , 14281 , 17501 , 15897 , 16019 ,
160+ 21661 , 32765 , 23874 , 26527 , 20463 , 18661 };
161+ std::vector<CTYPE> expected_grad_bias_data = {363 , 438 , 585 , 501 };
161162
162- auto grad_output = tf.make ({2 , 4 , 3 , 2 }, grad_output_data);
163- auto input = tf.make ({2 , 6 , 7 , 5 }, input_data);
164- auto weight = tf.make ({4 , 3 , 4 , 2 }, weight_data);
165- int64_t bias_sizes[1 ] = {4 };
166- int64_t stride[2 ] = {1 , 2 };
167- int64_t padding[2 ] = {1 , 0 };
168- int64_t dilation[2 ] = {2 , 1 };
169- bool transposed = false ;
170- int64_t output_padding[2 ] = {0 , 0 };
171- int64_t groups = 2 ;
172- std::array<bool , 3 > output_mask_a = {true , true , true };
173- auto grad_input = tf.zeros ({2 , 6 , 7 , 5 });
174- auto grad_weight = tf.zeros ({4 , 3 , 4 , 2 });
175- auto grad_bias = tf.zeros ({4 });
163+ auto grad_output = tf.make ({2 , 4 , 3 , 2 }, grad_output_data);
164+ auto input = tf.make ({2 , 6 , 7 , 5 }, input_data);
165+ auto weight = tf.make ({4 , 3 , 4 , 2 }, weight_data);
166+ int64_t bias_sizes[1 ] = {4 };
167+ int64_t stride[2 ] = {1 , 2 };
168+ int64_t padding[2 ] = {1 , 0 };
169+ int64_t dilation[2 ] = {2 , 1 };
170+ bool transposed = false ;
171+ int64_t output_padding[2 ] = {0 , 0 };
172+ int64_t groups = 2 ;
173+ std::array<bool , 3 > output_mask_a = {true , true , true };
174+ auto grad_input = tf.zeros ({2 , 6 , 7 , 5 });
175+ auto grad_weight = tf.zeros ({4 , 3 , 4 , 2 });
176+ auto grad_bias = tf.zeros ({4 });
177+
178+ op_convolution_backward_out (
179+ grad_output,
180+ input,
181+ weight,
182+ IntArrayRef{bias_sizes, 1 },
183+ IntArrayRef{stride, 2 },
184+ IntArrayRef{padding, 2 },
185+ IntArrayRef{dilation, 2 },
186+ transposed,
187+ IntArrayRef{output_padding, 2 },
188+ groups,
189+ output_mask_a,
190+ grad_input,
191+ grad_weight,
192+ grad_bias);
176193
177- op_convolution_backward_out (
178- grad_output,
179- input,
180- weight,
181- IntArrayRef{bias_sizes, 1 },
182- IntArrayRef{stride, 2 },
183- IntArrayRef{padding, 2 },
184- IntArrayRef{dilation, 2 },
185- transposed,
186- IntArrayRef{output_padding, 2 },
187- groups,
188- output_mask_a,
189- grad_input,
190- grad_weight,
191- grad_bias);
194+ auto expected_grad_input = tf.make ({2 , 6 , 7 , 5 }, expected_grad_input_data);
195+ auto expected_grad_weight =
196+ tf.make ({4 , 3 , 4 , 2 }, expected_grad_weight_data);
197+ auto expected_grad_bias = tf.make ({4 }, expected_grad_bias_data);
192198
193- auto expected_grad_input = tf.make ({2 , 6 , 7 , 5 }, expected_grad_input_data);
194- auto expected_grad_weight = tf.make ({4 , 3 , 4 , 2 }, expected_grad_weight_data);
195- auto expected_grad_bias = tf.make ({4 }, expected_grad_bias_data);
199+ if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) {
200+ EXPECT_TENSOR_CLOSE_WITH_TOL (grad_input, expected_grad_input, 1e-2 , 1e-8 );
201+ EXPECT_TENSOR_CLOSE_WITH_TOL (
202+ grad_weight, expected_grad_weight, 2e-2 , 1e-8 );
203+ EXPECT_TENSOR_CLOSE_WITH_TOL (grad_bias, expected_grad_bias, 1e-2 , 1e-8 );
204+ } else {
205+ EXPECT_TENSOR_CLOSE (grad_input, expected_grad_input);
206+ EXPECT_TENSOR_CLOSE (grad_weight, expected_grad_weight);
207+ EXPECT_TENSOR_CLOSE (grad_bias, expected_grad_bias);
208+ }
209+ }
210+ };
196211
197- EXPECT_TENSOR_CLOSE (grad_input, expected_grad_input);
198- EXPECT_TENSOR_CLOSE (grad_weight, expected_grad_weight);
199- EXPECT_TENSOR_CLOSE (grad_bias, expected_grad_bias);
212+ TEST_F (OpConvolutionBackwardOutTest, SmokeTest) {
213+ #define TEST_ENTRY (ctype, dtype ) test_dtype<ScalarType::dtype>();
214+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
215+ #undef TEST_ENTRY
200216}
0 commit comments