Skip to content

Commit 4c62edc

Browse files
author
Alexandre Hoffmann
committed
mapper works even if when the user provides more/less indices than the number of specified slices. If more indices are provided we assume an slice. The underlying container handles the excess/missing indices
1 parent 1d46493 commit 4c62edc

File tree

2 files changed

+175
-81
lines changed

2 files changed

+175
-81
lines changed

include/xtensor/views/index_mapper.hpp

Lines changed: 144 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -45,57 +45,76 @@ namespace xt
4545
template <class UnderlyingContainer, class... Slices>
4646
class index_mapper<xt::xview<UnderlyingContainer, Slices...>>
4747
{
48-
static constexpr size_t n_slices = sizeof...(Slices); ///< @brief Total number of slices in the view
49-
static constexpr size_t n_free = ((!std::is_integral_v<Slices>) +...); ///< @brief Number of free
50-
///< (non-integral) slices
48+
/// @brief Total number of explicitly passed slices in the view
49+
static constexpr size_t n_slices = sizeof...(Slices);
50+
51+
/// @brief Number of slices that are integral constants (fixed indices)
52+
static constexpr size_t nb_integral_slices = (std::is_integral_v<Slices> + ...);
53+
54+
/// @brief Number of slices that are xt::newaxis (insert a dimension)
55+
static constexpr size_t nb_new_axis_slices = (xt::detail::is_newaxis<Slices>::value + ...);
56+
57+
/**
58+
* Compute how many indices are needed to address the underlying container
59+
* when given N indices in the view.
60+
*/
61+
template <std::integral... Indices>
62+
static constexpr size_t n_indices_full_v = size_t(
63+
sizeof...(Indices) + nb_integral_slices - nb_new_axis_slices
64+
);
5165

5266
public:
5367

54-
using view_type = xt::xview<UnderlyingContainer, Slices...>; ///< @brief The view type this mapper
55-
///< works with
68+
/// @brief The view type this mapper works with
69+
using view_type = xt::xview<UnderlyingContainer, Slices...>;
5670

57-
using value_type = typename xt::xview<UnderlyingContainer, Slices...>::value_type; ///< @brief Value
58-
///< type of the
59-
///< underlying
60-
///< container
71+
///< @brief Value type of the underlying container
72+
using value_type = typename xt::xview<UnderlyingContainer, Slices...>::value_type;
6173

6274
private:
6375

6476
/// @brief Helper type alias for the I-th slice type
6577
template <size_t I>
6678
using ith_slice_type = std::tuple_element_t<I, std::tuple<Slices...>>;
6779

80+
/// @brief True if the I-th slice is an integral slice (fixed index)
81+
template <size_t I>
82+
static consteval bool is_ith_slice_integral();
83+
84+
/// @brief True if the I-th slice is a newaxis slice
85+
template <size_t I>
86+
static consteval bool is_ith_slice_new_axis();
87+
6888
/**
69-
* @brief Helper metafunction to generate an index sequence excluding newaxis slices.
89+
* Helper metafunction to build an index_sequence that skips
90+
* newaxis slices.
7091
*
71-
* This recursive template builds an `std::index_sequence` containing indices of slices
72-
* that are not `xt::newaxis`. Newaxis slices increase dimensionality but don't correspond
73-
* to actual dimensions in the underlying container.
74-
*
75-
* @tparam first Current slice index being processed.
76-
* @tparam indices... Accumulated indices of non-newaxis slices.
92+
* The resulting sequence contains only the indices that
93+
* correspond to real container dimensions.
7794
*/
78-
template <size_t first, size_t... indices>
95+
template <size_t first, size_t bound, size_t... indices>
7996
struct indices_sequence_helper
8097
{
81-
using not_new_axis_type = typename indices_sequence_helper<first + 1, indices..., first>::Type; // we add the current axis
82-
using new_axis_type = typename indices_sequence_helper<first + 1, indices...>::Type; // we skip
83-
// the
84-
// current
85-
// axis
98+
// we add the current axis
99+
using not_new_axis_type = typename indices_sequence_helper<first + 1, bound, indices..., first>::Type;
100+
101+
// we skip the current axis
102+
using new_axis_type = typename indices_sequence_helper<first + 1, bound, indices...>::Type;
86103

87-
using Type = std::conditional_t<xt::detail::is_newaxis<ith_slice_type<first>>::value, new_axis_type, not_new_axis_type>;
104+
// NOTE: is_ith_slice_new_axis works even if first >= sizeof...(Slices)
105+
using Type = std::conditional_t<is_ith_slice_new_axis<first>(), new_axis_type, not_new_axis_type>;
88106
};
89107

90108
/// @brief Base case: recursion termination
91-
template <size_t... indices>
92-
struct indices_sequence_helper<n_slices, indices...>
109+
template <size_t bound, size_t... indices>
110+
struct indices_sequence_helper<bound, bound, indices...>
93111
{
94112
using Type = std::index_sequence<indices...>;
95113
};
96114

97-
/// @brief Index sequence of non-newaxis
98-
using indices_sequence = indices_sequence_helper<0>::Type;
115+
///< @brief Index sequence of non-newaxis slices
116+
template <size_t bound>
117+
using indices_sequence = indices_sequence_helper<0, bound>::Type;
99118

100119
/**
101120
* @brief Maps an index for a specific slice to the corresponding index in the underlying container.
@@ -124,12 +143,12 @@ namespace xt
124143
* @param indices Array of indices for all slices.
125144
* @return value_type The value at the mapped location in the container.
126145
*/
127-
template <size_t... Is>
146+
template <size_t n_indices, size_t... Is>
128147
value_type map_all_indices(
129148
const UnderlyingContainer& container,
130149
const view_type& view,
131150
std::index_sequence<Is...>,
132-
const std::array<size_t, n_slices>& indices
151+
const std::array<size_t, n_indices>& indices
133152
) const;
134153

135154
/**
@@ -145,14 +164,18 @@ namespace xt
145164
*
146165
* @throws std::out_of_range if any index is out of bounds.
147166
*/
148-
template <size_t... Is>
167+
template <size_t n_indices, size_t... Is>
149168
value_type map_at_all_indices(
150169
const UnderlyingContainer& container,
151170
const view_type& view,
152171
std::index_sequence<Is...>,
153-
const std::array<size_t, n_slices>& indices
172+
const std::array<size_t, n_indices>& indices
154173
) const;
155174

175+
/// @brief Expand view indices into a full index array, inserting dummy indices for integral slices
176+
template <std::integral... Indices>
177+
std::array<size_t, n_indices_full_v<Indices...>> get_indices_full(const Indices... indices) const;
178+
156179
public:
157180

158181
/**
@@ -167,19 +190,15 @@ namespace xt
167190
* @param indices The indices for the free dimensions of the view.
168191
* @return value_type The value at the mapped location in the container.
169192
*
170-
* @note The number of provided indices must equal `n_free` (number of non-integral slices).
171-
*
172193
* @example
173194
* @code
174195
* // For view(a, 1, all(), all()):
175-
* // n_free = 2 (two all() slices)
176196
* mapper.map(a, view, i, j); // Maps to a(1, i, j)
177197
* @endcode
178198
*/
179199
template <std::integral... Indices>
180200
value_type
181-
map(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const
182-
requires(sizeof...(Indices) == n_free);
201+
map(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
183202

184203
/**
185204
* @brief Maps view indices to container indices with bounds checking.
@@ -196,34 +215,78 @@ namespace xt
196215
*/
197216
template <std::integral... Indices>
198217
value_type
199-
map_at(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const
200-
requires(sizeof...(Indices) == n_free);
218+
map_at(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const;
201219

202-
constexpr size_t dimension() const
203-
{
204-
return n_free;
205-
}
220+
/// @brief Return the dimensionality of the view
221+
size_t dimension(const UnderlyingContainer& container) const;
206222
};
207223

208224
/*******************************
209225
* index_mapper implementation *
210226
*******************************/
211227

228+
template <class UnderlyingContainer, class... Slices>
229+
template <size_t I>
230+
consteval bool index_mapper<xt::xview<UnderlyingContainer, Slices...>>::is_ith_slice_integral()
231+
{
232+
if constexpr (I < sizeof...(Slices))
233+
{
234+
return std::is_integral_v<ith_slice_type<I>>;
235+
}
236+
else
237+
{
238+
return false;
239+
}
240+
}
241+
242+
template <class UnderlyingContainer, class... Slices>
243+
template <size_t I>
244+
consteval bool index_mapper<xt::xview<UnderlyingContainer, Slices...>>::is_ith_slice_new_axis()
245+
{
246+
if constexpr (I < sizeof...(Slices))
247+
{
248+
return xt::detail::is_newaxis<ith_slice_type<I>>::value;
249+
}
250+
else
251+
{
252+
return false;
253+
}
254+
}
255+
256+
template <class UnderlyingContainer, class... Slices>
257+
template <std::integral... Indices>
258+
auto
259+
index_mapper<xt::xview<UnderlyingContainer, Slices...>>::get_indices_full(const Indices... indices) const
260+
-> std::array<size_t, n_indices_full_v<Indices...>>
261+
{
262+
constexpr size_t n_indices_full = n_indices_full_v<Indices...>;
263+
264+
std::array<size_t, sizeof...(indices)> args{size_t(indices)...};
265+
std::array<size_t, n_indices_full> args_full;
266+
267+
const auto fill_args_full = [&args_full, &args]<size_t... Is>(std::index_sequence<Is...>)
268+
{
269+
auto it = std::cbegin(args);
270+
271+
((args_full[Is] = (is_ith_slice_integral<Is>()) ? size_t(0) : *it++), ...);
272+
};
273+
274+
fill_args_full(std::make_index_sequence<n_indices_full>{});
275+
276+
return args_full;
277+
}
278+
212279
template <class UnderlyingContainer, class... Slices>
213280
template <std::integral... Indices>
214281
auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map(
215282
const UnderlyingContainer& container,
216283
const view_type& view,
217284
const Indices... indices
218285
) const -> value_type
219-
requires(sizeof...(Indices) == n_free)
220286
{
221-
std::array<size_t, sizeof...(indices)> args{size_t(indices)...};
222-
223-
auto it = std::cbegin(args);
224-
std::array<size_t, n_slices> args_full{(std::is_integral_v<Slices> ? size_t(0) : *it++)...};
287+
constexpr size_t n_indices_full = n_indices_full_v<Indices...>;
225288

226-
return map_all_indices(container, view, indices_sequence{}, args_full);
289+
return map_all_indices(container, view, indices_sequence<n_indices_full>{}, get_indices_full(indices...));
227290
}
228291

229292
template <class UnderlyingContainer, class... Slices>
@@ -233,35 +296,31 @@ namespace xt
233296
const view_type& view,
234297
const Indices... indices
235298
) const -> value_type
236-
requires(sizeof...(Indices) == n_free)
237299
{
238-
std::array<size_t, sizeof...(indices)> args{size_t(indices)...};
239-
240-
auto it = std::cbegin(args);
241-
std::array<size_t, n_slices> args_full{(std::is_integral_v<Slices> ? size_t(0) : *it++)...};
300+
constexpr size_t n_indices_full = n_indices_full_v<Indices...>;
242301

243-
return map_at_all_indices(container, view, indices_sequence{}, args_full);
302+
return map_at_all_indices(container, view, indices_sequence<n_indices_full>{}, get_indices_full(indices...));
244303
}
245304

246305
template <class UnderlyingContainer, class... Slices>
247-
template <size_t... Is>
306+
template <size_t n_indices, size_t... Is>
248307
auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_all_indices(
249308
const UnderlyingContainer& container,
250309
const view_type& view,
251310
std::index_sequence<Is...>,
252-
const std::array<size_t, n_slices>& indices
311+
const std::array<size_t, n_indices>& indices
253312
) const -> value_type
254313
{
255314
return container(map_ith_index<Is>(view, indices[Is])...);
256315
}
257316

258317
template <class UnderlyingContainer, class... Slices>
259-
template <size_t... Is>
318+
template <size_t n_indices, size_t... Is>
260319
auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_at_all_indices(
261320
const UnderlyingContainer& container,
262321
const view_type& view,
263322
std::index_sequence<Is...>,
264-
const std::array<size_t, n_slices>& indices
323+
const std::array<size_t, n_indices>& indices
265324
) const -> value_type
266325
{
267326
return container.at(map_ith_index<Is>(view, indices[Is])...);
@@ -273,24 +332,40 @@ namespace xt
273332
index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_ith_index(const view_type& view, const Index i) const
274333
-> size_t
275334
{
276-
using current_slice = std::tuple_element_t<I, std::tuple<Slices...>>;
277-
278-
static_assert(not xt::detail::is_newaxis<current_slice>::value);
279-
280-
const auto& slice = std::get<I>(view.slices());
281-
282-
if constexpr (std::is_integral_v<current_slice>)
335+
if constexpr (I < sizeof...(Slices))
283336
{
284-
assert(i == 0);
285-
return size_t(slice);
337+
// if the slice is explicitly specified, use it
338+
using current_slice = std::tuple_element_t<I, std::tuple<Slices...>>;
339+
340+
static_assert(not xt::detail::is_newaxis<current_slice>::value);
341+
342+
const auto& slice = std::get<I>(view.slices());
343+
344+
if constexpr (std::is_integral_v<current_slice>)
345+
{
346+
assert(i == 0);
347+
return size_t(slice);
348+
}
349+
else
350+
{
351+
assert(i < slice.size());
352+
return size_t(slice(i));
353+
}
286354
}
287355
else
288356
{
289-
assert(i < slice.size());
290-
return size_t(slice(i));
357+
// else assume xt::all
358+
return i;
291359
}
292360
}
293361

362+
template <class UnderlyingContainer, class... Slices>
363+
auto index_mapper<xt::xview<UnderlyingContainer, Slices...>>::dimension(const UnderlyingContainer& container
364+
) const -> size_t
365+
{
366+
return container.dimension() - nb_integral_slices + nb_new_axis_slices;
367+
}
368+
294369
} // namespace xt
295370

296371
#endif // XTENSOR_INDEX_MAPPER_HPP

0 commit comments

Comments
 (0)