diff --git a/CMakeLists.txt b/CMakeLists.txt
index 705448b83..6cf7feaa8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -106,7 +106,8 @@ endif()
if(ENABLE_MKLCPU_BACKEND
OR ENABLE_MKLGPU_BACKEND
OR ENABLE_CURAND_BACKEND
- OR ENABLE_ROCRAND_BACKEND)
+ OR ENABLE_ROCRAND_BACKEND
+ OR ENABLE_ARMPL_BACKEND)
list(APPEND DOMAINS_LIST "rng")
endif()
if(ENABLE_MKLGPU_BACKEND
diff --git a/README.md b/README.md
index 528e8081e..e755373be 100644
--- a/README.md
+++ b/README.md
@@ -276,12 +276,18 @@ Supported compilers include:
Dynamic, Static |
- | RNG |
+ RNG |
x86 CPU |
Intel(R) oneMKL |
Intel DPC++AdaptiveCpp |
Dynamic, Static |
+
+ | aarch64 CPU |
+ Arm Performance Libraries |
+ Open DPC++AdaptiveCpp |
+ Dynamic, Static |
+
| Intel GPU |
Intel(R) oneMKL |
diff --git a/include/oneapi/math/detail/backends_table.hpp b/include/oneapi/math/detail/backends_table.hpp
index 48e5f98c6..5dd435cf1 100644
--- a/include/oneapi/math/detail/backends_table.hpp
+++ b/include/oneapi/math/detail/backends_table.hpp
@@ -175,6 +175,12 @@ static std::map>> libraries =
{
#ifdef ONEMATH_ENABLE_MKLCPU_BACKEND
LIB_NAME("rng_mklcpu")
+#endif
+ } },
+ { device::aarch64cpu,
+ {
+#ifdef ONEMATH_ENABLE_ARMPL_BACKEND
+ LIB_NAME("rng_armpl")
#endif
} },
{ device::intelgpu,
diff --git a/include/oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp b/include/oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp
new file mode 100644
index 000000000..298c0fee3
--- /dev/null
+++ b/include/oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp
@@ -0,0 +1,56 @@
+/*******************************************************************************
+* Copyright 2025 SiPearl
+* Copyright 2020-2021 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#ifndef _ONEMATH_RNG_ARMPL_HPP_
+#define _ONEMATH_RNG_ARMPL_HPP_
+
+#include
+#if __has_include()
+#include
+#else
+#include
+#endif
+
+#include "oneapi/math/detail/export.hpp"
+#include "oneapi/math/rng/detail/engine_impl.hpp"
+
+namespace oneapi {
+namespace math {
+namespace rng {
+namespace armpl {
+
+ONEMATH_EXPORT oneapi::math::rng::detail::engine_impl* create_philox4x32x10(sycl::queue queue,
+ std::uint64_t seed);
+
+ONEMATH_EXPORT oneapi::math::rng::detail::engine_impl* create_philox4x32x10(
+ sycl::queue queue, std::initializer_list seed);
+
+ONEMATH_EXPORT oneapi::math::rng::detail::engine_impl* create_mrg32k3a(sycl::queue queue,
+ std::uint32_t seed);
+
+ONEMATH_EXPORT oneapi::math::rng::detail::engine_impl* create_mrg32k3a(
+ sycl::queue queue, std::initializer_list seed);
+
+} // namespace armpl
+} // namespace rng
+} // namespace math
+} // namespace oneapi
+
+#endif //_ONEMATH_RNG_ARMPL_HPP_
diff --git a/include/oneapi/math/rng/engines.hpp b/include/oneapi/math/rng/engines.hpp
index f09243459..3ae489ff1 100644
--- a/include/oneapi/math/rng/engines.hpp
+++ b/include/oneapi/math/rng/engines.hpp
@@ -47,6 +47,9 @@
#ifdef ONEMATH_ENABLE_ROCRAND_BACKEND
#include "oneapi/math/rng/detail/rocrand/onemath_rng_rocrand.hpp"
#endif
+#ifdef ONEMATH_ENABLE_ARMPL_BACKEND
+#include "oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp"
+#endif
namespace oneapi {
namespace math {
@@ -103,6 +106,15 @@ class philox4x32x10 {
: pimpl_(rocrand::create_philox4x32x10(selector.get_queue(), seed)) {}
#endif
+#ifdef ONEMATH_ENABLE_ARMPL_BACKEND
+ philox4x32x10(backend_selector selector, std::uint64_t seed = default_seed)
+ : pimpl_(armpl::create_philox4x32x10(selector.get_queue(), seed)) {}
+
+ philox4x32x10(backend_selector selector,
+ std::initializer_list seed)
+ : pimpl_(armpl::create_philox4x32x10(selector.get_queue(), seed)) {}
+#endif
+
philox4x32x10(const philox4x32x10& other) {
pimpl_.reset(other.pimpl_.get()->copy_state());
}
@@ -192,6 +204,14 @@ class mrg32k3a {
: pimpl_(rocrand::create_mrg32k3a(selector.get_queue(), seed)) {}
#endif
+#ifdef ONEMATH_ENABLE_ARMPL_BACKEND
+ mrg32k3a(backend_selector selector, std::uint32_t seed = default_seed)
+ : pimpl_(armpl::create_mrg32k3a(selector.get_queue(), seed)) {}
+
+ mrg32k3a(backend_selector selector, std::initializer_list seed)
+ : pimpl_(armpl::create_mrg32k3a(selector.get_queue(), seed)) {}
+#endif
+
mrg32k3a(const mrg32k3a& other) {
pimpl_.reset(other.pimpl_.get()->copy_state());
}
diff --git a/src/rng/backends/CMakeLists.txt b/src/rng/backends/CMakeLists.txt
index 52ddcdd3c..40399138a 100644
--- a/src/rng/backends/CMakeLists.txt
+++ b/src/rng/backends/CMakeLists.txt
@@ -36,3 +36,8 @@ if(ENABLE_ROCRAND_BACKEND AND UNIX)
add_subdirectory(rocrand)
endif()
+if(ENABLE_ARMPL_BACKEND AND UNIX)
+ add_subdirectory(armpl)
+endif()
+
+
diff --git a/src/rng/backends/armpl/CMakeLists.txt b/src/rng/backends/armpl/CMakeLists.txt
new file mode 100644
index 000000000..a3280ddc9
--- /dev/null
+++ b/src/rng/backends/armpl/CMakeLists.txt
@@ -0,0 +1,76 @@
+#===============================================================================
+# Copyright 2025 SiPearl
+# Copyright 2020-2021 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions
+# and limitations under the License.
+#
+#
+# SPDX-License-Identifier: Apache-2.0
+#===============================================================================
+
+set(LIB_NAME onemath_rng_armpl)
+set(LIB_OBJ ${LIB_NAME}_obj)
+
+find_package(ARMPL REQUIRED)
+
+set(SOURCES armpl_common.hpp
+ philox4x32x10.cpp
+ mrg32k3a.cpp
+ $<$: armpl_rng_cpu_wrappers.cpp>
+)
+
+add_library(${LIB_NAME})
+add_library(${LIB_OBJ} OBJECT ${SOURCES})
+add_dependencies(onemath_backend_libs_rng ${LIB_NAME})
+target_include_directories(${LIB_OBJ}
+ PRIVATE ${PROJECT_SOURCE_DIR}/include
+ ${PROJECT_SOURCE_DIR}/src
+ ${CMAKE_BINARY_DIR}/bin
+ ${ARMPL_INCLUDE}
+ ${ONEMATH_GENERATED_INCLUDE_PATH}
+)
+
+target_compile_options(${LIB_OBJ} PRIVATE ${ONEMATH_BUILD_COPT} ${ARMPL_COPT})
+if (USE_ADD_SYCL_TO_TARGET_INTEGRATION)
+ add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES})
+endif()
+
+target_link_libraries(${LIB_OBJ} PUBLIC ONEMATH::SYCL::SYCL ${ARMPL_LINK})
+
+set_target_properties(${LIB_OBJ} PROPERTIES
+ POSITION_INDEPENDENT_CODE ON
+)
+target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ})
+
+# Set oneMATH libraries as not transitive for dynamic
+if(BUILD_SHARED_LIBS)
+ set_target_properties(${LIB_NAME} PROPERTIES
+ INTERFACE_LINK_LIBRARIES ONEMATH::SYCL::SYCL
+ )
+endif()
+
+# Add major version to the library
+set_target_properties(${LIB_NAME} PROPERTIES
+ SOVERSION ${PROJECT_VERSION_MAJOR}
+)
+
+# Add dependencies rpath to the library
+list(APPEND CMAKE_BUILD_RPATH $)
+
+# Add the library to install package
+install(TARGETS ${LIB_OBJ} EXPORT oneMathTargets)
+install(TARGETS ${LIB_NAME} EXPORT oneMathTargets
+ RUNTIME DESTINATION bin
+ ARCHIVE DESTINATION lib
+ LIBRARY DESTINATION lib
+)
diff --git a/src/rng/backends/armpl/armpl_common.hpp b/src/rng/backends/armpl/armpl_common.hpp
new file mode 100644
index 000000000..c3028f702
--- /dev/null
+++ b/src/rng/backends/armpl/armpl_common.hpp
@@ -0,0 +1,86 @@
+/*******************************************************************************
+* Copyright 2025 SiPearl
+* Copyright 2020-2021 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#ifndef _RNG_CPU_COMMON_HPP_
+#define _RNG_CPU_COMMON_HPP_
+
+#if __has_include()
+#include
+#else
+#include
+#endif
+
+#define GET_MULTI_PTR template get_multi_ptr().get_raw()
+#define __fp16 _Float16
+#define INTEGER64 1
+
+#include "armpl.h"
+
+namespace oneapi {
+namespace math {
+namespace rng {
+namespace armpl {
+
+inline int check_armpl_version(armpl_int_t major_req, armpl_int_t minor_req, armpl_int_t build_req,
+ const char* message) {
+ armpl_int_t major, minor, build;
+ char* tag;
+ armplversion(&major, &minor, &build, (const char**)&tag);
+ if (major > major_req) {
+ return 0;
+ }
+ else if (major == major_req && minor > minor_req) {
+ return 0;
+ }
+ else if (major == major_req && minor == minor_req && build >= build_req) {
+ return 0;
+ }
+ throw oneapi::math::unimplemented("rng", "version support", message);
+}
+
+template
+static inline auto host_task_internal(H& cgh, F f, int) -> decltype(cgh.host_task(f)) {
+ return cgh.host_task(f);
+}
+
+template
+static inline void host_task_internal(H& cgh, F f, long) {
+#ifndef __SYCL_DEVICE_ONLY__
+ cgh.template single_task(f);
+#endif
+}
+
+template
+static inline void host_task(H& cgh, F f) {
+ (void)host_task_internal(cgh, f, 0);
+}
+
+template
+class kernel_name {};
+
+template
+class kernel_name_usm {};
+
+} // namespace armpl
+} // namespace rng
+} // namespace math
+} // namespace oneapi
+
+#endif //_RNG_CPU_COMMON_HPP_
diff --git a/src/rng/backends/armpl/armpl_rng_cpu_wrappers.cpp b/src/rng/backends/armpl/armpl_rng_cpu_wrappers.cpp
new file mode 100644
index 000000000..8ecd5333b
--- /dev/null
+++ b/src/rng/backends/armpl/armpl_rng_cpu_wrappers.cpp
@@ -0,0 +1,30 @@
+/*******************************************************************************
+* Copyright 2025 SiPearl
+* Copyright 2020-2021 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#include "rng/function_table.hpp"
+#include "oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp"
+
+#define WRAPPER_VERSION 1
+
+extern "C" ONEMATH_EXPORT rng_function_table_t onemath_rng_table = {
+ WRAPPER_VERSION, oneapi::math::rng::armpl::create_philox4x32x10,
+ oneapi::math::rng::armpl::create_philox4x32x10, oneapi::math::rng::armpl::create_mrg32k3a,
+ oneapi::math::rng::armpl::create_mrg32k3a
+};
diff --git a/src/rng/backends/armpl/mrg32k3a.cpp b/src/rng/backends/armpl/mrg32k3a.cpp
new file mode 100644
index 000000000..43702752a
--- /dev/null
+++ b/src/rng/backends/armpl/mrg32k3a.cpp
@@ -0,0 +1,498 @@
+/*******************************************************************************
+* Copyright 2025 SiPearl
+* Copyright 2020-2021 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#include
+#if __has_include()
+#include
+#else
+#include
+#endif
+
+#include "oneapi/math/exceptions.hpp"
+#include "oneapi/math/rng/detail/engine_impl.hpp"
+#include "oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp"
+
+#include "armpl_common.hpp"
+
+namespace oneapi {
+namespace math {
+namespace rng {
+namespace armpl {
+
+class mrg32k3a_impl : public oneapi::math::rng::detail::engine_impl {
+public:
+ mrg32k3a_impl(sycl::queue queue, std::uint32_t seed)
+ : oneapi::math::rng::detail::engine_impl(queue) {
+ vslNewStream(&stream_, VSL_BRNG_MRG32K3A, seed);
+ state_size_ = vslGetStreamSize(stream_);
+ }
+
+ mrg32k3a_impl(sycl::queue queue, std::initializer_list seed)
+ : oneapi::math::rng::detail::engine_impl(queue) {
+ vslNewStreamEx(&stream_, VSL_BRNG_MRG32K3A, 2 * seed.size(),
+ reinterpret_cast(seed.begin()));
+ state_size_ = vslGetStreamSize(stream_);
+ }
+
+ mrg32k3a_impl(const mrg32k3a_impl* other) : oneapi::math::rng::detail::engine_impl(*other) {
+ vslCopyStream(&stream_, other->stream_);
+ state_size_ = vslGetStreamSize(stream_);
+ }
+
+ // Buffers APIs
+
+ virtual void generate(const uniform& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const uniform& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const uniform& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ if (distr.a() < 0)
+ check_armpl_version(25, 04, 0,
+ "ArmPl : Uniform int32 generation is not functional with <0 bound");
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ viRngUniform(VSL_RNG_METHOD_UNIFORM_STD,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const uniform& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const uniform& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const gaussian& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev());
+ });
+ });
+ }
+
+ virtual void generate(const gaussian& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev());
+ });
+ });
+ }
+
+ virtual void generate(const gaussian& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev());
+ });
+ });
+ }
+
+ virtual void generate(const gaussian& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev());
+ });
+ });
+ }
+
+ virtual void generate(const lognormal& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ throw oneapi::math::unimplemented("rng", "method lognormal",
+ "not yet implemented in ArmPL");
+ }
+
+ virtual void generate(const lognormal& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ throw oneapi::math::unimplemented("rng", "method lognormal",
+ "not yet implemented in ArmPL");
+ }
+
+ virtual void generate(const lognormal& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ throw oneapi::math::unimplemented("rng", "method lognormal",
+ "not yet implemented in ArmPL");
+ }
+
+ virtual void generate(const lognormal& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ throw oneapi::math::unimplemented("rng", "method lognormal",
+ "not yet implemented in ArmPL");
+ }
+
+ virtual void generate(const bernoulli& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.p());
+ });
+ });
+ }
+
+ virtual void generate(const bernoulli& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ std::uint32_t* r_ptr = acc_r.GET_MULTI_PTR;
+ viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ reinterpret_cast(r_ptr), distr.p());
+ });
+ });
+ }
+
+ virtual void generate(const poisson& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL");
+ }
+
+ virtual void generate(const poisson& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL");
+ }
+
+ virtual void generate(const bits& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ viRngUniformBits(VSL_RNG_METHOD_UNIFORMBITS_STD,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR);
+ });
+ });
+ }
+
+ // USM APIs
+
+ virtual sycl::event generate(const uniform& distr,
+ std::int64_t n, float* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream, n, r, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const uniform& distr,
+ std::int64_t n, double* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream, n, r, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const uniform& distr,
+ std::int64_t n, std::int32_t* r,
+ const std::vector& dependencies) override {
+ if (distr.a() < 0)
+ check_armpl_version(25, 04, 0,
+ "ArmPl : Uniform int32 generation is not functional with <0 bound");
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream, n, r, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const uniform& distr,
+ std::int64_t n, float* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, stream, n, r, distr.a(),
+ distr.b());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const uniform& distr,
+ std::int64_t n, double* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, stream, n, r, distr.a(),
+ distr.b());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const gaussian& distr,
+ std::int64_t n, float* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, stream, n, r, distr.mean(),
+ distr.stddev());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const gaussian& distr,
+ std::int64_t n, double* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, stream, n, r, distr.mean(),
+ distr.stddev());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const gaussian& distr,
+ std::int64_t n, float* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, stream, n, r, distr.mean(),
+ distr.stddev());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const gaussian& distr,
+ std::int64_t n, double* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, stream, n, r, distr.mean(),
+ distr.stddev());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const lognormal& distr,
+ std::int64_t n, float* r,
+ const std::vector& dependencies) override {
+ throw oneapi::math::unimplemented("rng", "method lognormal",
+ "not yet implemented in ArmPL");
+ }
+
+ virtual sycl::event generate(const lognormal& distr,
+ std::int64_t n, double* r,
+ const std::vector& dependencies) override {
+ throw oneapi::math::unimplemented("rng", "method lognormal",
+ "not yet implemented in ArmPL");
+ }
+
+ virtual sycl::event generate(const lognormal& distr,
+ std::int64_t n, float* r,
+ const std::vector& dependencies) override {
+ throw oneapi::math::unimplemented("rng", "method lognormal",
+ "not yet implemented in ArmPL");
+ }
+
+ virtual sycl::event generate(const lognormal& distr,
+ std::int64_t n, double* r,
+ const std::vector& dependencies) override {
+ throw oneapi::math::unimplemented("rng", "method lognormal",
+ "not yet implemented in ArmPL");
+ }
+
+ virtual sycl::event generate(const bernoulli& distr,
+ std::int64_t n, std::int32_t* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, n, r, distr.p());
+ });
+ });
+ }
+
+ virtual sycl::event generate(const bernoulli& distr,
+ std::int64_t n, std::uint32_t* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(cgh, [=]() {
+ viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, n,
+ reinterpret_cast(r), distr.p());
+ });
+ });
+ }
+
+ virtual sycl::event generate(
+ const poisson& distr, std::int64_t n,
+ std::int32_t* r, const std::vector& dependencies) override {
+ throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL");
+ }
+
+ virtual sycl::event generate(
+ const poisson& distr, std::int64_t n,
+ std::uint32_t* r, const std::vector& dependencies) override {
+ throw oneapi::math::unimplemented("rng", "method poisson", "not yet implemented in ArmPL");
+ }
+
+ virtual sycl::event generate(const bits& distr, std::int64_t n, std::uint32_t* r,
+ const std::vector& dependencies) override {
+ sycl::event::wait_and_throw(dependencies);
+ return queue_.submit([&](sycl::handler& cgh) {
+ VSLStreamStatePtr stream = stream_;
+ host_task>(
+ cgh, [=]() { viRngUniformBits(VSL_RNG_METHOD_UNIFORMBITS_STD, stream, n, r); });
+ });
+ }
+
+ virtual oneapi::math::rng::detail::engine_impl* copy_state() override {
+ return new mrg32k3a_impl(this);
+ }
+
+ virtual void skip_ahead(std::uint64_t num_to_skip) override {
+ vslSkipAheadStream(stream_, num_to_skip);
+ }
+
+ virtual void skip_ahead(std::initializer_list num_to_skip) override {
+ vslSkipAheadStreamEx(stream_, num_to_skip.size(), (unsigned long long*)num_to_skip.begin());
+ }
+
+ virtual void leapfrog(std::uint64_t idx, std::uint64_t stride) override {
+ vslLeapfrogStream(stream_, idx, stride);
+ }
+
+ virtual ~mrg32k3a_impl() override {
+ vslDeleteStream(&stream_);
+ }
+
+private:
+ VSLStreamStatePtr stream_;
+ std::int32_t state_size_;
+};
+
+oneapi::math::rng::detail::engine_impl* create_mrg32k3a(sycl::queue queue, std::uint32_t seed) {
+ return new mrg32k3a_impl(queue, seed);
+}
+
+oneapi::math::rng::detail::engine_impl* create_mrg32k3a(sycl::queue queue,
+ std::initializer_list seed) {
+ return new mrg32k3a_impl(queue, seed);
+}
+
+} // namespace armpl
+} // namespace rng
+} // namespace math
+} // namespace oneapi
diff --git a/src/rng/backends/armpl/philox4x32x10.cpp b/src/rng/backends/armpl/philox4x32x10.cpp
new file mode 100644
index 000000000..59192c41b
--- /dev/null
+++ b/src/rng/backends/armpl/philox4x32x10.cpp
@@ -0,0 +1,501 @@
+/*******************************************************************************
+* Copyright 2025 SiPearl
+* Copyright 2020-2021 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#include
+#if __has_include()
+#include
+#else
+#include
+#endif
+
+#include "oneapi/math/exceptions.hpp"
+#include "oneapi/math/rng/detail/engine_impl.hpp"
+#include "oneapi/math/rng/detail/armpl/onemath_rng_armpl.hpp"
+
+#include "armpl_common.hpp"
+
+namespace oneapi {
+namespace math {
+namespace rng {
+namespace armpl {
+
+class philox4x32x10_impl : public oneapi::math::rng::detail::engine_impl {
+public:
+ philox4x32x10_impl(sycl::queue queue, std::uint64_t seed)
+ : oneapi::math::rng::detail::engine_impl(queue) {
+ vslNewStreamEx(&stream_, VSL_BRNG_PHILOX4X32X10, 2,
+ reinterpret_cast(&seed));
+ state_size_ = vslGetStreamSize(stream_);
+ }
+
+ philox4x32x10_impl(sycl::queue queue, std::initializer_list seed)
+ : oneapi::math::rng::detail::engine_impl(queue) {
+ vslNewStreamEx(&stream_, VSL_BRNG_PHILOX4X32X10, 2 * seed.size(),
+ reinterpret_cast(seed.begin()));
+ state_size_ = vslGetStreamSize(stream_);
+ }
+
+ philox4x32x10_impl(const philox4x32x10_impl* other)
+ : oneapi::math::rng::detail::engine_impl(*other) {
+ vslCopyStream(&stream_, other->stream_);
+ state_size_ = vslGetStreamSize(stream_);
+ }
+
+ // Buffers APIs
+
+ virtual void generate(const uniform& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const uniform& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const uniform& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ if (distr.a() < 0)
+ check_armpl_version(25, 04, 0,
+ "ArmPl : Uniform int32 generation is not functional with <0 bound");
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ viRngUniform(VSL_RNG_METHOD_UNIFORM_STD,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const uniform& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const uniform& distr, std::int64_t n,
+ sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.a(), distr.b());
+ });
+ });
+ }
+
+ virtual void generate(const gaussian& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access(cgh);
+ host_task>(cgh, [=]() {
+ vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2,
+ static_cast(acc_stream.GET_MULTI_PTR), n,
+ acc_r.GET_MULTI_PTR, distr.mean(), distr.stddev());
+ });
+ });
+ }
+
+ virtual void generate(const gaussian& distr,
+ std::int64_t n, sycl::buffer& r) override {
+ sycl::buffer stream_buf(static_cast(stream_), state_size_);
+ queue_.submit([&](sycl::handler& cgh) {
+ auto acc_stream = stream_buf.get_access(cgh);
+ auto acc_r = r.get_access