diff --git a/CMakeLists.txt b/CMakeLists.txt index 705448b83..6cf7feaa8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,7 +106,8 @@ endif() if(ENABLE_MKLCPU_BACKEND OR ENABLE_MKLGPU_BACKEND OR ENABLE_CURAND_BACKEND - OR ENABLE_ROCRAND_BACKEND) + OR ENABLE_ROCRAND_BACKEND + OR ENABLE_ARMPL_BACKEND) list(APPEND DOMAINS_LIST "rng") endif() if(ENABLE_MKLGPU_BACKEND diff --git a/README.md b/README.md index 528e8081e..e755373be 100644 --- a/README.md +++ b/README.md @@ -276,12 +276,18 @@ Supported compilers include: Dynamic, Static - RNG + RNG x86 CPU Intel(R) oneMKL Intel DPC++
AdaptiveCpp Dynamic, Static + + aarch64 CPU + Arm Performance Libraries + Open DPC++
AdaptiveCpp + Dynamic, Static + Intel GPU Intel(R) oneMKL diff --git a/include/oneapi/math/detail/backends_table.hpp b/include/oneapi/math/detail/backends_table.hpp index 48e5f98c6..5dd435cf1 100644 --- a/include/oneapi/math/detail/backends_table.hpp +++ b/include/oneapi/math/detail/backends_table.hpp @@ -175,6 +175,12 @@ static std::map>> libraries = { #ifdef ONEMATH_ENABLE_MKLCPU_BACKEND LIB_NAME("rng_mklcpu") +#endif + } }, + { device::aarch64cpu, + { +#ifdef ONEMATH_ENABLE_ARMPL_BACKEND + LIB_NAME("rng_armpl") #endif } }, { device::intelgpu, diff --git a/include/oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp b/include/oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp new file mode 100644 index 000000000..298c0fee3 --- /dev/null +++ b/include/oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2025 SiPearl +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMATH_RNG_ARMPL_HPP_ +#define _ONEMATH_RNG_ARMPL_HPP_ + +#include +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/math/detail/export.hpp" +#include "oneapi/math/rng/detail/engine_impl.hpp" + +namespace oneapi { +namespace math { +namespace rng { +namespace armpl { + +ONEMATH_EXPORT oneapi::math::rng::detail::engine_impl* create_philox4x32x10(sycl::queue queue, + std::uint64_t seed); + +ONEMATH_EXPORT oneapi::math::rng::detail::engine_impl* create_philox4x32x10( + sycl::queue queue, std::initializer_list seed); + +ONEMATH_EXPORT oneapi::math::rng::detail::engine_impl* create_mrg32k3a(sycl::queue queue, + std::uint32_t seed); + +ONEMATH_EXPORT oneapi::math::rng::detail::engine_impl* create_mrg32k3a( + sycl::queue queue, std::initializer_list seed); + +} // namespace armpl +} // namespace rng +} // namespace math +} // namespace oneapi + +#endif //_ONEMATH_RNG_ARMPL_HPP_ diff --git a/include/oneapi/math/rng/engines.hpp b/include/oneapi/math/rng/engines.hpp index f09243459..3ae489ff1 100644 --- a/include/oneapi/math/rng/engines.hpp +++ b/include/oneapi/math/rng/engines.hpp @@ -47,6 +47,9 @@ #ifdef ONEMATH_ENABLE_ROCRAND_BACKEND #include "oneapi/math/rng/detail/rocrand/onemath_rng_rocrand.hpp" #endif +#ifdef ONEMATH_ENABLE_ARMPL_BACKEND +#include "oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp" +#endif namespace oneapi { namespace math { @@ -103,6 +106,15 @@ class philox4x32x10 { : pimpl_(rocrand::create_philox4x32x10(selector.get_queue(), seed)) {} #endif +#ifdef ONEMATH_ENABLE_ARMPL_BACKEND + philox4x32x10(backend_selector selector, std::uint64_t seed = default_seed) + : pimpl_(armpl::create_philox4x32x10(selector.get_queue(), seed)) {} + + philox4x32x10(backend_selector selector, + std::initializer_list seed) + : pimpl_(armpl::create_philox4x32x10(selector.get_queue(), seed)) {} +#endif + philox4x32x10(const philox4x32x10& other) { pimpl_.reset(other.pimpl_.get()->copy_state()); } @@ -192,6 +204,14 @@ class mrg32k3a { : pimpl_(rocrand::create_mrg32k3a(selector.get_queue(), seed)) {} #endif +#ifdef ONEMATH_ENABLE_ARMPL_BACKEND + mrg32k3a(backend_selector selector, std::uint32_t seed = default_seed) + : pimpl_(armpl::create_mrg32k3a(selector.get_queue(), seed)) {} + + mrg32k3a(backend_selector selector, std::initializer_list seed) + : pimpl_(armpl::create_mrg32k3a(selector.get_queue(), seed)) {} +#endif + mrg32k3a(const mrg32k3a& other) { pimpl_.reset(other.pimpl_.get()->copy_state()); } diff --git a/src/rng/backends/CMakeLists.txt b/src/rng/backends/CMakeLists.txt index 52ddcdd3c..40399138a 100644 --- a/src/rng/backends/CMakeLists.txt +++ b/src/rng/backends/CMakeLists.txt @@ -36,3 +36,8 @@ if(ENABLE_ROCRAND_BACKEND AND UNIX) add_subdirectory(rocrand) endif() +if(ENABLE_ARMPL_BACKEND AND UNIX) + add_subdirectory(armpl) +endif() + + diff --git a/src/rng/backends/armpl/CMakeLists.txt b/src/rng/backends/armpl/CMakeLists.txt new file mode 100644 index 000000000..a3280ddc9 --- /dev/null +++ b/src/rng/backends/armpl/CMakeLists.txt @@ -0,0 +1,76 @@ +#=============================================================================== +# Copyright 2025 SiPearl +# Copyright 2020-2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemath_rng_armpl) +set(LIB_OBJ ${LIB_NAME}_obj) + +find_package(ARMPL REQUIRED) + +set(SOURCES armpl_common.hpp + philox4x32x10.cpp + mrg32k3a.cpp + $<$: armpl_rng_cpu_wrappers.cpp> +) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemath_backend_libs_rng ${LIB_NAME}) +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${ARMPL_INCLUDE} + ${ONEMATH_GENERATED_INCLUDE_PATH} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMATH_BUILD_COPT} ${ARMPL_COPT}) +if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES}) +endif() + +target_link_libraries(${LIB_OBJ} PUBLIC ONEMATH::SYCL::SYCL ${ARMPL_LINK}) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +# Set oneMATH libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMATH::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMathTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMathTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/rng/backends/armpl/armpl_common.hpp b/src/rng/backends/armpl/armpl_common.hpp new file mode 100644 index 000000000..c3028f702 --- /dev/null +++ b/src/rng/backends/armpl/armpl_common.hpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2025 SiPearl +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _RNG_CPU_COMMON_HPP_ +#define _RNG_CPU_COMMON_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#define GET_MULTI_PTR template get_multi_ptr().get_raw() +#define __fp16 _Float16 +#define INTEGER64 1 + +#include "armpl.h" + +namespace oneapi { +namespace math { +namespace rng { +namespace armpl { + +inline int check_armpl_version(armpl_int_t major_req, armpl_int_t minor_req, armpl_int_t build_req, + const char* message) { + armpl_int_t major, minor, build; + char* tag; + armplversion(&major, &minor, &build, (const char**)&tag); + if (major > major_req) { + return 0; + } + else if (major == major_req && minor > minor_req) { + return 0; + } + else if (major == major_req && minor == minor_req && build >= build_req) { + return 0; + } + throw oneapi::math::unimplemented("rng", "version support", message); +} + +template +static inline auto host_task_internal(H& cgh, F f, int) -> decltype(cgh.host_task(f)) { + return cgh.host_task(f); +} + +template +static inline void host_task_internal(H& cgh, F f, long) { +#ifndef __SYCL_DEVICE_ONLY__ + cgh.template single_task(f); +#endif +} + +template +static inline void host_task(H& cgh, F f) { + (void)host_task_internal(cgh, f, 0); +} + +template +class kernel_name {}; + +template +class kernel_name_usm {}; + +} // namespace armpl +} // namespace rng +} // namespace math +} // namespace oneapi + +#endif //_RNG_CPU_COMMON_HPP_ diff --git a/src/rng/backends/armpl/armpl_rng_cpu_wrappers.cpp b/src/rng/backends/armpl/armpl_rng_cpu_wrappers.cpp new file mode 100644 index 000000000..8ecd5333b --- /dev/null +++ b/src/rng/backends/armpl/armpl_rng_cpu_wrappers.cpp @@ -0,0 +1,30 @@ +/******************************************************************************* +* Copyright 2025 SiPearl +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "rng/function_table.hpp" +#include "oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp" + +#define WRAPPER_VERSION 1 + +extern "C" ONEMATH_EXPORT rng_function_table_t onemath_rng_table = { + WRAPPER_VERSION, oneapi::math::rng::armpl::create_philox4x32x10, + oneapi::math::rng::armpl::create_philox4x32x10, oneapi::math::rng::armpl::create_mrg32k3a, + oneapi::math::rng::armpl::create_mrg32k3a +}; diff --git a/src/rng/backends/armpl/mrg32k3a.cpp b/src/rng/backends/armpl/mrg32k3a.cpp new file mode 100644 index 000000000..43702752a --- /dev/null +++ b/src/rng/backends/armpl/mrg32k3a.cpp @@ -0,0 +1,498 @@ +/******************************************************************************* +* Copyright 2025 SiPearl +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/math/exceptions.hpp" +#include "oneapi/math/rng/detail/engine_impl.hpp" +#include "oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp" + +#include "armpl_common.hpp" + +namespace oneapi { +namespace math { +namespace rng { +namespace armpl { + +class mrg32k3a_impl : public oneapi::math::rng::detail::engine_impl { +public: + mrg32k3a_impl(sycl::queue queue, std::uint32_t seed) + : oneapi::math::rng::detail::engine_impl(queue) { + vslNewStream(&stream_, VSL_BRNG_MRG32K3A, seed); + state_size_ = vslGetStreamSize(stream_); + } + + mrg32k3a_impl(sycl::queue queue, std::initializer_list seed) + : oneapi::math::rng::detail::engine_impl(queue) { + vslNewStreamEx(&stream_, VSL_BRNG_MRG32K3A, 2 * seed.size(), + reinterpret_cast(seed.begin())); + state_size_ = vslGetStreamSize(stream_); + } + + mrg32k3a_impl(const mrg32k3a_impl* other) : oneapi::math::rng::detail::engine_impl(*other) { + vslCopyStream(&stream_, other->stream_); + state_size_ = vslGetStreamSize(stream_); + } + + // Buffers APIs + + virtual void generate(const uniform& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const uniform& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const uniform& distr, + std::int64_t n, sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + if (distr.a() < 0) + check_armpl_version(25, 04, 0, + "ArmPl : Uniform int32 generation is not functional with <0 bound"); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const uniform& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const uniform& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const gaussian& distr, + std::int64_t n, sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev()); + }); + }); + } + + virtual void generate(const gaussian& distr, + std::int64_t n, sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev()); + }); + }); + } + + virtual void generate(const gaussian& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev()); + }); + }); + } + + virtual void generate(const gaussian& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev()); + }); + }); + } + + virtual void generate(const lognormal& distr, + std::int64_t n, sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual void generate(const lognormal& distr, + std::int64_t n, sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual void generate(const lognormal& distr, std::int64_t n, + sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual void generate(const lognormal& distr, std::int64_t n, + sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual void generate(const bernoulli& distr, + std::int64_t n, sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.p()); + }); + }); + } + + virtual void generate(const bernoulli& distr, + std::int64_t n, sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + std::uint32_t* r_ptr = acc_r.GET_MULTI_PTR; + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, + static_cast(acc_stream.GET_MULTI_PTR), n, + reinterpret_cast(r_ptr), distr.p()); + }); + }); + } + + virtual void generate(const poisson& distr, + std::int64_t n, sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL"); + } + + virtual void generate(const poisson& distr, + std::int64_t n, sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL"); + } + + virtual void generate(const bits& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + viRngUniformBits(VSL_RNG_METHOD_UNIFORMBITS_STD, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR); + }); + }); + } + + // USM APIs + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream, n, r, distr.a(), distr.b()); + }); + }); + } + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream, n, r, distr.a(), distr.b()); + }); + }); + } + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, std::int32_t* r, + const std::vector& dependencies) override { + if (distr.a() < 0) + check_armpl_version(25, 04, 0, + "ArmPl : Uniform int32 generation is not functional with <0 bound"); + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream, n, r, distr.a(), distr.b()); + }); + }); + } + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, stream, n, r, distr.a(), + distr.b()); + }); + }); + } + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, stream, n, r, distr.a(), + distr.b()); + }); + }); + } + + virtual sycl::event generate(const gaussian& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, stream, n, r, distr.mean(), + distr.stddev()); + }); + }); + } + + virtual sycl::event generate(const gaussian& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, stream, n, r, distr.mean(), + distr.stddev()); + }); + }); + } + + virtual sycl::event generate(const gaussian& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, stream, n, r, distr.mean(), + distr.stddev()); + }); + }); + } + + virtual sycl::event generate(const gaussian& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, stream, n, r, distr.mean(), + distr.stddev()); + }); + }); + } + + virtual sycl::event generate(const lognormal& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const lognormal& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const lognormal& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const lognormal& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const bernoulli& distr, + std::int64_t n, std::int32_t* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, n, r, distr.p()); + }); + }); + } + + virtual sycl::event generate(const bernoulli& distr, + std::int64_t n, std::uint32_t* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, n, + reinterpret_cast(r), distr.p()); + }); + }); + } + + virtual sycl::event generate( + const poisson& distr, std::int64_t n, + std::int32_t* r, const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL"); + } + + virtual sycl::event generate( + const poisson& distr, std::int64_t n, + std::uint32_t* r, const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const bits& distr, std::int64_t n, std::uint32_t* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>( + cgh, [=]() { viRngUniformBits(VSL_RNG_METHOD_UNIFORMBITS_STD, stream, n, r); }); + }); + } + + virtual oneapi::math::rng::detail::engine_impl* copy_state() override { + return new mrg32k3a_impl(this); + } + + virtual void skip_ahead(std::uint64_t num_to_skip) override { + vslSkipAheadStream(stream_, num_to_skip); + } + + virtual void skip_ahead(std::initializer_list num_to_skip) override { + vslSkipAheadStreamEx(stream_, num_to_skip.size(), (unsigned long long*)num_to_skip.begin()); + } + + virtual void leapfrog(std::uint64_t idx, std::uint64_t stride) override { + vslLeapfrogStream(stream_, idx, stride); + } + + virtual ~mrg32k3a_impl() override { + vslDeleteStream(&stream_); + } + +private: + VSLStreamStatePtr stream_; + std::int32_t state_size_; +}; + +oneapi::math::rng::detail::engine_impl* create_mrg32k3a(sycl::queue queue, std::uint32_t seed) { + return new mrg32k3a_impl(queue, seed); +} + +oneapi::math::rng::detail::engine_impl* create_mrg32k3a(sycl::queue queue, + std::initializer_list seed) { + return new mrg32k3a_impl(queue, seed); +} + +} // namespace armpl +} // namespace rng +} // namespace math +} // namespace oneapi diff --git a/src/rng/backends/armpl/philox4x32x10.cpp b/src/rng/backends/armpl/philox4x32x10.cpp new file mode 100644 index 000000000..59192c41b --- /dev/null +++ b/src/rng/backends/armpl/philox4x32x10.cpp @@ -0,0 +1,501 @@ +/******************************************************************************* +* Copyright 2025 SiPearl +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/math/exceptions.hpp" +#include "oneapi/math/rng/detail/engine_impl.hpp" +#include "oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp" + +#include "armpl_common.hpp" + +namespace oneapi { +namespace math { +namespace rng { +namespace armpl { + +class philox4x32x10_impl : public oneapi::math::rng::detail::engine_impl { +public: + philox4x32x10_impl(sycl::queue queue, std::uint64_t seed) + : oneapi::math::rng::detail::engine_impl(queue) { + vslNewStreamEx(&stream_, VSL_BRNG_PHILOX4X32X10, 2, + reinterpret_cast(&seed)); + state_size_ = vslGetStreamSize(stream_); + } + + philox4x32x10_impl(sycl::queue queue, std::initializer_list seed) + : oneapi::math::rng::detail::engine_impl(queue) { + vslNewStreamEx(&stream_, VSL_BRNG_PHILOX4X32X10, 2 * seed.size(), + reinterpret_cast(seed.begin())); + state_size_ = vslGetStreamSize(stream_); + } + + philox4x32x10_impl(const philox4x32x10_impl* other) + : oneapi::math::rng::detail::engine_impl(*other) { + vslCopyStream(&stream_, other->stream_); + state_size_ = vslGetStreamSize(stream_); + } + + // Buffers APIs + + virtual void generate(const uniform& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const uniform& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const uniform& distr, + std::int64_t n, sycl::buffer& r) override { + if (distr.a() < 0) + check_armpl_version(25, 04, 0, + "ArmPl : Uniform int32 generation is not functional with <0 bound"); + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const uniform& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const uniform& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.a(), distr.b()); + }); + }); + } + + virtual void generate(const gaussian& distr, + std::int64_t n, sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev()); + }); + }); + } + + virtual void generate(const gaussian& distr, + std::int64_t n, sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev()); + }); + }); + } + + virtual void generate(const gaussian& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev()); + }); + }); + } + + virtual void generate(const gaussian& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev()); + }); + }); + } + + virtual void generate(const lognormal& distr, + std::int64_t n, sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual void generate(const lognormal& distr, + std::int64_t n, sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual void generate(const lognormal& distr, std::int64_t n, + sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual void generate(const lognormal& distr, std::int64_t n, + sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual void generate(const bernoulli& distr, + std::int64_t n, sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR, distr.p()); + }); + }); + } + + virtual void generate(const bernoulli& distr, + std::int64_t n, sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + std::uint32_t* r_ptr = acc_r.GET_MULTI_PTR; + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, + static_cast(acc_stream.GET_MULTI_PTR), n, + reinterpret_cast(r_ptr), distr.p()); + }); + }); + } + + virtual void generate(const poisson& distr, + std::int64_t n, sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL"); + } + + virtual void generate(const poisson& distr, + std::int64_t n, sycl::buffer& r) override { + throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL"); + } + + virtual void generate(const bits& distr, std::int64_t n, + sycl::buffer& r) override { + sycl::buffer stream_buf(static_cast(stream_), state_size_); + queue_.submit([&](sycl::handler& cgh) { + auto acc_stream = stream_buf.get_access(cgh); + auto acc_r = r.get_access(cgh); + host_task>(cgh, [=]() { + viRngUniformBits(VSL_RNG_METHOD_UNIFORMBITS_STD, + static_cast(acc_stream.GET_MULTI_PTR), n, + acc_r.GET_MULTI_PTR); + }); + }); + } + + // USM APIs + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream, n, r, distr.a(), distr.b()); + }); + }); + } + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream, n, r, distr.a(), distr.b()); + }); + }); + } + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, std::int32_t* r, + const std::vector& dependencies) override { + if (distr.a() < 0) + check_armpl_version(25, 04, 0, + "ArmPl : Uniform int32 generation is not functional with <0 bound"); + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream, n, r, distr.a(), distr.b()); + }); + }); + } + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, stream, n, r, distr.a(), + distr.b()); + }); + }); + } + + virtual sycl::event generate(const uniform& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, stream, n, r, distr.a(), + distr.b()); + }); + }); + } + + virtual sycl::event generate(const gaussian& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, stream, n, r, distr.mean(), + distr.stddev()); + }); + }); + } + + virtual sycl::event generate(const gaussian& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, stream, n, r, distr.mean(), + distr.stddev()); + }); + }); + } + + virtual sycl::event generate(const gaussian& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, stream, n, r, distr.mean(), + distr.stddev()); + }); + }); + } + + virtual sycl::event generate(const gaussian& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, stream, n, r, distr.mean(), + distr.stddev()); + }); + }); + } + + virtual sycl::event generate(const lognormal& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const lognormal& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const lognormal& distr, + std::int64_t n, float* r, + const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const lognormal& distr, + std::int64_t n, double* r, + const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method lognormal", + "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const bernoulli& distr, + std::int64_t n, std::int32_t* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, n, r, distr.p()); + }); + }); + } + + virtual sycl::event generate(const bernoulli& distr, + std::int64_t n, std::uint32_t* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>(cgh, [=]() { + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, n, + reinterpret_cast(r), distr.p()); + }); + }); + } + + virtual sycl::event generate( + const poisson& distr, std::int64_t n, + std::int32_t* r, const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL"); + } + + virtual sycl::event generate( + const poisson& distr, std::int64_t n, + std::uint32_t* r, const std::vector& dependencies) override { + throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL"); + } + + virtual sycl::event generate(const bits& distr, std::int64_t n, std::uint32_t* r, + const std::vector& dependencies) override { + sycl::event::wait_and_throw(dependencies); + return queue_.submit([&](sycl::handler& cgh) { + VSLStreamStatePtr stream = stream_; + host_task>( + cgh, [=]() { viRngUniformBits(VSL_RNG_METHOD_UNIFORMBITS_STD, stream, n, r); }); + }); + } + + virtual oneapi::math::rng::detail::engine_impl* copy_state() override { + return new philox4x32x10_impl(this); + } + + virtual void skip_ahead(std::uint64_t num_to_skip) override { + vslSkipAheadStream(stream_, num_to_skip); + } + + virtual void skip_ahead(std::initializer_list num_to_skip) override { + vslSkipAheadStreamEx(stream_, num_to_skip.size(), (unsigned long long*)num_to_skip.begin()); + } + + virtual void leapfrog(std::uint64_t idx, std::uint64_t stride) override { + vslLeapfrogStream(stream_, idx, stride); + } + + virtual ~philox4x32x10_impl() override { + vslDeleteStream(&stream_); + } + +private: + VSLStreamStatePtr stream_; + std::int32_t state_size_; +}; + +oneapi::math::rng::detail::engine_impl* create_philox4x32x10(sycl::queue queue, + std::uint64_t seed) { + return new philox4x32x10_impl(queue, seed); +} + +oneapi::math::rng::detail::engine_impl* create_philox4x32x10( + sycl::queue queue, std::initializer_list seed) { + return new philox4x32x10_impl(queue, seed); +} + +} // namespace armpl +} // namespace rng +} // namespace math +} // namespace oneapi diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index a9df19787..c444316b7 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -182,6 +182,11 @@ foreach(domain ${TEST_TARGET_DOMAINS}) list(APPEND ONEMATH_LIBRARIES_${domain} onemath_${domain}_rocrand) endif() + if(domain STREQUAL "rng" AND ENABLE_ARMPL_BACKEND) + add_dependencies(test_main_${domain}_ct onemath_${domain}_armpl) + list(APPEND ONEMATH_LIBRARIES_${domain} onemath_${domain}_armpl) + endif() + if(domain STREQUAL "dft" AND ENABLE_CUFFT_BACKEND) add_dependencies(test_main_${domain}_ct onemath_${domain}_cufft) list(APPEND ONEMATH_LIBRARIES_${domain} onemath_${domain}_cufft)