Skip to content

Commit ce4acb4

Browse files
authored
[oneDPL] The internal functors refactoring: reducing number of functors, renaming etc (#2291)
* [oneDPL] the internal functors refactoring: reducing number of functor, renaming etc * [oneDPL][hetero] + using std::swap; and passing ranges by ref * [oneDPL] + reverse tag for __transform_functor * [oneDPL] corrected a term in an error message * [oneDPL] auto -> bool, if it is applicable (for predicates). * [oneDPL] + missing test cases for functors __unary_op and __binary_op * [oneDPL] + missing negative test cases for functors __unary_op and __binary_op
1 parent 3f1c9a8 commit ce4acb4

File tree

6 files changed

+82
-70
lines changed

6 files changed

+82
-70
lines changed

include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,6 @@ __pattern_walk_n(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Function
6868
}
6969

7070
#if _ONEDPL_CPP20_RANGES_PRESENT
71-
template <typename _F, typename _Proj>
72-
struct __pattern_transform_unary_op
73-
{
74-
_F __op;
75-
_Proj __proj;
76-
77-
template <typename _TValue>
78-
auto
79-
operator()(_TValue&& __val) const
80-
{
81-
return std::invoke(__op, std::invoke(__proj, std::forward<_TValue>(__val)));
82-
}
83-
};
8471

8572
//---------------------------------------------------------------------------------------------------------------------
8673
// pattern_for_each
@@ -89,7 +76,7 @@ template <typename _BackendTag, typename _ExecutionPolicy, typename _R, typename
8976
void
9077
__pattern_for_each(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R&& __r, _Fun __f, _Proj __proj)
9178
{
92-
__pattern_transform_unary_op<_Fun, _Proj> __f_1{__f, __proj};
79+
oneapi::dpl::__internal::__unary_op<_Fun, _Proj> __f_1{__f, __proj};
9380

9481
oneapi::dpl::__internal::__ranges::__pattern_walk_n(__tag, std::forward<_ExecutionPolicy>(__exec), __f_1,
9582
oneapi::dpl::__ranges::views::all(std::forward<_R>(__r)));
@@ -105,7 +92,7 @@ __pattern_transform(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec,
10592
_F __op, _Proj __proj)
10693
{
10794
assert(std::ranges::size(__in_r) <= std::ranges::size(__out_r)); // for debug purposes only
108-
__pattern_transform_unary_op<_F, _Proj> __unary_op{__op, __proj};
95+
oneapi::dpl::__internal::__unary_op<_F, _Proj> __unary_op{__op, __proj};
10996

11097
oneapi::dpl::__internal::__ranges::__pattern_walk_n(__tag, std::forward<_ExecutionPolicy>(__exec),
11198
oneapi::dpl::__internal::__transform_functor<decltype(__unary_op)>{std::move(__unary_op)},
@@ -256,7 +243,7 @@ template <typename _BackendTag, typename _ExecutionPolicy, typename _R, typename
256243
auto
257244
__pattern_find_if(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R&& __r, _Pred __pred, _Proj __proj)
258245
{
259-
__pattern_transform_unary_op<_Pred, _Proj> __pred_1{__pred, __proj};
246+
oneapi::dpl::__internal::__unary_op<_Pred, _Proj> __pred_1{__pred, __proj};
260247

261248
auto __idx = oneapi::dpl::__internal::__ranges::__pattern_find_if(__tag, std::forward<_ExecutionPolicy>(__exec),
262249
oneapi::dpl::__ranges::views::all_read(__r), __pred_1);
@@ -379,7 +366,7 @@ template <typename _BackendTag, typename _ExecutionPolicy, typename _R, typename
379366
bool
380367
__pattern_any_of(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R&& __r, _Pred __pred, _Proj __proj)
381368
{
382-
__pattern_transform_unary_op<_Pred, _Proj> __pred_1{__pred, __proj};
369+
oneapi::dpl::__internal::__unary_op<_Pred, _Proj> __pred_1{__pred, __proj};
383370

384371
return oneapi::dpl::__internal::__ranges::__pattern_any_of(__tag, std::forward<_ExecutionPolicy>(__exec),
385372
oneapi::dpl::__ranges::views::all_read(std::forward<_R>(__r)), __pred_1);
@@ -485,7 +472,7 @@ struct __pattern_search_n_pred
485472
_Proj __proj;
486473

487474
template <typename _TValue1, typename _TValue2>
488-
auto
475+
bool
489476
operator()(_TValue1&& __val1, _TValue2&& __val2) const
490477
{
491478
return std::invoke(__pred, std::invoke(__proj, std::forward<_TValue1>(__val1)), std::forward<_TValue2>(__val2));
@@ -605,7 +592,7 @@ template <typename _BackendTag, typename _ExecutionPolicy, typename _R, typename
605592
std::ranges::range_difference_t<_R>
606593
__pattern_count_if(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R&& __r, _Pred __pred, _Proj __proj)
607594
{
608-
__pattern_transform_unary_op<_Pred, _Proj> __pred_1{__pred, __proj};
595+
oneapi::dpl::__internal::__unary_op<_Pred, _Proj> __pred_1{__pred, __proj};
609596

610597
return oneapi::dpl::__internal::__ranges::__pattern_count(__tag, ::std::forward<_ExecutionPolicy>(__exec),
611598
oneapi::dpl::__ranges::views::all_read(::std::forward<_R>(__r)), __pred_1);
@@ -650,7 +637,7 @@ auto
650637
__pattern_copy_if_ranges(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _InRange&& __in_r, _OutRange&& __out_r,
651638
_Pred __pred, _Proj __proj)
652639
{
653-
__pattern_transform_unary_op<_Pred, _Proj> __pred_1{__pred, __proj};
640+
oneapi::dpl::__internal::__unary_op<_Pred, _Proj> __pred_1{__pred, __proj};
654641

655642
auto __res_idx = oneapi::dpl::__internal::__ranges::__pattern_copy_if(__tag,
656643
std::forward<_ExecutionPolicy>(__exec), oneapi::dpl::__ranges::views::all_read(__in_r),

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -921,15 +921,15 @@ struct __scalar_store_transform_op
921921
// Unary transformations into an output buffer
922922
template <typename _IdxType1, typename _IdxType2, typename _SourceAcc, typename _DestAcc>
923923
void
924-
operator()(_IdxType1 __idx_source, _IdxType2 __idx_dest, _SourceAcc __source_acc, _DestAcc __dest_acc) const
924+
operator()(_IdxType1 __idx_source, _IdxType2 __idx_dest, _SourceAcc&& __source_acc, _DestAcc&& __dest_acc) const
925925
{
926926
__transform(__source_acc[__idx_source], __dest_acc[__idx_dest]);
927927
}
928928
// Binary transformations into an output buffer
929929
template <typename _IdxType1, typename _IdxType2, typename _Source1Acc, typename _Source2Acc, typename _DestAcc>
930930
void
931-
operator()(_IdxType1 __idx_source, _IdxType2 __idx_dest, _Source1Acc __source1_acc, _Source2Acc __source2_acc,
932-
_DestAcc __dest_acc) const
931+
operator()(_IdxType1 __idx_source, _IdxType2 __idx_dest, _Source1Acc&& __source1_acc, _Source2Acc&& __source2_acc,
932+
_DestAcc&& __dest_acc) const
933933
{
934934
__transform(__source1_acc[__idx_source], __source2_acc[__idx_source], __dest_acc[__idx_dest]);
935935
}
@@ -997,18 +997,24 @@ struct __vector_reverse
997997
static_assert(__vec_size <= 4, "Only vector sizes of 4 or less are supported");
998998
template <typename _Idx, typename _Array>
999999
void
1000-
operator()(/*__is_full*/ std::true_type, const _Idx /*__elements_to_process*/, _Array __array) const
1000+
operator()(/*__is_full*/ std::true_type, const _Idx /*__elements_to_process*/, _Array&& __array) const
10011001
{
10021002
_ONEDPL_PRAGMA_UNROLL
10031003
for (std::uint8_t __i = 0; __i < __vec_size / 2; ++__i)
1004-
std::swap(__array[__i], __array[__vec_size - __i - 1]);
1004+
{
1005+
using std::swap;
1006+
swap(__array[__i], __array[__vec_size - __i - 1]);
1007+
}
10051008
}
10061009
template <typename _Idx, typename _Array>
10071010
void
1008-
operator()(/*__is_full*/ std::false_type, const _Idx __elements_to_process, _Array __array) const
1011+
operator()(/*__is_full*/ std::false_type, const _Idx __elements_to_process, _Array&& __array) const
10091012
{
10101013
for (std::uint8_t __i = 0; __i < __elements_to_process / 2; ++__i)
1011-
std::swap(__array[__i], __array[__elements_to_process - __i - 1]);
1014+
{
1015+
using std::swap;
1016+
swap(__array[__i], __array[__elements_to_process - __i - 1]);
1017+
}
10121018
}
10131019
};
10141020

include/oneapi/dpl/pstl/hetero/dpcpp/sycl_traits.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,13 @@ class __set_value;
7171
template <typename _Comp, typename _Proj>
7272
struct __predicate;
7373

74+
template <typename _F, typename _Proj>
75+
struct __unary_op;
76+
7477
template <typename _F, typename _Proj1, typename _Proj2>
7578
struct __binary_op;
7679

77-
template <typename _Pred>
80+
template <typename _Pred, typename _RevTag>
7881
class __transform_functor;
7982

8083
template <typename _UnaryOper, typename _UnaryPred>
@@ -151,9 +154,6 @@ struct __parallel_reduce_by_segment_fallback_fn1;
151154
template <typename _BinaryPredicate>
152155
struct __parallel_reduce_by_segment_fallback_fn2;
153156

154-
template <typename _Op, typename _It1ValueT, typename _It2ValueTRef>
155-
struct __pattern_adjacent_difference_op_caller_fn;
156-
157157
} // namespace oneapi::dpl::__internal
158158

159159
template <typename _Pred>
@@ -192,14 +192,20 @@ struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__internal::
192192
{
193193
};
194194

195+
template <typename _F, typename _Proj>
196+
struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__internal::__unary_op, _F, _Proj)>
197+
: oneapi::dpl::__internal::__are_all_device_copyable<_F, _Proj>
198+
{
199+
};
200+
195201
template <typename _F, typename _Proj1, typename _Proj2>
196202
struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__internal::__binary_op, _F, _Proj1, _Proj2)>
197203
: oneapi::dpl::__internal::__are_all_device_copyable<_F, _Proj1, _Proj2>
198204
{
199205
};
200206

201-
template <typename _Pred>
202-
struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__internal::__transform_functor, _Pred)>
207+
template <typename _Pred, typename _RevTag>
208+
struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__internal::__transform_functor, _Pred, _RevTag)>
203209
: oneapi::dpl::__internal::__are_all_device_copyable<_Pred>
204210
{
205211
};
@@ -365,13 +371,6 @@ struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(
365371
{
366372
};
367373

368-
template <typename _Op, typename _It1ValueT, typename _It2ValueTRef>
369-
struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(
370-
oneapi::dpl::__internal::__pattern_adjacent_difference_op_caller_fn, _Op, _It1ValueT, _It2ValueTRef)>
371-
: oneapi::dpl::__internal::__are_all_device_copyable<_Op, _It1ValueT, _It2ValueTRef>
372-
{
373-
};
374-
375374
namespace oneapi::dpl::__internal::__ranges
376375
{
377376

include/oneapi/dpl/pstl/hetero/numeric_impl_hetero.h

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -223,20 +223,6 @@ struct adjacent_difference_wrapper
223223
{
224224
};
225225

226-
template <typename _Op, typename _It1ValueT, typename _It2ValueTRef>
227-
struct __pattern_adjacent_difference_op_caller_fn
228-
{
229-
_Op __op;
230-
231-
// TODO investigate why we can't use oneapi::dpl::__internal::__transform_functor
232-
// instead this predicate
233-
void
234-
operator()(_It1ValueT __in1, _It1ValueT __in2, _It2ValueTRef __out1) const
235-
{
236-
__out1 = __op(__in2, __in1); // This move assignment is allowed by the C++ standard draft N4810
237-
}
238-
};
239-
240226
template <typename _BackendTag, typename _ExecutionPolicy, typename _ForwardIterator1, typename _ForwardIterator2,
241227
typename _BinaryOperation>
242228
_ForwardIterator2
@@ -247,9 +233,6 @@ __pattern_adjacent_difference(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __ex
247233
if (__n <= 0)
248234
return __d_first;
249235

250-
using _It1ValueT = typename ::std::iterator_traits<_ForwardIterator1>::value_type;
251-
using _It2ValueTRef = typename ::std::iterator_traits<_ForwardIterator2>::reference;
252-
253236
_ForwardIterator2 __d_last = __d_first + __n;
254237

255238
#if !__SYCL_UNNAMED_LAMBDA__
@@ -265,7 +248,7 @@ __pattern_adjacent_difference(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __ex
265248
else
266249
#endif
267250
{
268-
__pattern_adjacent_difference_op_caller_fn<_BinaryOperation, _It1ValueT, _It2ValueTRef> __fn{__op};
251+
oneapi::dpl::__internal::__transform_functor<_BinaryOperation, std::true_type> __fn{__op};
269252

270253
auto __keep1 =
271254
oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, _ForwardIterator1>();

include/oneapi/dpl/pstl/utils.h

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,21 @@ struct __predicate
150150
template <typename _Comp, typename _Proj>
151151
using __compare = __predicate<_Comp, _Proj>;
152152

153+
template <typename _F, typename _Proj>
154+
struct __unary_op
155+
{
156+
//'mutable' is to relax the requirements for a user functor or/and projection type operator() may be non-const
157+
mutable _F __f;
158+
mutable _Proj __proj;
159+
160+
template <typename _TValue>
161+
decltype(auto)
162+
operator()(_TValue&& __val) const
163+
{
164+
return std::invoke(__f, std::invoke(__proj, std::forward<_TValue>(__val)));
165+
}
166+
};
167+
153168
template <typename _F, typename _Proj1, typename _Proj2>
154169
struct __binary_op
155170
{
@@ -159,7 +174,7 @@ struct __binary_op
159174
mutable _Proj2 __proj2;
160175

161176
template <typename _TValue1, typename _TValue2>
162-
auto
177+
decltype(auto)
163178
operator()(_TValue1&& __val1, _TValue2&& __val2) const
164179
{
165180
return std::invoke(__f, std::invoke(__proj1, std::forward<_TValue1>(__val1)),
@@ -299,37 +314,41 @@ class __set_value
299314

300315
//TODO: to do the same fix for output type (by re-using __transform_functor if applicable) for the other functor below:
301316
// __transform_if_unary_functor, __transform_if_binary_functor, __replace_functor, __replace_copy_functor
302-
//TODO: to make input type consistently: const T& or T&&; to think which way is preferable
303-
template <typename _Pred>
317+
template <typename _F, typename _RevTag = std::false_type>
304318
class __transform_functor
305319
{
306-
mutable _Pred _M_pred;
320+
mutable _F __f;
307321

308322
public:
309-
explicit __transform_functor(_Pred __pred) : _M_pred(::std::move(__pred)) {}
323+
explicit __transform_functor(_F __f) : __f(std::move(__f)) {}
310324

311325
template <typename _Input1Type, typename _Input2Type, typename _OutputType>
312326
void
313-
operator()(const _Input1Type& __x, const _Input2Type& __y, _OutputType&& __output) const
327+
operator()(_Input1Type&& __x, _Input2Type&& __y, _OutputType&& __output) const
314328
{
315-
__transform_impl(::std::forward<_OutputType>(__output), __x, __y);
329+
if constexpr (_RevTag())
330+
__transform_impl(std::forward<_OutputType>(__output), std::forward<_Input1Type>(__y),
331+
std::forward<_Input2Type>(__x));
332+
else
333+
__transform_impl(std::forward<_OutputType>(__output), std::forward<_Input1Type>(__x),
334+
std::forward<_Input2Type>(__y));
316335
}
317336

318337
template <typename _InputType, typename _OutputType>
319338
void
320339
operator()(_InputType&& __x, _OutputType&& __output) const
321340
{
322-
__transform_impl(::std::forward<_OutputType>(__output), ::std::forward<_InputType>(__x));
341+
__transform_impl(std::forward<_OutputType>(__output), std::forward<_InputType>(__x));
323342
}
324343

325344
private:
326345
template <typename _OutputType, typename... _Args>
327346
void
328347
__transform_impl(_OutputType&& __output, _Args&&... __args) const
329348
{
330-
static_assert(sizeof...(_Args) < 3, "A predicate supports either unary or binary transformation");
331-
static_assert(::std::is_invocable_v<_Pred, _Args...>, "A predicate cannot be called with the passed arguments");
332-
::std::forward<_OutputType>(__output) = _M_pred(::std::forward<_Args>(__args)...);
349+
static_assert(sizeof...(_Args) < 3, "A functor supports either unary or binary transformation");
350+
static_assert(::std::is_invocable_v<_F, _Args...>, "A functor cannot be called with the passed arguments");
351+
std::forward<_OutputType>(__output) = __f(std::forward<_Args>(__args)...);
333352
}
334353
};
335354

@@ -954,7 +973,7 @@ struct __count_fn_pred
954973
_Proj __proj;
955974

956975
template <typename _TValue>
957-
auto
976+
bool
958977
operator()(_TValue&& __val) const
959978
{
960979
return std::ranges::equal_to{}(std::invoke(__proj, std::forward<_TValue>(__val)), __value);

test/general/implementation_details/device_copyable.pass.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,15 @@ test_device_copyable()
315315
sycl::is_device_copyable_v<oneapi::dpl::__internal::tuple<std::tuple<int_device_copyable, int_device_copyable>,
316316
int_device_copyable, int_device_copyable>>,
317317
"tuple is not device copyable with device copyable types");
318+
319+
//__unary_op
320+
static_assert(sycl::is_device_copyable_v<oneapi::dpl::__internal::__unary_op<noop_device_copyable,
321+
noop_device_copyable>>, "__unary_op is not device copyable with device copyable types");
322+
323+
//__binary_op
324+
static_assert(sycl::is_device_copyable_v<oneapi::dpl::__internal::__binary_op<noop_device_copyable,
325+
noop_device_copyable, noop_device_copyable>>,
326+
"__binary_op is not device copyable with device copyable types");
318327
}
319328

320329
void
@@ -592,6 +601,15 @@ test_non_device_copyable()
592601
static_assert(!sycl::is_device_copyable_v<oneapi::dpl::__internal::tuple<
593602
std::tuple<int_non_device_copyable, int_device_copyable>, int_device_copyable>>,
594603
"tuple is device copyable with non device copyable types");
604+
605+
//__unary_op
606+
static_assert(!sycl::is_device_copyable_v<oneapi::dpl::__internal::__unary_op<noop_non_device_copyable,
607+
noop_non_device_copyable>>, "__unary_op is device copyable with non device copyable types");
608+
609+
//__binary_op
610+
static_assert(!sycl::is_device_copyable_v<oneapi::dpl::__internal::__binary_op<noop_non_device_copyable,
611+
noop_non_device_copyable, noop_non_device_copyable>>,
612+
"__binary_op is device copyable with non device copyable types");
595613
}
596614

597615
#endif // TEST_DPCPP_BACKEND_PRESENT

0 commit comments

Comments
 (0)