Skip to content

Commit 8517a20

Browse files
authored
[RNG] Added random engine adaptor: count_engine_adaptor (#644)
1 parent 09ede6e commit 8517a20

File tree

11 files changed

+343
-21
lines changed

11 files changed

+343
-21
lines changed

include/oneapi/math/rng/device/detail/bits_impl.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class distribution_base<oneapi::math::rng::device::bits<UIntType>> {
2929
protected:
3030
template <typename EngineType>
3131
auto generate(EngineType& engine) -> typename std::enable_if<
32-
!std::is_same<EngineType, mcg59<EngineType::vec_size>>::value,
32+
!std::is_same<EngineType, mcg59<EngineType::vec_size>>::value &&
33+
!std::is_same<EngineType, count_engine_adaptor<mcg59<EngineType::vec_size>>>::value,
3334
typename std::conditional<EngineType::vec_size == 1, UIntType,
3435
sycl::vec<UIntType, EngineType::vec_size>>::type>::type {
3536
static_assert(std::is_same<UIntType, uint32_t>::value,
@@ -39,7 +40,8 @@ class distribution_base<oneapi::math::rng::device::bits<UIntType>> {
3940

4041
template <typename EngineType>
4142
auto generate(EngineType& engine) -> typename std::enable_if<
42-
std::is_same<EngineType, mcg59<EngineType::vec_size>>::value,
43+
std::is_same<EngineType, mcg59<EngineType::vec_size>>::value ||
44+
std::is_same<EngineType, count_engine_adaptor<mcg59<EngineType::vec_size>>>::value,
4345
typename std::conditional<EngineType::vec_size == 1, UIntType,
4446
sycl::vec<UIntType, EngineType::vec_size>>::type>::type {
4547
static_assert(std::is_same<UIntType, uint64_t>::value,
@@ -48,17 +50,21 @@ class distribution_base<oneapi::math::rng::device::bits<UIntType>> {
4850
}
4951

5052
template <typename EngineType>
51-
typename std::enable_if<!std::is_same<EngineType, mcg59<EngineType::vec_size>>::value,
52-
UIntType>::type
53+
typename std::enable_if<
54+
!std::is_same<EngineType, mcg59<EngineType::vec_size>>::value &&
55+
!std::is_same<EngineType, count_engine_adaptor<mcg59<EngineType::vec_size>>>::value,
56+
UIntType>::type
5357
generate_single(EngineType& engine) {
5458
static_assert(std::is_same<UIntType, uint32_t>::value,
5559
"oneMath: bits works only with std::uint32_t");
5660
return engine.generate_single();
5761
}
5862

5963
template <typename EngineType>
60-
typename std::enable_if<std::is_same<EngineType, mcg59<EngineType::vec_size>>::value,
61-
UIntType>::type
64+
typename std::enable_if<
65+
std::is_same<EngineType, mcg59<EngineType::vec_size>>::value ||
66+
std::is_same<EngineType, count_engine_adaptor<mcg59<EngineType::vec_size>>>::value,
67+
UIntType>::type
6268
generate_single(EngineType& engine) {
6369
static_assert(std::is_same<UIntType, uint64_t>::value,
6470
"oneMath: bits for mcg59 works only with std::uint64_t");

include/oneapi/math/rng/device/detail/mcg31m1_impl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ namespace oneapi::math::rng::device {
2525
template <std::int32_t VecSize = 1>
2626
class mcg31m1;
2727

28+
template <typename Engine>
29+
class count_engine_adaptor;
30+
2831
namespace detail {
2932

3033
template <std::uint64_t VecSize>
@@ -224,6 +227,8 @@ class engine_base<oneapi::math::rng::device::mcg31m1<VecSize>> {
224227
}
225228

226229
engine_state<oneapi::math::rng::device::mcg31m1<VecSize>> state_;
230+
friend class oneapi::math::rng::device::count_engine_adaptor<
231+
oneapi::math::rng::device::mcg31m1<VecSize>>;
227232
};
228233

229234
} // namespace detail

include/oneapi/math/rng/device/detail/mcg59_impl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ namespace oneapi::math::rng::device {
2525
template <std::int32_t VecSize = 1>
2626
class mcg59;
2727

28+
template <typename Engine>
29+
class count_engine_adaptor;
30+
2831
namespace detail {
2932

3033
template <std::uint32_t VecSize>
@@ -267,6 +270,8 @@ class engine_base<oneapi::math::rng::device::mcg59<VecSize>> {
267270
}
268271

269272
engine_state<oneapi::math::rng::device::mcg59<VecSize>> state_;
273+
friend class oneapi::math::rng::device::count_engine_adaptor<
274+
oneapi::math::rng::device::mcg59<VecSize>>;
270275
};
271276

272277
} // namespace detail

include/oneapi/math/rng/device/detail/mrg32k3a_impl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ namespace oneapi::math::rng::device {
3232
template <std::int32_t VecSize = 1>
3333
class mrg32k3a;
3434

35+
template <typename Engine>
36+
class count_engine_adaptor;
37+
3538
namespace detail {
3639

3740
template <std::int32_t VecSize>
@@ -377,6 +380,8 @@ class engine_base<oneapi::math::rng::device::mrg32k3a<VecSize>> {
377380
}
378381

379382
engine_state<oneapi::math::rng::device::mrg32k3a<VecSize>> state_;
383+
friend class oneapi::math::rng::device::count_engine_adaptor<
384+
oneapi::math::rng::device::mrg32k3a<VecSize>>;
380385
};
381386

382387
} // namespace detail

include/oneapi/math/rng/device/detail/philox4x32x10_impl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ namespace oneapi::math::rng::device {
2727
template <std::int32_t VecSize = 1>
2828
class philox4x32x10;
2929

30+
template <typename Engine>
31+
class count_engine_adaptor;
32+
3033
namespace detail {
3134

3235
template <std::int32_t VecSize>
@@ -546,6 +549,8 @@ class engine_base<oneapi::math::rng::device::philox4x32x10<VecSize>> {
546549
}
547550

548551
engine_state<oneapi::math::rng::device::philox4x32x10<VecSize>> state_;
552+
friend class oneapi::math::rng::device::count_engine_adaptor<
553+
oneapi::math::rng::device::philox4x32x10<VecSize>>;
549554
};
550555

551556
} // namespace detail

include/oneapi/math/rng/device/detail/uniform_bits_impl.hpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,25 @@ class distribution_base<oneapi::math::rng::device::uniform_bits<UIntType>> {
3131
auto generate(EngineType& engine) ->
3232
typename std::conditional<EngineType::vec_size == 1, UIntType,
3333
sycl::vec<UIntType, EngineType::vec_size>>::type {
34-
static_assert(std::is_same<EngineType, philox4x32x10<EngineType::vec_size>>::value ||
35-
std::is_same<EngineType, mcg59<EngineType::vec_size>>::value,
36-
"oneMath: uniform_bits works only with philox4x32x10/mcg59 engines");
34+
static_assert(
35+
std::is_same<EngineType, philox4x32x10<EngineType::vec_size>>::value ||
36+
std::is_same<EngineType,
37+
count_engine_adaptor<philox4x32x10<EngineType::vec_size>>>::value ||
38+
std::is_same<EngineType, mcg59<EngineType::vec_size>>::value ||
39+
std::is_same<EngineType, count_engine_adaptor<mcg59<EngineType::vec_size>>>::value,
40+
"oneMath: uniform_bits works only with philox4x32x10/mcg59 engines and their adaptors");
3741
return engine.template generate_uniform_bits<UIntType>();
3842
}
3943

4044
template <typename EngineType>
4145
UIntType generate_single(EngineType& engine) {
42-
static_assert(std::is_same<EngineType, philox4x32x10<EngineType::vec_size>>::value ||
43-
std::is_same<EngineType, mcg59<EngineType::vec_size>>::value,
44-
"oneMath: uniform_bits works only with philox4x32x10/mcg59 engines");
46+
static_assert(
47+
std::is_same<EngineType, philox4x32x10<EngineType::vec_size>>::value ||
48+
std::is_same<EngineType,
49+
count_engine_adaptor<philox4x32x10<EngineType::vec_size>>>::value ||
50+
std::is_same<EngineType, mcg59<EngineType::vec_size>>::value ||
51+
std::is_same<EngineType, count_engine_adaptor<mcg59<EngineType::vec_size>>>::value,
52+
"oneMath: uniform_bits works only with philox4x32x10/mcg59 engines and their adaptors");
4553
return engine.template generate_single_uniform_bits<UIntType>();
4654
}
4755
};

include/oneapi/math/rng/device/engines.hpp

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ namespace oneapi::math::rng::device {
3838
// skip_ahead
3939
//
4040
template <std::int32_t VecSize>
41-
class philox4x32x10 : detail::engine_base<philox4x32x10<VecSize>> {
41+
class philox4x32x10 : public detail::engine_base<philox4x32x10<VecSize>> {
4242
public:
4343
static constexpr std::uint64_t default_seed = 0;
4444

@@ -79,7 +79,7 @@ class philox4x32x10 : detail::engine_base<philox4x32x10<VecSize>> {
7979
// skip_ahead
8080
//
8181
template <std::int32_t VecSize>
82-
class mrg32k3a : detail::engine_base<mrg32k3a<VecSize>> {
82+
class mrg32k3a : public detail::engine_base<mrg32k3a<VecSize>> {
8383
public:
8484
static constexpr std::uint32_t default_seed = 1;
8585

@@ -119,7 +119,7 @@ class mrg32k3a : detail::engine_base<mrg32k3a<VecSize>> {
119119
// skip_ahead
120120
//
121121
template <std::int32_t VecSize>
122-
class mcg31m1 : detail::engine_base<mcg31m1<VecSize>> {
122+
class mcg31m1 : public detail::engine_base<mcg31m1<VecSize>> {
123123
public:
124124
static constexpr std::uint32_t default_seed = 1;
125125

@@ -146,7 +146,7 @@ class mcg31m1 : detail::engine_base<mcg31m1<VecSize>> {
146146
// skip_ahead
147147
//
148148
template <std::int32_t VecSize>
149-
class mcg59 : detail::engine_base<mcg59<VecSize>> {
149+
class mcg59 : public detail::engine_base<mcg59<VecSize>> {
150150
public:
151151
static constexpr std::uint32_t default_seed = 1;
152152

@@ -165,6 +165,83 @@ class mcg59 : detail::engine_base<mcg59<VecSize>> {
165165
friend class detail::distribution_base;
166166
};
167167

168+
// ENGINE ADAPTORS
169+
170+
// Class oneapi::math::rng::device::count_engine_adaptor
171+
template <typename Engine>
172+
class count_engine_adaptor {
173+
public:
174+
static constexpr std::int32_t vec_size = Engine::vec_size;
175+
176+
// ctors
177+
template <typename... Params>
178+
count_engine_adaptor(Params... params) : engine_(params...) {}
179+
180+
count_engine_adaptor(const Engine& engine) : engine_(engine) {}
181+
count_engine_adaptor(Engine&& engine) : engine_(std::move(engine)) {}
182+
183+
// methods
184+
template <typename RealType>
185+
auto generate(RealType a, RealType b) {
186+
counted_ += Engine::vec_size;
187+
return engine_.generate(a, b);
188+
}
189+
190+
auto generate() {
191+
counted_ += Engine::vec_size;
192+
return engine_.generate();
193+
}
194+
195+
template <typename RealType>
196+
RealType generate_single(RealType a, RealType b) {
197+
counted_++;
198+
return engine_.generate_single(a, b);
199+
}
200+
201+
template <typename UIntType>
202+
auto generate_uniform_bits() {
203+
if constexpr (std::is_same<UIntType, std::uint32_t>::value) {
204+
counted_ += Engine::vec_size;
205+
}
206+
else {
207+
counted_ += 2 * Engine::vec_size;
208+
}
209+
return engine_.template generate_uniform_bits<UIntType>();
210+
}
211+
212+
template <typename UIntType>
213+
auto generate_single_uniform_bits() {
214+
if constexpr (std::is_same<UIntType, std::uint32_t>::value) {
215+
counted_ += 1;
216+
}
217+
else {
218+
counted_ += 2;
219+
}
220+
return engine_.template generate_single_uniform_bits<UIntType>();
221+
}
222+
223+
auto generate_bits() {
224+
counted_ += Engine::vec_size;
225+
return engine_.generate_bits();
226+
}
227+
228+
// getters
229+
std::int64_t get_count() const {
230+
return counted_;
231+
}
232+
233+
const Engine& base() const {
234+
return engine_;
235+
}
236+
237+
private:
238+
Engine engine_;
239+
std::int64_t counted_ = 0;
240+
241+
template <typename DistrType>
242+
friend class detail::distribution_base;
243+
};
244+
168245
} // namespace oneapi::math::rng::device
169246

170247
#endif // ONEMATH_RNG_DEVICE_ENGINES_HPP_

0 commit comments

Comments
 (0)