Skip to content

Commit 8e7b91e

Browse files
swolchokZonglin Peng
authored andcommitted
Support BFloat16 in convolution_backward (pytorch#7807)
Partial fix for pytorch#7748.
1 parent 80c9acd commit 8e7b91e

File tree

3 files changed

+147
-130
lines changed

3 files changed

+147
-130
lines changed

kernels/portable/cpu/op_convolution_backward.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> convolution_backward_out(
289289

290290
constexpr auto name = "convolution_backward.out";
291291

292-
ET_SWITCH_FLOATH_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
292+
ET_SWITCH_FLOATHBF16_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
293293
conv2d_backward_impl<CTYPE>(
294294
grad_output,
295295
input,

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ set(all_test_sources
117117
"op_clamp_test.cpp"
118118
"op_clone_test.cpp"
119119
"op_constant_pad_nd_test.cpp"
120+
"op_convolution_backward_test.cpp"
120121
"op_convolution_test.cpp"
121122
"op_copy_test.cpp"
122123
"op_cos_test.cpp"

kernels/test/op_convolution_backward_test.cpp

Lines changed: 145 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -62,139 +62,155 @@ class OpConvolutionBackwardOutTest : public OperatorTest {
6262
grad_weight,
6363
grad_bias);
6464
}
65-
};
6665

67-
TEST_F(OpConvolutionBackwardOutTest, SmokeTest) {
68-
TensorFactory<ScalarType::Float> tf;
66+
template <ScalarType DTYPE>
67+
void test_dtype() {
68+
TensorFactory<DTYPE> tf;
6969

70-
std::vector<float> grad_output_data = {
71-
10, 12, 87, 13, 34, 87, 55, 22, 48, 33, 29, 38, 60, 49, 88, 30,
72-
99, 19, 42, 37, 61, 31, 33, 58, 38, 23, 2, 33, 3, 21, 32, 2,
73-
30, 72, 10, 67, 92, 19, 11, 16, 65, 37, 60, 74, 4, 19, 45, 37};
74-
std::vector<float> input_data = {
75-
9, 89, 45, 39, 25, 2, 97, 55, 80, 24, 18, 33, 28, 89, 19, 16, 19, 33,
76-
69, 61, 34, 84, 58, 30, 33, 18, 75, 30, 6, 33, 42, 10, 80, 41, 66, 64,
77-
47, 51, 67, 62, 58, 10, 97, 71, 24, 44, 84, 34, 33, 54, 8, 73, 90, 15,
78-
21, 92, 55, 22, 56, 12, 10, 63, 32, 76, 65, 38, 95, 92, 22, 15, 37, 12,
79-
67, 14, 60, 44, 73, 74, 23, 4, 56, 64, 88, 90, 82, 32, 91, 3, 6, 87,
80-
55, 95, 7, 14, 24, 69, 52, 44, 14, 37, 75, 52, 37, 40, 25, 54, 4, 15,
81-
97, 51, 46, 28, 65, 95, 50, 82, 23, 39, 50, 55, 97, 52, 91, 16, 19, 49,
82-
61, 50, 42, 47, 87, 99, 9, 60, 22, 71, 47, 17, 0, 80, 28, 88, 93, 43,
83-
65, 25, 88, 67, 21, 89, 24, 81, 3, 71, 20, 34, 17, 17, 94, 10, 82, 25,
84-
10, 11, 7, 28, 77, 39, 74, 79, 17, 40, 67, 54, 49, 54, 21, 89, 17, 7,
85-
52, 64, 68, 80, 7, 72, 44, 35, 92, 47, 4, 13, 10, 43, 64, 66, 83, 49,
86-
81, 78, 58, 22, 86, 48, 35, 64, 98, 79, 8, 52, 56, 23, 38, 74, 16, 63,
87-
51, 70, 44, 28, 43, 13, 51, 85, 42, 29, 64, 26, 54, 91, 9, 96, 41, 56,
88-
7, 52, 27, 22, 69, 13, 8, 20, 22, 49, 66, 98, 77, 42, 54, 38, 70, 83,
89-
13, 8, 21, 56, 78, 37, 28, 69, 42, 30, 91, 5, 28, 15, 20, 14, 16, 39,
90-
95, 66, 4, 72, 52, 35, 54, 93, 87, 77, 3, 49, 82, 70, 84, 3, 73, 99,
91-
32, 95, 58, 65, 32, 75, 34, 22, 12, 84, 63, 72, 85, 66, 63, 27, 3, 73,
92-
45, 37, 61, 52, 41, 16, 37, 14, 80, 17, 48, 8, 87, 98, 69, 63, 92, 68,
93-
42, 63, 5, 22, 66, 91, 74, 11, 17, 45, 45, 33, 40, 85, 26, 75, 73, 81,
94-
54, 27, 80, 1, 44, 66, 10, 21, 15, 10, 76, 96, 0, 43, 39, 3, 57, 79,
95-
45, 64, 58, 92, 44, 42, 7, 28, 94, 4, 8, 22, 22, 31, 75, 44, 3, 70,
96-
83, 72, 87, 12, 20, 55, 84, 31, 50, 34, 25, 49, 29, 71, 57, 97, 25, 82,
97-
84, 42, 86, 41, 54, 92, 34, 30, 52, 34, 84, 25, 54, 37, 38, 26, 76, 82,
98-
34, 14, 85, 28, 93, 9};
99-
std::vector<float> weight_data = {
100-
2, 54, 9, 37, 0, 47, 70, 9, 84, 69, 56, 79, 25, 35, 54, 13,
101-
65, 46, 38, 28, 74, 27, 66, 61, 20, 60, 62, 58, 15, 44, 75, 55,
102-
7, 52, 13, 36, 39, 64, 62, 45, 100, 6, 79, 63, 63, 52, 37, 60,
103-
78, 12, 69, 2, 74, 56, 93, 39, 62, 22, 55, 67, 68, 74, 12, 69,
104-
15, 73, 28, 70, 86, 20, 90, 49, 52, 26, 58, 2, 82, 17, 70, 55,
105-
54, 83, 70, 11, 27, 9, 5, 42, 34, 62, 29, 94, 69, 81, 54, 4};
106-
std::vector<float> expected_grad_input_data = {
107-
1134, 7578, 686, 2682, 0, 4148, 7136, 2406, 8698, 0,
108-
3759, 6003, 2163, 2395, 0, 2929, 5830, 3469, 6955, 0,
109-
720, 6201, 495, 2063, 0, 5260, 5989, 3060, 7079, 0,
110-
9690, 3423, 3385, 1932, 0, 7644, 8499, 1323, 2613, 0,
111-
4334, 6624, 8532, 9719, 0, 5496, 8601, 1157, 2215, 0,
112-
4676, 7600, 6524, 10069, 0, 4047, 6117, 1612, 2567, 0,
113-
5931, 5651, 5669, 6623, 0, 7674, 3291, 2748, 1654, 0,
114-
10455, 4290, 4145, 796, 0, 9835, 5483, 11649, 5952, 0,
115-
7098, 5460, 3101, 2443, 0, 7788, 5909, 8582, 6298, 0,
116-
9462, 4845, 3041, 2067, 0, 7038, 6336, 10438, 6377, 0,
117-
7518, 8187, 2079, 2773, 0, 10036, 2642, 3952, 1166, 0,
118-
16014, 2250, 10025, 1908, 0, 9610, 298, 3868, 122, 0,
119-
16629, 4338, 11335, 3527, 0, 11514, 5965, 4762, 2207, 0,
120-
18552, 10755, 13309, 5996, 0, 12454, 6787, 4960, 2875, 0,
121-
8750, 6999, 3534, 3233, 0, 14160, 9399, 9595, 8922, 0,
122-
9110, 6567, 3820, 2351, 0, 12969, 11814, 9436, 5870, 0,
123-
7631, 7061, 2877, 2499, 0, 8553, 13527, 3631, 6863, 0,
124-
1361, 8634, 515, 3372, 0, 3394, 10206, 1504, 4112, 0,
125-
5505, 17421, 4702, 11891, 0, 4233, 11894, 1739, 5014, 0,
126-
11787, 14634, 8981, 10759, 0, 11777, 6701, 4719, 3111, 0,
127-
18459, 7761, 12044, 7627, 0, 11214, 4556, 4374, 1594, 0,
128-
604, 1908, 1506, 6102, 0, 2532, 4024, 1713, 6121, 0,
129-
1878, 1814, 4761, 5397, 0, 1127, 3885, 4373, 5832, 0,
130-
450, 1414, 1080, 4719, 0, 5210, 2683, 2765, 4252, 0,
131-
2390, 1668, 7710, 4257, 0, 378, 1698, 3276, 6021, 0,
132-
2866, 4881, 3547, 6822, 0, 502, 1238, 2784, 5199, 0,
133-
2496, 3975, 2700, 5004, 0, 1220, 1990, 3633, 5763, 0,
134-
4501, 2679, 4504, 5412, 0, 1968, 1376, 6246, 3669, 0,
135-
3130, 272, 9345, 1950, 0, 5167, 3278, 9097, 2138, 0,
136-
2446, 1946, 6942, 5460, 0, 5732, 3404, 7919, 5534, 0,
137-
2038, 1614, 6978, 4635, 0, 4544, 4839, 7367, 5574, 0,
138-
1242, 1922, 4842, 6333, 0, 1066, 236, 2236, 686, 0,
139-
17238, 2254, 10413, 1592, 0, 991, 30, 2206, 70, 0,
140-
18823, 6392, 12173, 2470, 0, 1142, 684, 2742, 1219, 0,
141-
21256, 11293, 12719, 7512, 0, 1303, 649, 2818, 1669, 0,
142-
898, 574, 2018, 1929, 0, 15720, 11989, 10517, 5972, 0,
143-
885, 781, 2210, 1281, 0, 14601, 12198, 7915, 4958, 0,
144-
856, 850, 1601, 1355, 0, 7039, 14083, 4113, 7490, 0,
145-
152, 927, 287, 1902, 0, 301, 1051, 886, 2346, 0,
146-
6821, 19615, 4491, 13281, 0, 424, 1146, 999, 2906, 0,
147-
15177, 15480, 8849, 12442, 0, 1222, 544, 2687, 1859, 0,
148-
20215, 9693, 11441, 4964, 0, 1206, 555, 2466, 860, 0};
149-
std::vector<float> expected_grad_weight_data = {
150-
9246, 22073, 12431, 19714, 11179, 19032, 8458, 6495, 18707, 13830,
151-
20445, 17089, 17124, 18710, 11827, 17236, 16824, 9008, 14086, 18834,
152-
17419, 16759, 13152, 9339, 13801, 20888, 13976, 27277, 13010, 23949,
153-
9838, 11220, 17658, 15019, 25337, 17583, 13270, 21754, 16908, 20563,
154-
20732, 13413, 20868, 27521, 19537, 21170, 15888, 10034, 19195, 16370,
155-
40243, 25890, 40472, 30460, 21228, 21625, 13289, 24435, 19876, 29816,
156-
24188, 23619, 13752, 16251, 18741, 19368, 24517, 34261, 27054, 31257,
157-
21238, 18909, 15776, 16881, 34604, 22534, 28101, 23834, 18479, 16469,
158-
12852, 16551, 14204, 29983, 20167, 24150, 14281, 17501, 15897, 16019,
159-
21661, 32765, 23874, 26527, 20463, 18661};
160-
std::vector<float> expected_grad_bias_data = {363, 438, 585, 501};
70+
using CTYPE = typename decltype(tf)::ctype;
71+
std::vector<CTYPE> grad_output_data = {
72+
10, 12, 87, 13, 34, 87, 55, 22, 48, 33, 29, 38, 60, 49, 88, 30,
73+
99, 19, 42, 37, 61, 31, 33, 58, 38, 23, 2, 33, 3, 21, 32, 2,
74+
30, 72, 10, 67, 92, 19, 11, 16, 65, 37, 60, 74, 4, 19, 45, 37};
75+
std::vector<CTYPE> input_data = {
76+
9, 89, 45, 39, 25, 2, 97, 55, 80, 24, 18, 33, 28, 89, 19, 16, 19, 33,
77+
69, 61, 34, 84, 58, 30, 33, 18, 75, 30, 6, 33, 42, 10, 80, 41, 66, 64,
78+
47, 51, 67, 62, 58, 10, 97, 71, 24, 44, 84, 34, 33, 54, 8, 73, 90, 15,
79+
21, 92, 55, 22, 56, 12, 10, 63, 32, 76, 65, 38, 95, 92, 22, 15, 37, 12,
80+
67, 14, 60, 44, 73, 74, 23, 4, 56, 64, 88, 90, 82, 32, 91, 3, 6, 87,
81+
55, 95, 7, 14, 24, 69, 52, 44, 14, 37, 75, 52, 37, 40, 25, 54, 4, 15,
82+
97, 51, 46, 28, 65, 95, 50, 82, 23, 39, 50, 55, 97, 52, 91, 16, 19, 49,
83+
61, 50, 42, 47, 87, 99, 9, 60, 22, 71, 47, 17, 0, 80, 28, 88, 93, 43,
84+
65, 25, 88, 67, 21, 89, 24, 81, 3, 71, 20, 34, 17, 17, 94, 10, 82, 25,
85+
10, 11, 7, 28, 77, 39, 74, 79, 17, 40, 67, 54, 49, 54, 21, 89, 17, 7,
86+
52, 64, 68, 80, 7, 72, 44, 35, 92, 47, 4, 13, 10, 43, 64, 66, 83, 49,
87+
81, 78, 58, 22, 86, 48, 35, 64, 98, 79, 8, 52, 56, 23, 38, 74, 16, 63,
88+
51, 70, 44, 28, 43, 13, 51, 85, 42, 29, 64, 26, 54, 91, 9, 96, 41, 56,
89+
7, 52, 27, 22, 69, 13, 8, 20, 22, 49, 66, 98, 77, 42, 54, 38, 70, 83,
90+
13, 8, 21, 56, 78, 37, 28, 69, 42, 30, 91, 5, 28, 15, 20, 14, 16, 39,
91+
95, 66, 4, 72, 52, 35, 54, 93, 87, 77, 3, 49, 82, 70, 84, 3, 73, 99,
92+
32, 95, 58, 65, 32, 75, 34, 22, 12, 84, 63, 72, 85, 66, 63, 27, 3, 73,
93+
45, 37, 61, 52, 41, 16, 37, 14, 80, 17, 48, 8, 87, 98, 69, 63, 92, 68,
94+
42, 63, 5, 22, 66, 91, 74, 11, 17, 45, 45, 33, 40, 85, 26, 75, 73, 81,
95+
54, 27, 80, 1, 44, 66, 10, 21, 15, 10, 76, 96, 0, 43, 39, 3, 57, 79,
96+
45, 64, 58, 92, 44, 42, 7, 28, 94, 4, 8, 22, 22, 31, 75, 44, 3, 70,
97+
83, 72, 87, 12, 20, 55, 84, 31, 50, 34, 25, 49, 29, 71, 57, 97, 25, 82,
98+
84, 42, 86, 41, 54, 92, 34, 30, 52, 34, 84, 25, 54, 37, 38, 26, 76, 82,
99+
34, 14, 85, 28, 93, 9};
100+
std::vector<CTYPE> weight_data = {
101+
2, 54, 9, 37, 0, 47, 70, 9, 84, 69, 56, 79, 25, 35, 54, 13,
102+
65, 46, 38, 28, 74, 27, 66, 61, 20, 60, 62, 58, 15, 44, 75, 55,
103+
7, 52, 13, 36, 39, 64, 62, 45, 100, 6, 79, 63, 63, 52, 37, 60,
104+
78, 12, 69, 2, 74, 56, 93, 39, 62, 22, 55, 67, 68, 74, 12, 69,
105+
15, 73, 28, 70, 86, 20, 90, 49, 52, 26, 58, 2, 82, 17, 70, 55,
106+
54, 83, 70, 11, 27, 9, 5, 42, 34, 62, 29, 94, 69, 81, 54, 4};
107+
std::vector<CTYPE> expected_grad_input_data = {
108+
1134, 7578, 686, 2682, 0, 4148, 7136, 2406, 8698, 0,
109+
3759, 6003, 2163, 2395, 0, 2929, 5830, 3469, 6955, 0,
110+
720, 6201, 495, 2063, 0, 5260, 5989, 3060, 7079, 0,
111+
9690, 3423, 3385, 1932, 0, 7644, 8499, 1323, 2613, 0,
112+
4334, 6624, 8532, 9719, 0, 5496, 8601, 1157, 2215, 0,
113+
4676, 7600, 6524, 10069, 0, 4047, 6117, 1612, 2567, 0,
114+
5931, 5651, 5669, 6623, 0, 7674, 3291, 2748, 1654, 0,
115+
10455, 4290, 4145, 796, 0, 9835, 5483, 11649, 5952, 0,
116+
7098, 5460, 3101, 2443, 0, 7788, 5909, 8582, 6298, 0,
117+
9462, 4845, 3041, 2067, 0, 7038, 6336, 10438, 6377, 0,
118+
7518, 8187, 2079, 2773, 0, 10036, 2642, 3952, 1166, 0,
119+
16014, 2250, 10025, 1908, 0, 9610, 298, 3868, 122, 0,
120+
16629, 4338, 11335, 3527, 0, 11514, 5965, 4762, 2207, 0,
121+
18552, 10755, 13309, 5996, 0, 12454, 6787, 4960, 2875, 0,
122+
8750, 6999, 3534, 3233, 0, 14160, 9399, 9595, 8922, 0,
123+
9110, 6567, 3820, 2351, 0, 12969, 11814, 9436, 5870, 0,
124+
7631, 7061, 2877, 2499, 0, 8553, 13527, 3631, 6863, 0,
125+
1361, 8634, 515, 3372, 0, 3394, 10206, 1504, 4112, 0,
126+
5505, 17421, 4702, 11891, 0, 4233, 11894, 1739, 5014, 0,
127+
11787, 14634, 8981, 10759, 0, 11777, 6701, 4719, 3111, 0,
128+
18459, 7761, 12044, 7627, 0, 11214, 4556, 4374, 1594, 0,
129+
604, 1908, 1506, 6102, 0, 2532, 4024, 1713, 6121, 0,
130+
1878, 1814, 4761, 5397, 0, 1127, 3885, 4373, 5832, 0,
131+
450, 1414, 1080, 4719, 0, 5210, 2683, 2765, 4252, 0,
132+
2390, 1668, 7710, 4257, 0, 378, 1698, 3276, 6021, 0,
133+
2866, 4881, 3547, 6822, 0, 502, 1238, 2784, 5199, 0,
134+
2496, 3975, 2700, 5004, 0, 1220, 1990, 3633, 5763, 0,
135+
4501, 2679, 4504, 5412, 0, 1968, 1376, 6246, 3669, 0,
136+
3130, 272, 9345, 1950, 0, 5167, 3278, 9097, 2138, 0,
137+
2446, 1946, 6942, 5460, 0, 5732, 3404, 7919, 5534, 0,
138+
2038, 1614, 6978, 4635, 0, 4544, 4839, 7367, 5574, 0,
139+
1242, 1922, 4842, 6333, 0, 1066, 236, 2236, 686, 0,
140+
17238, 2254, 10413, 1592, 0, 991, 30, 2206, 70, 0,
141+
18823, 6392, 12173, 2470, 0, 1142, 684, 2742, 1219, 0,
142+
21256, 11293, 12719, 7512, 0, 1303, 649, 2818, 1669, 0,
143+
898, 574, 2018, 1929, 0, 15720, 11989, 10517, 5972, 0,
144+
885, 781, 2210, 1281, 0, 14601, 12198, 7915, 4958, 0,
145+
856, 850, 1601, 1355, 0, 7039, 14083, 4113, 7490, 0,
146+
152, 927, 287, 1902, 0, 301, 1051, 886, 2346, 0,
147+
6821, 19615, 4491, 13281, 0, 424, 1146, 999, 2906, 0,
148+
15177, 15480, 8849, 12442, 0, 1222, 544, 2687, 1859, 0,
149+
20215, 9693, 11441, 4964, 0, 1206, 555, 2466, 860, 0};
150+
std::vector<CTYPE> expected_grad_weight_data = {
151+
9246, 22073, 12431, 19714, 11179, 19032, 8458, 6495, 18707, 13830,
152+
20445, 17089, 17124, 18710, 11827, 17236, 16824, 9008, 14086, 18834,
153+
17419, 16759, 13152, 9339, 13801, 20888, 13976, 27277, 13010, 23949,
154+
9838, 11220, 17658, 15019, 25337, 17583, 13270, 21754, 16908, 20563,
155+
20732, 13413, 20868, 27521, 19537, 21170, 15888, 10034, 19195, 16370,
156+
40243, 25890, 40472, 30460, 21228, 21625, 13289, 24435, 19876, 29816,
157+
24188, 23619, 13752, 16251, 18741, 19368, 24517, 34261, 27054, 31257,
158+
21238, 18909, 15776, 16881, 34604, 22534, 28101, 23834, 18479, 16469,
159+
12852, 16551, 14204, 29983, 20167, 24150, 14281, 17501, 15897, 16019,
160+
21661, 32765, 23874, 26527, 20463, 18661};
161+
std::vector<CTYPE> expected_grad_bias_data = {363, 438, 585, 501};
161162

162-
auto grad_output = tf.make({2, 4, 3, 2}, grad_output_data);
163-
auto input = tf.make({2, 6, 7, 5}, input_data);
164-
auto weight = tf.make({4, 3, 4, 2}, weight_data);
165-
int64_t bias_sizes[1] = {4};
166-
int64_t stride[2] = {1, 2};
167-
int64_t padding[2] = {1, 0};
168-
int64_t dilation[2] = {2, 1};
169-
bool transposed = false;
170-
int64_t output_padding[2] = {0, 0};
171-
int64_t groups = 2;
172-
std::array<bool, 3> output_mask_a = {true, true, true};
173-
auto grad_input = tf.zeros({2, 6, 7, 5});
174-
auto grad_weight = tf.zeros({4, 3, 4, 2});
175-
auto grad_bias = tf.zeros({4});
163+
auto grad_output = tf.make({2, 4, 3, 2}, grad_output_data);
164+
auto input = tf.make({2, 6, 7, 5}, input_data);
165+
auto weight = tf.make({4, 3, 4, 2}, weight_data);
166+
int64_t bias_sizes[1] = {4};
167+
int64_t stride[2] = {1, 2};
168+
int64_t padding[2] = {1, 0};
169+
int64_t dilation[2] = {2, 1};
170+
bool transposed = false;
171+
int64_t output_padding[2] = {0, 0};
172+
int64_t groups = 2;
173+
std::array<bool, 3> output_mask_a = {true, true, true};
174+
auto grad_input = tf.zeros({2, 6, 7, 5});
175+
auto grad_weight = tf.zeros({4, 3, 4, 2});
176+
auto grad_bias = tf.zeros({4});
177+
178+
op_convolution_backward_out(
179+
grad_output,
180+
input,
181+
weight,
182+
IntArrayRef{bias_sizes, 1},
183+
IntArrayRef{stride, 2},
184+
IntArrayRef{padding, 2},
185+
IntArrayRef{dilation, 2},
186+
transposed,
187+
IntArrayRef{output_padding, 2},
188+
groups,
189+
output_mask_a,
190+
grad_input,
191+
grad_weight,
192+
grad_bias);
176193

177-
op_convolution_backward_out(
178-
grad_output,
179-
input,
180-
weight,
181-
IntArrayRef{bias_sizes, 1},
182-
IntArrayRef{stride, 2},
183-
IntArrayRef{padding, 2},
184-
IntArrayRef{dilation, 2},
185-
transposed,
186-
IntArrayRef{output_padding, 2},
187-
groups,
188-
output_mask_a,
189-
grad_input,
190-
grad_weight,
191-
grad_bias);
194+
auto expected_grad_input = tf.make({2, 6, 7, 5}, expected_grad_input_data);
195+
auto expected_grad_weight =
196+
tf.make({4, 3, 4, 2}, expected_grad_weight_data);
197+
auto expected_grad_bias = tf.make({4}, expected_grad_bias_data);
192198

193-
auto expected_grad_input = tf.make({2, 6, 7, 5}, expected_grad_input_data);
194-
auto expected_grad_weight = tf.make({4, 3, 4, 2}, expected_grad_weight_data);
195-
auto expected_grad_bias = tf.make({4}, expected_grad_bias_data);
199+
if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) {
200+
EXPECT_TENSOR_CLOSE_WITH_TOL(grad_input, expected_grad_input, 1e-2, 1e-8);
201+
EXPECT_TENSOR_CLOSE_WITH_TOL(
202+
grad_weight, expected_grad_weight, 2e-2, 1e-8);
203+
EXPECT_TENSOR_CLOSE_WITH_TOL(grad_bias, expected_grad_bias, 1e-2, 1e-8);
204+
} else {
205+
EXPECT_TENSOR_CLOSE(grad_input, expected_grad_input);
206+
EXPECT_TENSOR_CLOSE(grad_weight, expected_grad_weight);
207+
EXPECT_TENSOR_CLOSE(grad_bias, expected_grad_bias);
208+
}
209+
}
210+
};
196211

197-
EXPECT_TENSOR_CLOSE(grad_input, expected_grad_input);
198-
EXPECT_TENSOR_CLOSE(grad_weight, expected_grad_weight);
199-
EXPECT_TENSOR_CLOSE(grad_bias, expected_grad_bias);
212+
TEST_F(OpConvolutionBackwardOutTest, SmokeTest) {
213+
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
214+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
215+
#undef TEST_ENTRY
200216
}

0 commit comments

Comments
 (0)