diff --git a/include/xtensor/xmasked_view.hpp b/include/xtensor/xmasked_view.hpp index de0ac9df3..f5c251f0c 100644 --- a/include/xtensor/xmasked_view.hpp +++ b/include/xtensor/xmasked_view.hpp @@ -119,14 +119,17 @@ namespace xt using bool_load_type = xtl::xmasked_value; using shape_type = typename data_type::shape_type; - using strides_type = typename data_type::strides_type; static constexpr layout_type static_layout = data_type::static_layout; static constexpr bool contiguous_layout = false; using inner_shape_type = typename data_type::inner_shape_type; - using inner_strides_type = typename data_type::inner_strides_type; - using inner_backstrides_type = typename data_type::inner_backstrides_type; + using inner_strides_type = xtl::mpl::eval_if_t, + detail::expr_inner_strides_type, + get_strides_type>; + using inner_backstrides_type = xtl::mpl::eval_if_t, + detail::expr_inner_backstrides_type, + get_strides_type>; using expression_tag = xtensor_expression_tag; diff --git a/test/test_xmasked_view.cpp b/test/test_xmasked_view.cpp index f11111682..cacaf7c7a 100644 --- a/test/test_xmasked_view.cpp +++ b/test/test_xmasked_view.cpp @@ -245,12 +245,12 @@ namespace xt TEST(xmasked_view, view) { - xt::xarray data = {{0,1}, {2,3}, {4,5}}; - xt::xarray data_new = xt::zeros(data.shape()); - xt::xarray col_mask = {false, true}; + xarray data = {{0,1}, {2,3}, {4,5}}; + xarray data_new = zeros(data.shape()); + xarray col_mask = {false, true}; - auto row_masked = xt::masked_view(xt::view(data, 0, xt::all()), col_mask); - auto new_row_masked = xt::masked_view(xt::view(data_new, 0, xt::all()), col_mask); + auto row_masked = masked_view(view(data, 0, all()), col_mask); + auto new_row_masked = masked_view(view(data_new, 0, all()), col_mask); row_masked += 10; new_row_masked = row_masked; @@ -258,4 +258,20 @@ namespace xt EXPECT_EQ(data_new(0, 0), size_t(0)); EXPECT_EQ(data_new(0, 1), size_t(11)); } + + TEST(xmasked_view, xfunction) + { + xt::xarray data = {{0,1}, {2,3}, {4,5}}; + xt::xarray data_new = xt::zeros(data.shape()); + xt::xarray mask = {{true, false}, + {false, true}, + {true, false}}; + + masked_view(data_new, mask) = masked_view(2UL*data + 1UL, mask); + + xarray expected = {{1, 0}, + {0, 7}, + {9, 0}}; + EXPECT_EQ(data_new, expected); + } }