Skip to content

Commit 2808659

Browse files
swolchokZonglin Peng
authored andcommitted
Support Half/BFloat16 in constant_pad_nd (pytorch#7806)
Partial fix for pytorch#7748.
1 parent f153dd9 commit 2808659

File tree

2 files changed

+31
-32
lines changed

2 files changed

+31
-32
lines changed

kernels/portable/cpu/op_constant_pad_nd.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -184,17 +184,16 @@ Tensor& constant_pad_nd_out(
184184
ScalarType in_type = in.scalar_type();
185185
ScalarType value_type = utils::get_scalar_dtype(value);
186186

187-
ET_SWITCH_REAL_TYPES_AND(
188-
Bool, in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() {
189-
CTYPE value_v;
190-
ET_SWITCH_SCALAR_OBJ_TYPES(
191-
value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() {
192-
CTYPE_VALUE val;
193-
utils::extract_scalar(value, &val);
194-
value_v = static_cast<CTYPE>(val);
195-
});
196-
constant_pad_nd_out_impl<CTYPE>(in, pad, value_v, out);
197-
});
187+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() {
188+
CTYPE value_v;
189+
ET_SWITCH_SCALAR_OBJ_TYPES(
190+
value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() {
191+
CTYPE_VALUE val;
192+
utils::extract_scalar(value, &val);
193+
value_v = static_cast<CTYPE>(val);
194+
});
195+
constant_pad_nd_out_impl<CTYPE>(in, pad, value_v, out);
196+
});
198197

199198
return out;
200199
}

kernels/test/op_constant_pad_nd_test.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class OpConstantPadNDOutTest : public OperatorTest {
5050
5, 6, 7, 8,
5151
1, 2, 3, 4,
5252
5, 6, 7, 8,
53-
53+
5454
1, 2, 3, 4,
5555
5, 6, 7, 8,
5656
1, 2, 3, 4,
@@ -66,7 +66,7 @@ class OpConstantPadNDOutTest : public OperatorTest {
6666
7, 5, 6, 7, 8, 7,
6767
7, 1, 2, 3, 4, 7,
6868
7, 5, 6, 7, 8, 7,
69-
69+
7070
7, 1, 2, 3, 4, 7,
7171
7, 5, 6, 7, 8, 7,
7272
7, 1, 2, 3, 4, 7,
@@ -98,7 +98,7 @@ class OpConstantPadNDOutTest : public OperatorTest {
9898
5, 6, 7, 8,
9999
1, 2, 3, 4,
100100
5, 6, 7, 8,
101-
101+
102102
1, 2, 3, 4,
103103
5, 6, 7, 8,
104104
1, 2, 3, 4,
@@ -116,7 +116,7 @@ class OpConstantPadNDOutTest : public OperatorTest {
116116
5, 6, 7, 8,
117117
1, 2, 3, 4,
118118
5, 6, 7, 8,
119-
119+
120120
7, 7, 7, 7,
121121
7, 7, 7, 7,
122122
1, 2, 3, 4,
@@ -150,7 +150,7 @@ class OpConstantPadNDOutTest : public OperatorTest {
150150
5, 6, 7, 8,
151151
1, 2, 3, 4,
152152
5, 6, 7, 8,
153-
153+
154154
1, 2, 3, 4,
155155
5, 6, 7, 8,
156156
1, 2, 3, 4,
@@ -166,12 +166,12 @@ class OpConstantPadNDOutTest : public OperatorTest {
166166
7, 7, 7, 7,
167167
7, 7, 7, 7,
168168
7, 7, 7, 7,
169-
169+
170170
1, 2, 3, 4,
171171
5, 6, 7, 8,
172172
1, 2, 3, 4,
173173
5, 6, 7, 8,
174-
174+
175175
1, 2, 3, 4,
176176
5, 6, 7, 8,
177177
1, 2, 3, 4,
@@ -203,7 +203,7 @@ class OpConstantPadNDOutTest : public OperatorTest {
203203
5, 6, 7, 8,
204204
1, 2, 3, 4,
205205
5, 6, 7, 8,
206-
206+
207207
1, 2, 3, 4,
208208
5, 6, 7, 8,
209209
1, 2, 3, 4,
@@ -221,7 +221,7 @@ class OpConstantPadNDOutTest : public OperatorTest {
221221
7, 7, 5, 6, 7, 8, 7,
222222
7, 7, 7, 7, 7, 7, 7,
223223
7, 7, 7, 7, 7, 7, 7,
224-
224+
225225
7, 7, 1, 2, 3, 4, 7,
226226
7, 7, 5, 6, 7, 8, 7,
227227
7, 7, 1, 2, 3, 4, 7,
@@ -255,7 +255,7 @@ class OpConstantPadNDOutTest : public OperatorTest {
255255
5, 6, 7, 8,
256256
1, 2, 3, 4,
257257
5, 6, 7, 8,
258-
258+
259259
1, 2, 3, 4,
260260
5, 6, 7, 8,
261261
1, 2, 3, 4,
@@ -271,12 +271,12 @@ class OpConstantPadNDOutTest : public OperatorTest {
271271
7, 7, 5, 6, 7, 8, 7,
272272
7, 7, 1, 2, 3, 4, 7,
273273
7, 7, 5, 6, 7, 8, 7,
274-
274+
275275
7, 7, 1, 2, 3, 4, 7,
276276
7, 7, 5, 6, 7, 8, 7,
277277
7, 7, 1, 2, 3, 4, 7,
278278
7, 7, 5, 6, 7, 8, 7,
279-
279+
280280
7, 7, 7, 7, 7, 7, 7,
281281
7, 7, 7, 7, 7, 7, 7,
282282
7, 7, 7, 7, 7, 7, 7,
@@ -308,7 +308,7 @@ class OpConstantPadNDOutTest : public OperatorTest {
308308
5, 6, 7, 8,
309309
1, 2, 3, 4,
310310
5, 6, 7, 8,
311-
311+
312312
1, 2, 3, 4,
313313
5, 6, 7, 8,
314314
1, 2, 3, 4,
@@ -325,13 +325,13 @@ class OpConstantPadNDOutTest : public OperatorTest {
325325
7, 7, 5, 6, 7, 8, 7,
326326
7, 7, 1, 2, 3, 4, 7,
327327
7, 7, 5, 6, 7, 8, 7,
328-
328+
329329
7, 7, 7, 7, 7, 7, 7,
330330
7, 7, 1, 2, 3, 4, 7,
331331
7, 7, 5, 6, 7, 8, 7,
332332
7, 7, 1, 2, 3, 4, 7,
333333
7, 7, 5, 6, 7, 8, 7,
334-
334+
335335
7, 7, 7, 7, 7, 7, 7,
336336
7, 7, 7, 7, 7, 7, 7,
337337
7, 7, 7, 7, 7, 7, 7,
@@ -353,47 +353,47 @@ TEST_F(OpConstantPadNDOutTest, TestPadDim2) {
353353
#define TEST_ENTRY(ctype, dtype) \
354354
test_constant_pad_nd_out_dim2<ScalarType::dtype>();
355355

356-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
356+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
357357
#undef TEST_ENTRY
358358
}
359359

360360
TEST_F(OpConstantPadNDOutTest, TestPadDim1) {
361361
#define TEST_ENTRY(ctype, dtype) \
362362
test_constant_pad_nd_out_dim1<ScalarType::dtype>();
363363

364-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
364+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
365365
#undef TEST_ENTRY
366366
}
367367

368368
TEST_F(OpConstantPadNDOutTest, TestPadDim0) {
369369
#define TEST_ENTRY(ctype, dtype) \
370370
test_constant_pad_nd_out_dim0<ScalarType::dtype>();
371371

372-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
372+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
373373
#undef TEST_ENTRY
374374
}
375375

376376
TEST_F(OpConstantPadNDOutTest, TestPadDim1And2) {
377377
#define TEST_ENTRY(ctype, dtype) \
378378
test_constant_pad_nd_out_dim12<ScalarType::dtype>();
379379

380-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
380+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
381381
#undef TEST_ENTRY
382382
}
383383

384384
TEST_F(OpConstantPadNDOutTest, TestPadDim0And2) {
385385
#define TEST_ENTRY(ctype, dtype) \
386386
test_constant_pad_nd_out_dim02<ScalarType::dtype>();
387387

388-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
388+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
389389
#undef TEST_ENTRY
390390
}
391391

392392
TEST_F(OpConstantPadNDOutTest, TestPadDim0And1And2) {
393393
#define TEST_ENTRY(ctype, dtype) \
394394
test_constant_pad_nd_out_dim012<ScalarType::dtype>();
395395

396-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
396+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
397397
#undef TEST_ENTRY
398398
}
399399

0 commit comments

Comments
 (0)