3131#include " xstrides.hpp"
3232#include " xtensor_simd.hpp"
3333#include " xutils.hpp"
34+ #include " xdevice.hpp"
3435
3536namespace xt
3637{
@@ -283,6 +284,7 @@ namespace xt
283284 using const_iterator = typename iterable_base::const_iterator;
284285 using reverse_iterator = typename iterable_base::reverse_iterator;
285286 using const_reverse_iterator = typename iterable_base::const_reverse_iterator;
287+ using device_return_type = host_device_batch<value_type>;
286288
287289 template <class Func , class ... CTA, class U = std::enable_if_t <!std::is_base_of<std::decay_t <Func>, self_type>::value>>
288290 xfunction (Func&& f, CTA&&... e) noexcept ;
@@ -361,6 +363,8 @@ namespace xt
361363 template <class align , class requested_type = value_type, std::size_t N = xt_simd::simd_traits<requested_type>::size>
362364 simd_return_type<requested_type> load_simd (size_type i) const ;
363365
366+ device_return_type load_device () const ;
367+
364368 const tuple_type& arguments () const noexcept ;
365369
366370 const functor_type& functor () const noexcept ;
@@ -385,6 +389,9 @@ namespace xt
385389 template <class align , class requested_type , std::size_t N, std::size_t ... I>
386390 auto load_simd_impl (std::index_sequence<I...>, size_type i) const ;
387391
392+ template <std::size_t ... I>
393+ inline auto load_device_impl (std::index_sequence<I...>) const ;
394+
388395 template <class Func , std::size_t ... I>
389396 const_stepper build_stepper (Func&& f, std::index_sequence<I...>) const noexcept ;
390397
@@ -844,6 +851,12 @@ namespace xt
844851 return operator ()();
845852 }
846853
854+ template <class F , class ... CT>
855+ inline auto xfunction<F, CT...>::load_device() const -> device_return_type
856+ {
857+ return load_device_impl (std::make_index_sequence<sizeof ...(CT)>());
858+ }
859+
847860 template <class F , class ... CT>
848861 template <class align , class requested_type , std::size_t N>
849862 inline auto xfunction<F, CT...>::load_simd(size_type i) const -> simd_return_type<requested_type>
@@ -912,6 +925,13 @@ namespace xt
912925 return m_f.simd_apply ((std::get<I>(m_e).template load_simd <align, requested_type>(i))...);
913926 }
914927
928+ template <class F , class ... CT>
929+ template <std::size_t ... I>
930+ inline auto xfunction<F, CT...>::load_device_impl(std::index_sequence<I...>) const
931+ {
932+ return m_f.device_apply ((std::get<I>(m_e).load_device ())...);
933+ }
934+
915935 template <class F , class ... CT>
916936 template <class Func , std::size_t ... I>
917937 inline auto xfunction<F, CT...>::build_stepper(Func&& f, std::index_sequence<I...>) const noexcept
0 commit comments