Skip to content

Commit 0df9d44

Browse files
swolchokZonglin Peng
authored andcommitted
Support Half/BFloat16 in roll (pytorch#7861)
Partial fix for pytorch#7748.
1 parent b59f027 commit 0df9d44

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

kernels/portable/cpu/op_roll.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ Tensor& roll_out(
8181

8282
constexpr auto name = "roll.out";
8383

84-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
84+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
8585
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
8686
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
8787

kernels/test/op_roll_test.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,26 @@ class OpRollOutTest : public ::testing::Test {
3737
// first.
3838
torch::executor::runtime_init();
3939
}
40+
41+
template <ScalarType DTYPE>
42+
void test_dtype() {
43+
TensorFactory<DTYPE> tf;
44+
45+
Tensor input = tf.make({4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
46+
int64_t shifts_data[2] = {2, 1};
47+
ArrayRef<int64_t> shifts = ArrayRef<int64_t>(shifts_data, 2);
48+
int64_t dims_data[2] = {0, 1};
49+
ArrayRef<int64_t> dims = ArrayRef<int64_t>(dims_data, 2);
50+
Tensor out = tf.zeros({4, 2});
51+
Tensor out_expected = tf.make({4, 2}, {6, 5, 8, 7, 2, 1, 4, 3});
52+
op_roll_out(input, shifts, dims, out);
53+
EXPECT_TENSOR_CLOSE(out, out_expected);
54+
}
4055
};
4156

4257
TEST_F(OpRollOutTest, SmokeTest) {
43-
TensorFactory<ScalarType::Float> tfFloat;
44-
45-
Tensor input = tfFloat.make({4, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
46-
int64_t shifts_data[2] = {2, 1};
47-
ArrayRef<int64_t> shifts = ArrayRef<int64_t>(shifts_data, 2);
48-
int64_t dims_data[2] = {0, 1};
49-
ArrayRef<int64_t> dims = ArrayRef<int64_t>(dims_data, 2);
50-
Tensor out = tfFloat.zeros({4, 2});
51-
Tensor out_expected = tfFloat.make({4, 2}, {6, 5, 8, 7, 2, 1, 4, 3});
52-
op_roll_out(input, shifts, dims, out);
53-
EXPECT_TENSOR_CLOSE(out, out_expected);
58+
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
59+
// TODO: enable bool test after #7856 lands.
60+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
61+
#undef TEST_ENTRY
5462
}

0 commit comments

Comments
 (0)