@@ -38,7 +38,7 @@ namespace oneapi::math::rng::device {
3838// skip_ahead
3939//
4040template <std::int32_t VecSize>
41- class philox4x32x10 : detail::engine_base<philox4x32x10<VecSize>> {
41+ class philox4x32x10 : public detail ::engine_base<philox4x32x10<VecSize>> {
4242public:
4343 static constexpr std::uint64_t default_seed = 0 ;
4444
@@ -79,7 +79,7 @@ class philox4x32x10 : detail::engine_base<philox4x32x10<VecSize>> {
7979// skip_ahead
8080//
8181template <std::int32_t VecSize>
82- class mrg32k3a : detail::engine_base<mrg32k3a<VecSize>> {
82+ class mrg32k3a : public detail ::engine_base<mrg32k3a<VecSize>> {
8383public:
8484 static constexpr std::uint32_t default_seed = 1 ;
8585
@@ -119,7 +119,7 @@ class mrg32k3a : detail::engine_base<mrg32k3a<VecSize>> {
119119// skip_ahead
120120//
121121template <std::int32_t VecSize>
122- class mcg31m1 : detail::engine_base<mcg31m1<VecSize>> {
122+ class mcg31m1 : public detail ::engine_base<mcg31m1<VecSize>> {
123123public:
124124 static constexpr std::uint32_t default_seed = 1 ;
125125
@@ -146,7 +146,7 @@ class mcg31m1 : detail::engine_base<mcg31m1<VecSize>> {
146146// skip_ahead
147147//
148148template <std::int32_t VecSize>
149- class mcg59 : detail::engine_base<mcg59<VecSize>> {
149+ class mcg59 : public detail ::engine_base<mcg59<VecSize>> {
150150public:
151151 static constexpr std::uint32_t default_seed = 1 ;
152152
@@ -165,6 +165,83 @@ class mcg59 : detail::engine_base<mcg59<VecSize>> {
165165 friend class detail ::distribution_base;
166166};
167167
168+ // ENGINE ADAPTORS
169+
170+ // Class oneapi::math::rng::device::count_engine_adaptor
171+ template <typename Engine>
172+ class count_engine_adaptor {
173+ public:
174+ static constexpr std::int32_t vec_size = Engine::vec_size;
175+
176+ // ctors
177+ template <typename ... Params>
178+ count_engine_adaptor (Params... params) : engine_(params...) {}
179+
180+ count_engine_adaptor (const Engine& engine) : engine_(engine) {}
181+ count_engine_adaptor (Engine&& engine) : engine_(std::move(engine)) {}
182+
183+ // methods
184+ template <typename RealType>
185+ auto generate (RealType a, RealType b) {
186+ counted_ += Engine::vec_size;
187+ return engine_.generate (a, b);
188+ }
189+
190+ auto generate () {
191+ counted_ += Engine::vec_size;
192+ return engine_.generate ();
193+ }
194+
195+ template <typename RealType>
196+ RealType generate_single (RealType a, RealType b) {
197+ counted_++;
198+ return engine_.generate_single (a, b);
199+ }
200+
201+ template <typename UIntType>
202+ auto generate_uniform_bits () {
203+ if constexpr (std::is_same<UIntType, std::uint32_t >::value) {
204+ counted_ += Engine::vec_size;
205+ }
206+ else {
207+ counted_ += 2 * Engine::vec_size;
208+ }
209+ return engine_.template generate_uniform_bits <UIntType>();
210+ }
211+
212+ template <typename UIntType>
213+ auto generate_single_uniform_bits () {
214+ if constexpr (std::is_same<UIntType, std::uint32_t >::value) {
215+ counted_ += 1 ;
216+ }
217+ else {
218+ counted_ += 2 ;
219+ }
220+ return engine_.template generate_single_uniform_bits <UIntType>();
221+ }
222+
223+ auto generate_bits () {
224+ counted_ += Engine::vec_size;
225+ return engine_.generate_bits ();
226+ }
227+
228+ // getters
229+ std::int64_t get_count () const {
230+ return counted_;
231+ }
232+
233+ const Engine& base () const {
234+ return engine_;
235+ }
236+
237+ private:
238+ Engine engine_;
239+ std::int64_t counted_ = 0 ;
240+
241+ template <typename DistrType>
242+ friend class detail ::distribution_base;
243+ };
244+
168245} // namespace oneapi::math::rng::device
169246
170247#endif // ONEMATH_RNG_DEVICE_ENGINES_HPP_
0 commit comments