@@ -28,6 +28,19 @@ class OpGluOutTest : public OperatorTest {
2828 return torch::executor::aten::glu_outf (context_, self, dim, out);
2929 }
3030
31+ template <ScalarType DTYPE>
32+ void expect_tensor_close (Tensor actual, Tensor expected) {
33+ if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) {
34+ EXPECT_TENSOR_CLOSE_WITH_TOL (
35+ actual,
36+ expected,
37+ 1e-2 ,
38+ executorch::runtime::testing::internal::kDefaultAtol );
39+ } else {
40+ EXPECT_TENSOR_CLOSE (actual, expected);
41+ }
42+ }
43+
3144 // Common testing for glu operator
3245 template <ScalarType DTYPE, ScalarType OUT_DTYPE>
3346 void test_glu_out () {
@@ -41,14 +54,14 @@ class OpGluOutTest : public OperatorTest {
4154 Tensor in = tf.ones (sizes);
4255 Tensor out = tf_out.zeros (out_sizes_1);
4356 op_glu_out (in, 0 , out);
44- EXPECT_TENSOR_CLOSE (
57+ expect_tensor_close<DTYPE> (
4558 out,
4659 tf_out.make (
4760 out_sizes_1, /* data=*/ {0.731059 , 0.731059 , 0.731059 , 0.731059 }));
4861 const std::vector<int32_t > out_sizes_2 = {4 , 1 };
4962 out = tf_out.zeros (out_sizes_2);
5063 op_glu_out (in, 1 , out);
51- EXPECT_TENSOR_CLOSE (
64+ expect_tensor_close<DTYPE> (
5265 out,
5366 tf_out.make (
5467 out_sizes_2, /* data=*/ {0.731059 , 0.731059 , 0.731059 , 0.731059 }));
0 commit comments