Skip to content

Commit 66dd786

Browse files
author
Alexandre Hoffmann
committed
feurst commiteuh
1 parent cbb9a36 commit 66dd786

File tree

2 files changed

+212
-0
lines changed

2 files changed

+212
-0
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/***************************************************************************
2+
* Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3+
* Copyright (c) QuantStack *
4+
* *
5+
* Distributed under the terms of the BSD 3-Clause License. *
6+
* *
7+
* The full license is in the file LICENSE, distributed with this software. *
8+
****************************************************************************/
9+
10+
#ifndef XTENSOR_INDEX_MAPPER_HPP
11+
#define XTENSOR_INDEX_MAPPER_HPP
12+
13+
#include "xview.hpp"
14+
15+
namespace xt
16+
{
17+
18+
template<class UndefinedView> struct index_mapper;
19+
20+
template<class UnderlyingContainer, class... Slices>
21+
class index_mapper< xt::xview<UnderlyingContainer, Slices...> >
22+
{
23+
static constexpr size_t n_slices = sizeof...(Slices);
24+
static constexpr size_t n_free = ((!std::is_integral_v<Slices>) + ... );
25+
public:
26+
using view_type = xt::xview<UnderlyingContainer, Slices...>;
27+
28+
using value_type = typename xt::xview<UnderlyingContainer, Slices...>::value_type;
29+
private:
30+
template<size_t I> using ith_slice_type = std::tuple_element_t<I, std::tuple<Slices...> >;
31+
32+
template<size_t first, size_t... indices>
33+
struct indices_sequence_helper
34+
{
35+
using not_new_axis_type = typename indices_sequence_helper<first + 1, indices..., first>::Type; // we add the current axis
36+
using new_axis_type = typename indices_sequence_helper<first + 1, indices...>::Type; // we skip the current axis
37+
38+
using Type = std::conditional_t< xt::detail::is_newaxis< ith_slice_type<first> >::value , new_axis_type, not_new_axis_type>;
39+
};
40+
41+
// closing the recurence
42+
template<size_t... indices>
43+
struct indices_sequence_helper<n_slices, indices...>
44+
{
45+
using Type = std::index_sequence<indices...>;
46+
};
47+
48+
using indices_sequence = indices_sequence_helper<0>::Type;
49+
50+
static constexpr size_t n_all_indices = indices_sequence::size();
51+
52+
template<size_t I, std::integral Index>
53+
size_t map_ith_index(const view_type& view, const Index i) const;
54+
55+
template<size_t... Is>
56+
value_type map_all_indices(const UnderlyingContainer& container, const view_type& view, std::index_sequence<Is...>, const std::array<size_t, n_all_indices>& indices) const
57+
requires(sizeof...(Is) == n_all_indices);
58+
59+
template<size_t... Is>
60+
value_type map_at_all_indices(const UnderlyingContainer& container, const view_type& view, std::index_sequence<Is...>, const std::array<size_t, n_all_indices>& indices) const
61+
requires(sizeof...(Is) == n_all_indices);
62+
public:
63+
template<std::integral... Indices>
64+
value_type map(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const
65+
requires(sizeof...(Indices) == n_free);
66+
67+
template<std::integral... Indices>
68+
value_type map_at(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const
69+
requires(sizeof...(Indices) == n_free);
70+
71+
constexpr size_t dimension() const { return n_free; }
72+
};
73+
74+
/*******************************
75+
* index_mapper implementation *
76+
*******************************/
77+
78+
template<class UnderlyingContainer, class... Slices> template<std::integral... Indices>
79+
auto index_mapper< xt::xview<UnderlyingContainer, Slices...> >::map(
80+
const UnderlyingContainer& container,
81+
const view_type& view,
82+
const Indices... indices) const -> value_type
83+
requires(sizeof...(Indices) == n_free)
84+
{
85+
std::array<size_t, sizeof...(indices)> args{ size_t(indices)...};
86+
87+
auto it = std::cbegin(args);
88+
std::array<size_t, n_all_indices> args_full{ (std::is_integral_v<Slices> ? size_t(0) : *it++)... };
89+
90+
return map_all_indices(container, view, indices_sequence{}, args_full);
91+
}
92+
93+
template<class UnderlyingContainer, class... Slices> template<std::integral... Indices>
94+
auto index_mapper< xt::xview<UnderlyingContainer, Slices...> >::map_at(
95+
const UnderlyingContainer& container,
96+
const view_type& view,
97+
const Indices... indices) const -> value_type
98+
requires(sizeof...(Indices) == n_free)
99+
{
100+
std::array<size_t, sizeof...(indices)> args{ size_t(indices)...};
101+
102+
auto it = std::cbegin(args);
103+
std::array<size_t, n_all_indices> args_full{ (std::is_integral_v<Slices> ? size_t(0) : *it++)... };
104+
105+
return map_at_all_indices(container, view, indices_sequence{}, args_full);
106+
}
107+
108+
template<class UnderlyingContainer, class... Slices> template<size_t... Is>
109+
auto index_mapper< xt::xview<UnderlyingContainer, Slices...> >::map_all_indices(const UnderlyingContainer& container, const view_type& view, std::index_sequence<Is...>, const std::array<size_t, n_all_indices>& indices) const -> value_type
110+
requires(sizeof...(Is) == n_all_indices)
111+
{
112+
return container(map_ith_index<Is>(view, indices[Is])...);
113+
}
114+
115+
template<class UnderlyingContainer, class... Slices> template<size_t... Is>
116+
auto index_mapper< xt::xview<UnderlyingContainer, Slices...> >::map_at_all_indices(const UnderlyingContainer& container, const view_type& view, std::index_sequence<Is...>, const std::array<size_t, n_all_indices>& indices) const -> value_type
117+
requires(sizeof...(Is) == n_all_indices)
118+
{
119+
return container.at(map_ith_index<Is>(view, indices[Is])...);
120+
}
121+
122+
template<class UnderlyingContainer, class... Slices> template<size_t I, std::integral Index>
123+
auto index_mapper< xt::xview<UnderlyingContainer, Slices...> >::map_ith_index(const view_type& view, const Index i) const -> size_t
124+
{
125+
using current_slice = std::tuple_element_t<I, std::tuple<Slices...>>;
126+
127+
static_assert(not xt::detail::is_newaxis<current_slice>::value);
128+
129+
const auto& slice = std::get<I>(view.slices());
130+
131+
if constexpr (std::is_integral_v<current_slice>) { assert(i == 0); return size_t(slice); }
132+
else { assert(i < slice.size()); return size_t(slice(i)); }
133+
}
134+
135+
} // namespace xt
136+
137+
#endif // XTENSOR_INDEX_MAPPER_HPP

test/test_xview.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "xtensor/misc/xmanipulation.hpp"
3131
#include "xtensor/views/xstrided_view.hpp"
3232
#include "xtensor/views/xview.hpp"
33+
#include "xtensor/views/index_mapper.hpp"
3334

3435
namespace xt
3536
{
@@ -143,6 +144,57 @@ namespace xt
143144
}
144145
}
145146

147+
TEST(xview_mapping, simple)
148+
{
149+
view_shape_type shape = {3, 4};
150+
xarray<double> a(shape);
151+
std::vector<double> data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
152+
std::copy(data.cbegin(), data.cend(), a.template begin<layout_type::row_major>());
153+
154+
auto view1 = view(a, 1, range(1, 4));
155+
156+
index_mapper<decltype(view1)> mapper1;
157+
158+
EXPECT_EQ(a(1, 1), mapper1.map(a, view1, 0));
159+
EXPECT_EQ(a(1, 2), mapper1.map(a, view1, 1));
160+
EXPECT_EQ(size_t(1), mapper1.dimension());
161+
//~ XT_EXPECT_ANY_THROW(mapper1.map_at(a, view1, 10));
162+
163+
auto view0 = view(a, 0, range(0, 3));
164+
index_mapper<decltype(view0)> mapper0;
165+
166+
EXPECT_EQ(a(0, 0), mapper0.map(a, view0, 0));
167+
EXPECT_EQ(a(0, 1), mapper0.map(a, view0, 1));
168+
EXPECT_EQ(size_t(1), mapper0.dimension());
169+
170+
auto view2 = view(a, range(0, 2), 2);
171+
index_mapper<decltype(view2)> mapper2;
172+
EXPECT_EQ(a(0, 2), mapper2.map(a, view2, 0));
173+
EXPECT_EQ(a(1, 2), mapper2.map(a, view2, 1));
174+
EXPECT_EQ(size_t(1), mapper2.dimension());
175+
176+
//~ auto view4 = view(a, 1);
177+
//~ index_mapper<decltype(view4)> mapper4;
178+
//~ EXPECT_EQ(size_t(1), mapper4.dimension());
179+
//~
180+
//~ auto view5 = view(view4, 1);
181+
//~ index_mapper<decltype(view5)> mapper5;
182+
//~ EXPECT_EQ(size_t(0), mapper5.dimension());
183+
184+
auto view6 = view(a, 1, all());
185+
index_mapper<decltype(view6)> mapper6;
186+
EXPECT_EQ(a(1, 0), mapper6.map(a, view6, 0));
187+
EXPECT_EQ(a(1, 1), mapper6.map(a, view6, 1));
188+
EXPECT_EQ(a(1, 2), mapper6.map(a, view6, 2));
189+
EXPECT_EQ(a(1, 3), mapper6.map(a, view6, 3));
190+
191+
auto view7 = view(a, all(), 2);
192+
index_mapper<decltype(view7)> mapper7;
193+
EXPECT_EQ(a(0, 2), mapper7.map(a, view7, 0));
194+
EXPECT_EQ(a(1, 2), mapper7.map(a, view7, 1));
195+
EXPECT_EQ(a(2, 2), mapper7.map(a, view7, 2));
196+
}
197+
146198
TEST(xview, negative_index)
147199
{
148200
view_shape_type shape = {3, 4};
@@ -269,6 +321,29 @@ namespace xt
269321
EXPECT_EQ(a(1, 1, 1), view1.element(idx.cbegin(), idx.cend()));
270322
}
271323

324+
TEST(xview_mapping, three_dimensional)
325+
{
326+
view_shape_type shape = {3, 4, 2};
327+
std::vector<double> data = {1, 2, 3, 4, 5, 6, 7, 8,
328+
329+
9, 10, 11, 12, 21, 22, 23, 24,
330+
331+
25, 26, 27, 28, 29, 210, 211, 212};
332+
xarray<double> a(shape);
333+
std::copy(data.cbegin(), data.cend(), a.template begin<layout_type::row_major>());
334+
335+
auto view1 = view(a, 1, all(), all());
336+
index_mapper<decltype(view1)> mapper1;
337+
338+
EXPECT_EQ(size_t(2), mapper1.dimension());
339+
std::cout << "===================================" << std::endl;
340+
EXPECT_EQ(a(1, 0, 0), mapper1.map(a, view1, 0, 0));
341+
EXPECT_EQ(a(1, 0, 1), mapper1.map(a, view1, 0, 1));
342+
//~ EXPECT_EQ(a(1, 1, 0), mapper1.map(a, view1, 1, 0));
343+
//~ EXPECT_EQ(a(1, 1, 1), mapper1.map(a, view1, 1, 1));
344+
//~ XT_EXPECT_ANY_THROW(mapper1.map_at(a, view1, 10, 10));
345+
}
346+
272347
TEST(xview, integral_count)
273348
{
274349
size_t squeeze1 = integral_count<size_t, size_t, size_t, xrange<size_t>>();

0 commit comments

Comments
 (0)