diff --git a/include/oneapi/math/rng/device/detail/engine_base.hpp b/include/oneapi/math/rng/device/detail/engine_base.hpp index 07efadb41..00b7882c8 100644 --- a/include/oneapi/math/rng/device/detail/engine_base.hpp +++ b/include/oneapi/math/rng/device/detail/engine_base.hpp @@ -39,5 +39,6 @@ class engine_base {}; #include "oneapi/math/rng/device/detail/mrg32k3a_impl.hpp" #include "oneapi/math/rng/device/detail/mcg31m1_impl.hpp" #include "oneapi/math/rng/device/detail/mcg59_impl.hpp" +#include "oneapi/math/rng/device/detail/pcg64_dxsm_impl.hpp" #endif // ONEMATH_RNG_DEVICE_ENGINE_BASE_HPP_ diff --git a/include/oneapi/math/rng/device/detail/pcg64_dxsm_impl.hpp b/include/oneapi/math/rng/device/detail/pcg64_dxsm_impl.hpp new file mode 100644 index 000000000..8fa9f530c --- /dev/null +++ b/include/oneapi/math/rng/device/detail/pcg64_dxsm_impl.hpp @@ -0,0 +1,414 @@ +/******************************************************************************* +* Copyright (C) 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef ONEMATH_RNG_DEVICE_PCG64_DXSM_IMPL_HPP_ +#define ONEMATH_RNG_DEVICE_PCG64_DXSM_IMPL_HPP_ + +namespace oneapi::math::rng::device { + +template +class pcg64_dxsm; + +namespace detail { + +struct pcg128_t { + std::uint64_t high; + std::uint64_t low; +}; + +struct pcg64_dxsm_param { + static constexpr std::uint64_t cheap_multiplier = 0xda942042e4dd58b5ULL; + static constexpr pcg128_t increment = { 0x5851f42d4c957f2dULL, 0x14057b7ef767814fULL }; +}; + +template +struct engine_state> { + pcg128_t s; + std::uint32_t result; + bool has_32; +}; + +namespace pcg64_dxsm_impl { + +static inline pcg128_t mul_64x64(std::uint64_t x, std::uint64_t y) { + const std::uint64_t x_lo = x & 0xFFFFFFFFULL; + const std::uint64_t x_hi = x >> 32; + const std::uint64_t y_lo = y & 0xFFFFFFFFULL; + const std::uint64_t y_hi = y >> 32; + + const std::uint64_t xy_md = x_hi * y_lo; + const std::uint64_t yx_md = y_hi * x_lo; + + return { (x_hi * y_hi) + (xy_md >> 32) + (yx_md >> 32) + + (((xy_md & 0xFFFFFFFFULL) + (yx_md & 0xFFFFFFFFULL) + ((x_lo * y_lo) >> 32)) >> + 32), + x * y }; +} + +static inline pcg128_t mul_lo_128x128(pcg128_t x, pcg128_t y) { + pcg128_t result = mul_64x64(x.low, y.low); + result.high += (x.high * y.low) + (x.low * y.high); + return result; +} + +static inline pcg128_t mul_lo_128x64(pcg128_t x, std::uint64_t y) { + std::uint64_t x_high = x.high; + x = mul_64x64(x.low, y); + x.high += x_high * y; + return x; +} + +static inline pcg128_t add_128x128(pcg128_t x, pcg128_t y) { + x.low += y.low; + x.high += (y.high + (x.low < y.low)); + return x; +} + +static inline std::uint64_t prepare_output(const pcg128_t local_state) { + std::uint64_t tmp = local_state.high; + + tmp ^= tmp >> 32; + tmp *= pcg64_dxsm_param::cheap_multiplier; + tmp ^= tmp >> 48; + tmp *= (local_state.low | 1); + + return tmp; +} + +static inline void update_state(pcg128_t& local_state) { + local_state = add_128x128(pcg64_dxsm_param::increment, + mul_lo_128x64(local_state, pcg64_dxsm_param::cheap_multiplier)); +} + +template +static inline void bump32(engine_state>& state) { + if (state.has_32) { + state.has_32 = false; + } + else { + std::uint64_t tmp = prepare_output(state.s); + state.result = static_cast(tmp >> 32); + + update_state(state.s); + state.has_32 = true; + } +} + +template +static inline void skip_ahead(engine_state>& state, + pcg128_t num_to_skip) { + pcg128_t acc_mul = { 0ULL, 1ULL }; + pcg128_t acc_inc = { 0ULL, 0ULL }; + pcg128_t tmp_mul = { 0ULL, pcg64_dxsm_param::cheap_multiplier }; + pcg128_t tmp_inc = pcg64_dxsm_param::increment; + bool is_skip_odd = num_to_skip.low & 1; + num_to_skip.low >>= 1; + num_to_skip.low |= num_to_skip.high << 63; + num_to_skip.high >>= 1; + + while (num_to_skip.low || num_to_skip.high) { + if (num_to_skip.low & 1) { + acc_mul = mul_lo_128x128(acc_mul, tmp_mul); + acc_inc = add_128x128(mul_lo_128x128(acc_inc, tmp_mul), tmp_inc); + } + + tmp_inc = mul_lo_128x128(add_128x128(tmp_mul, { 0, 1 }), tmp_inc); + tmp_mul = mul_lo_128x128(tmp_mul, tmp_mul); + + num_to_skip.low >>= 1; + num_to_skip.low |= num_to_skip.high << 63; + num_to_skip.high >>= 1; + } + + state.s = add_128x128(mul_lo_128x128(state.s, acc_mul), acc_inc); + + if (is_skip_odd) { + bump32(state); + } +} + +template +static inline void skip_ahead(engine_state>& state, + std::uint64_t num_to_skip) { + pcg128_t acc_mul = { 0ULL, 1ULL }; + pcg128_t acc_inc = { 0ULL, 0ULL }; + pcg128_t tmp_mul = { 0ULL, pcg64_dxsm_param::cheap_multiplier }; + pcg128_t tmp_inc = pcg64_dxsm_param::increment; + bool is_skip_odd = num_to_skip & 1; + num_to_skip >>= 1; + + while (num_to_skip) { + if (num_to_skip & 1) { + acc_mul = mul_lo_128x128(acc_mul, tmp_mul); + acc_inc = add_128x128(mul_lo_128x128(acc_inc, tmp_mul), tmp_inc); + } + + tmp_inc = mul_lo_128x128(add_128x128(tmp_mul, { 0, 1 }), tmp_inc); + tmp_mul = mul_lo_128x128(tmp_mul, tmp_mul); + + num_to_skip >>= 1; + } + state.s = add_128x128(mul_lo_128x128(state.s, acc_mul), acc_inc); + + if (is_skip_odd) { + bump32(state); + } +} + +template +static inline void init(engine_state>& state, + std::uint64_t n, const std::uint64_t* seed_ptr, std::uint64_t n_offset, + const std::uint64_t* offset_ptr) { + state.s.low = seed_ptr[0]; + state.s.high = (n > 1 ? seed_ptr[1] : 0ULL); + + state.s = add_128x128(state.s, pcg64_dxsm_param::increment); + state.s = add_128x128(pcg64_dxsm_param::increment, + mul_lo_128x64(state.s, pcg64_dxsm_param::cheap_multiplier)); + + if (n_offset > 1) + pcg64_dxsm_impl::skip_ahead(state, { offset_ptr[1], offset_ptr[0] }); + else + pcg64_dxsm_impl::skip_ahead(state, offset_ptr[0]); + + state.has_32 = false; +} + +template +static inline sycl::vec generate64( + engine_state>& state) { + sycl::vec res; + + for (int i = 0; i < VecSize; i++) { + res[i] = prepare_output(state.s); + update_state(state.s); + } + + return res; +} + +template +static inline sycl::vec generate32_even( + engine_state>& state) { + sycl::vec res; + std::uint64_t tmp; + + for (int i = 0; i < (VecSize / 2); i++) { + tmp = prepare_output(state.s); + res[2 * i] = static_cast(tmp); + res[(2 * i) + 1] = static_cast(tmp >> 32); + + update_state(state.s); + } + + return res; +} + +template +static inline sycl::vec generate32_odd( + engine_state>& state) { + sycl::vec res; + + if constexpr (VecSize == 1) { + if (state.has_32) { + res[0] = state.result; + state.has_32 = false; + } + else { + std::uint64_t tmp = prepare_output(state.s); + res[0] = static_cast(tmp); + state.result = static_cast(tmp >> 32); + + update_state(state.s); + state.has_32 = true; + } + } + else { + std::uint64_t tmp; + int shift = (state.has_32 ? 1 : 0); + + tmp = prepare_output(state.s); + res[shift] = static_cast(tmp); + res[shift + 1] = static_cast(tmp >> 32); + + update_state(state.s); + + if (state.has_32) { + res[0] = state.result; + state.has_32 = false; + } + else { + tmp = prepare_output(state.s); + + res[2] = static_cast(tmp); + state.result = static_cast(tmp >> 32); + + update_state(state.s); + state.has_32 = true; + } + } + + return res; +} + +template +static inline sycl::vec generate32( + engine_state>& state) { + if constexpr (VecSize % 2 == 1) + return generate32_odd(state); + else + return generate32_even(state); +} + +template +static inline std::uint64_t generate_single64( + engine_state>& state) { + std::uint64_t res = prepare_output(state.s); + update_state(state.s); + return res; +} + +template +static inline std::uint64_t generate_single32( + engine_state>& state) { + std::uint32_t res; + + if (state.has_32) { + res = state.result; + state.has_32 = false; + } + else { + std::uint64_t tmp = prepare_output(state.s); + + res = static_cast(tmp); + state.result = static_cast(tmp >> 32); + + update_state(state.s); + state.has_32 = true; + } + + return res; +} + +} // namespace pcg64_dxsm_impl + +template +class engine_base> { +protected: + engine_base(std::uint64_t seed, std::uint64_t offset = 0) { + pcg64_dxsm_impl::init(this->state_, 1, &seed, 1, &offset); + } + + engine_base(std::uint64_t n, const std::uint64_t* seed, std::uint64_t offset = 0) { + pcg64_dxsm_impl::init(this->state_, n, seed, 1, &offset); + } + + engine_base(std::uint64_t seed, std::uint64_t n_offset, const std::uint64_t* offset_ptr) { + pcg64_dxsm_impl::init(this->state_, 1, &seed, n_offset, offset_ptr); + } + + engine_base(std::uint64_t n, const std::uint64_t* seed, std::uint64_t n_offset, + const std::uint64_t* offset_ptr) { + pcg64_dxsm_impl::init(this->state_, n, seed, n_offset, offset_ptr); + } + + template + inline auto generate(RealType a, RealType b) -> + typename std::conditional>::type { + sycl::vec res; + sycl::vec res_uint; + RealType a1 = (b + a) / static_cast(2.0); + RealType c1 = + (b - a) / (static_cast((std::numeric_limits::max)()) + 1); + + res_uint = pcg64_dxsm_impl::generate32(this->state_); + for (int i = 0; i < VecSize; i++) { + res[i] = static_cast(static_cast(res_uint[i])) * c1 + a1; + } + return res; + } + + inline auto generate() -> typename std::conditional>::type { + return pcg64_dxsm_impl::generate32(this->state_); + } + + template + inline auto generate_uniform_bits() -> + typename std::conditional>::type { + if constexpr (std::is_same::value) { + return pcg64_dxsm_impl::generate32(this->state_); + } + else { + return pcg64_dxsm_impl::generate64(this->state_); + } + } + + template + inline RealType generate_single(RealType a, RealType b) { + std::uint32_t res_uint; + RealType res; + RealType a1 = (b + a) / static_cast(2.0); + RealType c1 = + (b - a) / (static_cast((std::numeric_limits::max)()) + 1); + + res_uint = pcg64_dxsm_impl::generate_single32(this->state_); + + res = static_cast(static_cast(res_uint)) * c1 + a1; + + return res; + } + + inline std::uint32_t generate_single() { + return pcg64_dxsm_impl::generate_single32(this->state_); + } + + template + inline auto generate_single_uniform_bits() { + if constexpr (std::is_same::value) { + return pcg64_dxsm_impl::generate_single32(this->state_); + } + else { + return pcg64_dxsm_impl::generate_single64(this->state_); + } + } + + void skip_ahead(std::uint64_t num_to_skip) { + detail::pcg64_dxsm_impl::skip_ahead(this->state_, num_to_skip); + } + + void skip_ahead(std::initializer_list num_to_skip) { + if (num_to_skip.size() > 1) { + detail::pcg64_dxsm_impl::skip_ahead(this->state_, + { num_to_skip.begin()[1], num_to_skip.begin()[0] }); + } + else { + detail::pcg64_dxsm_impl::skip_ahead(this->state_, num_to_skip.begin()[0]); + } + } + + engine_state> state_; + friend class oneapi::math::rng::device::count_engine_adaptor< + oneapi::math::rng::device::pcg64_dxsm>; +}; + +} // namespace detail +} // namespace oneapi::math::rng::device + +#endif // ONEMATH_RNG_DEVICE_PCG64_DXSM_IMPL_HPP_ diff --git a/include/oneapi/math/rng/device/engines.hpp b/include/oneapi/math/rng/device/engines.hpp index b68874eaa..587ad313e 100644 --- a/include/oneapi/math/rng/device/engines.hpp +++ b/include/oneapi/math/rng/device/engines.hpp @@ -165,6 +165,47 @@ class mcg59 : public detail::engine_base> { friend class detail::distribution_base; }; +// Class template oneapi::math::rng::device::pcg64_dxsm +// +// Represents PCG64-DXSM pseudorandom number generator +// +// Supported parallelization methods: +// skip_ahead +// +template +class pcg64_dxsm : public detail::engine_base> { +public: + static constexpr std::uint64_t default_seed = 0; + + static constexpr std::int32_t vec_size = VecSize; + + pcg64_dxsm() : detail::engine_base>(default_seed) {} + + pcg64_dxsm(std::uint64_t seed, std::uint64_t offset = 0) + : detail::engine_base>(seed, offset) {} + + pcg64_dxsm(std::initializer_list seed, std::uint64_t offset = 0) + : detail::engine_base>(seed.size(), seed.begin(), offset) {} + + pcg64_dxsm(std::uint64_t seed, std::initializer_list offset) + : detail::engine_base>(seed, offset.size(), offset.begin()) {} + + pcg64_dxsm(std::initializer_list seed, + std::initializer_list offset) + : detail::engine_base>(seed.size(), seed.begin(), offset.size(), + offset.begin()) {} + +private: + template + friend void skip_ahead(Engine& engine, std::uint64_t num_to_skip); + + template + friend void skip_ahead(Engine& engine, std::initializer_list num_to_skip); + + template + friend class detail::distribution_base; +}; + // ENGINE ADAPTORS // Class oneapi::math::rng::device::count_engine_adaptor diff --git a/tests/unit_tests/rng/device/include/moments.hpp b/tests/unit_tests/rng/device/include/moments.hpp index 7b360d5c3..7ac693ebe 100644 --- a/tests/unit_tests/rng/device/include/moments.hpp +++ b/tests/unit_tests/rng/device/include/moments.hpp @@ -106,7 +106,7 @@ class moments_test { return; } - // validation (statistics check is turned out for mcg59) + // validation (statistics check is turned off for mcg59) if constexpr (!std::is_same>::value) { statistics_device stat; diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index efebef0f7..9c6cdddc7 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -889,6 +889,356 @@ INSTANTIATE_TEST_SUITE_P(Mcg59UniformStdDeviceMomentsTestsSuite, Mcg59UniformStd INSTANTIATE_TEST_SUITE_P(Mcg59UniformAccDeviceMomentsTestsSuite, Mcg59UniformAccDeviceMomentsTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); +class Pcg64DXSMUniformStdDeviceMomentsTests : public ::testing::TestWithParam {}; + +class Pcg64DXSMUniformAccDeviceMomentsTests : public ::testing::TestWithParam {}; + +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::math::rng::device::uniform< + float, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::math::rng::device::uniform< + float, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::math::rng::device::uniform< + float, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::math::rng::device::uniform< + double, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::math::rng::device::uniform< + double, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::math::rng::device::uniform< + double, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +/* Test small types (u)int8, (u)int16 only with uniform_method::standard since numbers are always generated +as single precision numbers */ +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, Integer8Precision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int8_t, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int8_t, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int8_t, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, UnsignedInteger8Precision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint8_t, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint8_t, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint8_t, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, Integer16Precision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int16_t, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int16_t, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int16_t, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, UnsignedInteger16Precision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint16_t, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint16_t, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint16_t, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int32_t, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int32_t, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int32_t, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint32_t, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint32_t, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint32_t, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, Integer64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int64_t, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int64_t, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int64_t, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformStdDeviceMomentsTests, UnsignedInteger64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint64_t, oneapi::math::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint64_t, oneapi::math::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint64_t, oneapi::math::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformAccDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::math::rng::device::uniform< + float, oneapi::math::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::math::rng::device::uniform< + float, oneapi::math::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::math::rng::device::uniform< + float, oneapi::math::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::math::rng::device::uniform< + double, oneapi::math::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::math::rng::device::uniform< + double, oneapi::math::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::math::rng::device::uniform< + double, oneapi::math::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformAccDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int32_t, oneapi::math::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int32_t, oneapi::math::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int32_t, oneapi::math::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformAccDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint32_t, oneapi::math::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint32_t, oneapi::math::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint32_t, oneapi::math::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformAccDeviceMomentsTests, Integer64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int64_t, oneapi::math::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int64_t, oneapi::math::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::int64_t, oneapi::math::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMUniformAccDeviceMomentsTests, UnsignedInteger64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint64_t, oneapi::math::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint64_t, oneapi::math::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::uniform< + std::uint64_t, oneapi::math::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Pcg64DXSMUniformStdDeviceMomentsTestsSuite, + Pcg64DXSMUniformStdDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Pcg64DXSMUniformAccDeviceMomentsTestsSuite, + Pcg64DXSMUniformAccDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + class Philox4x32x10BitsDeviceMomentsTests : public ::testing::TestWithParam {}; TEST_P(Philox4x32x10BitsDeviceMomentsTests, UnsignedIntegerPrecision) { diff --git a/tests/unit_tests/rng/device/service/count_adaptor.cpp b/tests/unit_tests/rng/device/service/count_adaptor.cpp index b9aa7a876..f31cece39 100644 --- a/tests/unit_tests/rng/device/service/count_adaptor.cpp +++ b/tests/unit_tests/rng/device/service/count_adaptor.cpp @@ -74,4 +74,16 @@ TEST_P(Mcg59DeviceCountAdaptorTests, BinaryPrecision) { INSTANTIATE_TEST_SUITE_P(Mcg59DeviceCountAdaptorTestsSuite, Mcg59DeviceCountAdaptorTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); +class Pcg64DXSMDeviceCountAdaptorTests : public ::testing::TestWithParam {}; + +TEST_P(Pcg64DXSMDeviceCountAdaptorTests, BinaryPrecision) { + rng_device_test>> test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test>> test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Pcg64DXSMDeviceCountAdaptorTestsSuite, Pcg64DXSMDeviceCountAdaptorTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + } // namespace diff --git a/tests/unit_tests/rng/device/service/skip_ahead.cpp b/tests/unit_tests/rng/device/service/skip_ahead.cpp index 662c56cff..cc4a1fbec 100644 --- a/tests/unit_tests/rng/device/service/skip_ahead.cpp +++ b/tests/unit_tests/rng/device/service/skip_ahead.cpp @@ -110,4 +110,32 @@ TEST_P(Mcg59DeviceSkipAheadTests, BinaryPrecision) { INSTANTIATE_TEST_SUITE_P(Mcg59DeviceSkipAheadTestsSuite, Mcg59DeviceSkipAheadTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); +class Pcg64DXSMDeviceSkipAheadTests : public ::testing::TestWithParam {}; + +class Pcg64DXSMDeviceSkipAheadExTests : public ::testing::TestWithParam {}; + +TEST_P(Pcg64DXSMDeviceSkipAheadTests, BinaryPrecision) { + rng_device_test>> test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test>> test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test>> test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Pcg64DXSMDeviceSkipAheadExTests, BinaryPrecision) { + rng_device_test>> test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test>> test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test>> test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Pcg64DXSMDeviceSkipAheadTestsSuite, Pcg64DXSMDeviceSkipAheadTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Pcg64DXSMDeviceSkipAheadExTestsSuite, Pcg64DXSMDeviceSkipAheadExTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + } // namespace