diff --git a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h index 561b80884ca..4411bee8a0d 100644 --- a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h +++ b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h @@ -1229,6 +1229,54 @@ struct __mismatch_fn inline constexpr __internal::__mismatch_fn mismatch; +// [alg.starts.with] [alg.ends.with] + +namespace __internal +{ +struct __starts_with_fn +{ + template + requires oneapi::dpl::is_execution_policy_v> && + std::ranges::sized_range<_R1> && std::ranges::sized_range<_R2> && + std::indirectly_comparable, std::ranges::iterator_t<_R2>, + _Pred, _Proj1, _Proj2> + bool + operator()(_ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _Pred __pred = {}, _Proj1 __proj1 = {}, + _Proj2 __proj2 = {}) const + { + // To ensure no dangling iterator is returned, __r2 may not be forwarded + return std::ranges::end(__r2) == oneapi::dpl::ranges::mismatch(std::forward<_ExecutionPolicy>(__exec), __r1, + __r2, __pred, __proj1, __proj2).in2; + } +}; // __starts_with_fn + +struct __ends_with_fn +{ + template + requires oneapi::dpl::is_execution_policy_v> && + std::ranges::sized_range<_R1> && std::ranges::sized_range<_R2> && + std::indirectly_comparable, std::ranges::iterator_t<_R2>, + _Pred, _Proj1, _Proj2> + bool + operator()(_ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _Pred __pred = {}, _Proj1 __proj1 = {}, + _Proj2 __proj2 = {}) const + { + using _Size = std::common_type_t; + _Size __n1 = std::ranges::size(__r1); + _Size __n2 = std::ranges::size(__r2); + return !(__n1 < __n2) && oneapi::dpl::ranges::equal(std::forward<_ExecutionPolicy>(__exec), + std::views::all(__r1) | std::views::drop(__n1 - __n2), + std::forward<_R2>(__r2),__pred, __proj1, __proj2); + } +}; // __ends_with_fn +} // __internal + +inline constexpr __internal::__starts_with_fn starts_with; +inline constexpr __internal::__ends_with_fn ends_with; + + // [alg.remove_if] namespace __internal diff --git a/test/parallel_api/ranges/std_ranges_ends_with.pass.cpp b/test/parallel_api/ranges/std_ranges_ends_with.pass.cpp new file mode 100644 index 00000000000..91999aa252c --- /dev/null +++ b/test/parallel_api/ranges/std_ranges_ends_with.pass.cpp @@ -0,0 +1,54 @@ +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2026 UXL Foundation Contributors +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "std_ranges_test.h" + +#include "std_ranges_test.h" + +#if _ENABLE_STD_RANGES_TESTING + template + using launcher = test_std_ranges::test_range_algo; +#if __cpp_lib_ranges_starts_ends_with >= 202106L + auto checker = TEST_PREPARE_CALLABLE(std::ranges::ends_with); +#else + struct { + template + bool operator()(R1&& r1, R2&& r2, Pred pred = {}, Proj1 proj1 = {}, Proj2 proj2 = {}) + { + std::ranges::reverse_view r2_reverse{r2}; + auto last = r2_reverse.end(); + return std::ranges::mismatch(std::ranges::reverse_view(r1), r2_reverse, pred, proj1, proj2).in2 == last; + } + } checker; +#endif +#endif + +std::int32_t +main() +{ +#if _ENABLE_STD_RANGES_TESTING + using namespace test_std_ranges; + namespace dpl_ranges = oneapi::dpl::ranges; + + using data_gen_shift_med = decltype([](auto i){ return i + medium_size/2; }); + using data_gen_shift_big = decltype([](auto i){ return i + big_size/2; }); + + launcher<0, int>{big_sz}(dpl_ranges::ends_with, checker, binary_pred_const); + launcher<1, int>{}(dpl_ranges::ends_with, checker, binary_pred, proj); + launcher<2, int, decltype(proj)>{}(dpl_ranges::ends_with, checker, binary_pred, proj); + launcher<3, P2>{}(dpl_ranges::ends_with, checker, binary_pred_const, &P2::x, &P2::proj); + launcher<4, P2>{}(dpl_ranges::ends_with, checker, binary_pred, &P2::proj, &P2::x); + launcher<5, int, data_gen_shift_med>{}(dpl_ranges::ends_with, checker); + launcher<6, int, data_gen_shift_big>{big_sz}(dpl_ranges::ends_with, checker); +#endif //_ENABLE_STD_RANGES_TESTING + + return TestUtils::done(_ENABLE_STD_RANGES_TESTING); +} diff --git a/test/parallel_api/ranges/std_ranges_starts_with.pass.cpp b/test/parallel_api/ranges/std_ranges_starts_with.pass.cpp new file mode 100644 index 00000000000..3e259f0272b --- /dev/null +++ b/test/parallel_api/ranges/std_ranges_starts_with.pass.cpp @@ -0,0 +1,52 @@ +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2026 UXL Foundation Contributors +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "std_ranges_test.h" + +#include "std_ranges_test.h" + +#if _ENABLE_STD_RANGES_TESTING + template + using launcher = test_std_ranges::test_range_algo; +#if __cpp_lib_ranges_starts_ends_with >= 202106L + auto checker = TEST_PREPARE_CALLABLE(std::ranges::starts_with); +#else + struct { + template + bool operator()(R1&& r1, R2&& r2, Pred pred = {}, Proj1 proj1 = {}, Proj2 proj2 = {}) + { + auto last = std::ranges::end(r2); + return std::ranges::mismatch(r1, r2, pred, proj1, proj2).in2 == last; + } + } checker; +#endif +#endif + +std::int32_t +main() +{ +#if _ENABLE_STD_RANGES_TESTING + using namespace test_std_ranges; + namespace dpl_ranges = oneapi::dpl::ranges; + + auto almost_always_i = [](auto i){ return (i == medium_size/2 + 19)? 0 : i; }; + using data_gen_needle = decltype(almost_always_i); + + launcher<0, int>{big_sz}(dpl_ranges::starts_with, checker, binary_pred_const); + launcher<1, int>{}(dpl_ranges::starts_with, checker, binary_pred, proj); + launcher<2, int, decltype(proj)>{}(dpl_ranges::starts_with, checker, binary_pred, proj); + launcher<3, P2>{}(dpl_ranges::starts_with, checker, binary_pred_const, &P2::x, &P2::proj); + launcher<4, P2>{}(dpl_ranges::starts_with, checker, binary_pred, &P2::proj, &P2::x); + launcher<5, int, data_gen_needle>{}(dpl_ranges::starts_with, checker); +#endif //_ENABLE_STD_RANGES_TESTING + + return TestUtils::done(_ENABLE_STD_RANGES_TESTING); +}