Skip to content

Commit 3f1c9a8

Browse files
authored
Relocate SYCL scan-by-segment to pstl/hetero/dpcpp (#2275)
* Moves SYCL specific code to `pstl/hetero/dpcpp` to align with our general design --------- Signed-off-by: Matthew Michel <[email protected]>
1 parent 7b5c0e9 commit 3f1c9a8

File tree

6 files changed

+172
-165
lines changed

6 files changed

+172
-165
lines changed

include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
#include "function.h"
2424
#include "by_segment_extension_defs.h"
2525
#include "../pstl/utils.h"
26-
#include "scan_by_segment_impl.h"
26+
27+
#if _ONEDPL_BACKEND_SYCL
28+
# include "../pstl/hetero/algorithm_impl_hetero.h"
29+
#endif
2730

2831
namespace oneapi
2932
{

include/oneapi/dpl/internal/inclusive_scan_by_segment_impl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
#include "../pstl/parallel_backend.h"
2828
#include "function.h"
2929
#include "../pstl/utils.h"
30-
#include "scan_by_segment_impl.h"
30+
31+
#if _ONEDPL_BACKEND_SYCL
32+
# include "../pstl/hetero/algorithm_impl_hetero.h"
33+
#endif
3134

3235
namespace oneapi
3336
{

include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2180,6 +2180,41 @@ __pattern_reduce_by_segment(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&&
21802180
__out_keys.all_view(), __out_values.all_view(), __binary_pred, __binary_op);
21812181
}
21822182

2183+
template <typename _BackendTag, typename _Policy, typename _InputIterator1, typename _InputIterator2,
2184+
typename _OutputIterator, typename _T, typename _BinaryPredicate, typename _Operator, typename _Inclusive>
2185+
_OutputIterator
2186+
__pattern_scan_by_segment(__hetero_tag<_BackendTag>, _Policy&& __policy, _InputIterator1 __first1,
2187+
_InputIterator1 __last1, _InputIterator2 __first2, _OutputIterator __result, _T __init,
2188+
_BinaryPredicate __binary_pred, _Operator __binary_op, _Inclusive)
2189+
{
2190+
const auto __n = std::distance(__first1, __last1);
2191+
2192+
// Check for empty element ranges
2193+
if (__n <= 0)
2194+
return __result;
2195+
2196+
namespace __bknd = oneapi::dpl::__par_backend_hetero;
2197+
2198+
auto __keep_keys = oneapi::dpl::__ranges::__get_sycl_range<__bknd::access_mode::read, _InputIterator1>();
2199+
auto __key_buf = __keep_keys(__first1, __last1);
2200+
auto __keep_values = oneapi::dpl::__ranges::__get_sycl_range<__bknd::access_mode::read, _InputIterator2>();
2201+
auto __value_buf = __keep_values(__first2, __first2 + __n);
2202+
auto __keep_value_outputs =
2203+
oneapi::dpl::__ranges::__get_sycl_range<__bknd::access_mode::read_write, _OutputIterator>();
2204+
auto __value_output_buf = __keep_value_outputs(__result, __result + __n);
2205+
using _IterValueType = typename std::iterator_traits<_InputIterator2>::value_type;
2206+
2207+
// Currently, this pattern requires a known identity for the binary operator.
2208+
static_assert(unseq_backend::__has_known_identity<_Operator, _IterValueType>::value,
2209+
"Calls to __pattern_scan_by_segment require a known identity for the provided binary operator");
2210+
constexpr _IterValueType __identity = unseq_backend::__known_identity<_Operator, _IterValueType>;
2211+
2212+
__bknd::__parallel_scan_by_segment<_Inclusive::value>(
2213+
_BackendTag{}, std::forward<_Policy>(__policy), __key_buf.all_view(), __value_buf.all_view(),
2214+
__value_output_buf.all_view(), __binary_pred, __binary_op, __init, __identity);
2215+
return __result + __n;
2216+
}
2217+
21832218
} // namespace __internal
21842219
} // namespace dpl
21852220
} // namespace oneapi

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "parallel_backend_sycl_merge_sort.h"
4444
#include "parallel_backend_sycl_reduce_by_segment.h"
4545
#include "parallel_backend_sycl_reduce_then_scan.h"
46+
#include "parallel_backend_sycl_scan_by_segment.h"
4647
#include "execution_sycl_defs.h"
4748
#include "sycl_iterator.h"
4849
#include "unseq_backend_sycl.h"

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_by_segment.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
#include "sycl_traits.h"
5656

5757
#include "../../utils.h"
58-
#include "../../../internal/scan_by_segment_impl.h"
58+
#include "parallel_backend_sycl_scan_by_segment.h"
5959

6060
namespace oneapi
6161
{

0 commit comments

Comments
 (0)