@@ -551,3 +551,30 @@ TEST_F(OpMeanOutTest, DTypeOutFloatNAN) {
551551 Tensor ret = op_mean_dtype_out (x, ScalarType::Float, out);
552552 EXPECT_TENSOR_CLOSE (out, expected_result);
553553}
554+
555+ TEST_F (OpMeanOutTest, EmptyInput) {
556+ TensorFactory<ScalarType::Float> tf;
557+
558+ Tensor x = tf.make ({2 , 0 , 3 }, {});
559+ optional<ScalarType> dtype = ScalarType::Float;
560+ optional<ArrayRef<int64_t >> dim_list = ArrayRef<int64_t >{};
561+ Tensor out = tf.zeros ({1 , 1 , 1 });
562+ op_mean_out (x, dim_list, /* keepdim=*/ true , dtype, out);
563+ EXPECT_TENSOR_CLOSE (out, tf.make ({1 , 1 , 1 }, {NAN}));
564+
565+ out = tf.zeros ({});
566+ op_mean_out (x, dim_list, /* keepdim=*/ false , dtype, out);
567+ EXPECT_TENSOR_CLOSE (out, tf.make ({}, {NAN}));
568+
569+ int64_t dims1[1 ] = {1 };
570+ dim_list = ArrayRef<int64_t >{dims1, 1 };
571+ out = tf.zeros ({2 , 3 });
572+ op_mean_out (x, dim_list, /* keepdim=*/ false , dtype, out);
573+ EXPECT_TENSOR_CLOSE (out, tf.make ({2 , 3 }, {NAN, NAN, NAN, NAN, NAN, NAN}));
574+
575+ int64_t dims2[1 ] = {2 };
576+ dim_list = ArrayRef<int64_t >{dims2, 1 };
577+ out = tf.make ({2 , 0 , 1 }, {});
578+ op_mean_out (x, dim_list, /* keepdim=*/ true , dtype, out);
579+ EXPECT_TENSOR_CLOSE (out, tf.make ({2 , 0 , 1 }, {}));
580+ }
0 commit comments