Skip to content

Commit 0260ff7

Browse files
authored
[RNG][Device API] Added (u)int8_t and (u)int16_t types support for Uniform (uxlfoundation#632)
1 parent c255b1b commit 0260ff7

File tree

4 files changed

+190
-72
lines changed

4 files changed

+190
-72
lines changed

include/oneapi/math/rng/device/detail/uniform_impl.hpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,15 @@ class distribution_base<oneapi::math::rng::device::uniform<Type, Method>> {
117117
sycl::vec<Type, EngineType::vec_size>>::type {
118118
using OutType = typename std::conditional<EngineType::vec_size == 1, Type,
119119
sycl::vec<Type, EngineType::vec_size>>::type;
120-
using FpType =
121-
typename std::conditional<std::is_same<Method, uniform_method::accurate>::value, double,
122-
float>::type;
120+
using FpType = typename std::conditional<
121+
!std::is_same_v<Method, uniform_method::accurate> ||
122+
std::is_same_v<Type, std::int8_t> || std::is_same_v<Type, std::uint8_t> ||
123+
std::is_same_v<Type, std::int16_t> || std::is_same_v<Type, std::uint16_t>,
124+
float, double>::type;
123125
OutType res;
124126
if constexpr (std::is_integral<Type>::value) {
125-
if constexpr (std::is_same_v<Type, std::int32_t> ||
126-
std::is_same_v<Type, std::uint32_t>) {
127+
if constexpr (!std::is_same_v<Type, std::int64_t> &&
128+
!std::is_same_v<Type, std::uint64_t>) {
127129
return generate_single_int<FpType, OutType>(engine);
128130
}
129131
else {
@@ -238,13 +240,15 @@ class distribution_base<oneapi::math::rng::device::uniform<Type, Method>> {
238240

239241
template <typename EngineType>
240242
Type generate_single(EngineType& engine) {
241-
using FpType =
242-
typename std::conditional<std::is_same<Method, uniform_method::accurate>::value, double,
243-
float>::type;
243+
using FpType = typename std::conditional<
244+
!std::is_same_v<Method, uniform_method::accurate> ||
245+
std::is_same_v<Type, std::int8_t> || std::is_same_v<Type, std::uint8_t> ||
246+
std::is_same_v<Type, std::int16_t> || std::is_same_v<Type, std::uint16_t>,
247+
float, double>::type;
244248
Type res;
245249
if constexpr (std::is_integral<Type>::value) {
246-
if constexpr (std::is_same_v<Type, std::int32_t> ||
247-
std::is_same_v<Type, std::uint32_t>) {
250+
if constexpr (!std::is_same_v<Type, std::int64_t> &&
251+
!std::is_same_v<Type, std::uint64_t>) {
248252
FpType res_fp =
249253
engine.generate_single(static_cast<FpType>(a_), static_cast<FpType>(b_));
250254
res_fp = sycl::floor(res_fp);

include/oneapi/math/rng/device/distributions.hpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,24 @@ namespace oneapi::math::rng::device {
3636
// Supported types:
3737
// float
3838
// double
39+
// std::int8_t
40+
// std::uint8_t
41+
// std::int16_t
42+
// std::uint16_t
3943
// std::int32_t
4044
// std::uint32_t
45+
// std::int64_t
46+
// std::uint64_t
4147
//
4248
// Supported methods:
4349
// oneapi::math::rng::device::uniform_method::standard
4450
// oneapi::math::rng::device::uniform_method::accurate
4551
//
4652
// Input arguments:
4753
// a - left bound. 0.0 by default
48-
// b - right bound. 1.0 by default (for std::(u)int32_t std::numeric_limits<std::int32_t>::max()
49-
// is used for accurate method and 2^23 is used for standard method)
54+
// b - right bound. 1.0 for floating point types.
55+
// For integer types, 2^23 is used for std::(u)int32_t in the case of standard method,
56+
// and std::numeric_limits<integer_type>::max() for other integer types.
5057
//
5158
// Note: using (un)signed integer uniform distribution with uniform_method::standard method may
5259
// cause incorrect statistics of the produced random numbers (due to rounding error) if
@@ -61,6 +68,8 @@ class uniform : detail::distribution_base<uniform<Type, Method>> {
6168
"oneMath: rng/uniform: method is incorrect");
6269

6370
static_assert(std::is_same<Type, float>::value || std::is_same<Type, double>::value ||
71+
std::is_same_v<Type, std::int8_t> || std::is_same_v<Type, std::uint8_t> ||
72+
std::is_same_v<Type, std::int16_t> || std::is_same_v<Type, std::uint16_t> ||
6473
std::is_same<Type, std::int32_t>::value ||
6574
std::is_same<Type, std::uint32_t>::value ||
6675
std::is_same<Type, std::int64_t>::value ||
@@ -77,6 +86,11 @@ class uniform : detail::distribution_base<uniform<Type, Method>> {
7786
std::is_integral<Type>::value
7887
? ((std::is_same_v<Type, std::uint64_t> || std::is_same_v<Type, std::int64_t>)
7988
? (std::numeric_limits<Type>::max)()
89+
: (std::is_same_v<Type, std::int8_t> ||
90+
std::is_same_v<Type, std::uint8_t> ||
91+
std::is_same_v<Type, std::int16_t> ||
92+
std::is_same_v<Type, std::uint16_t>)
93+
? (std::numeric_limits<Type>::max)()
8094
: (std::is_same<Method, uniform_method::standard>::value
8195
? (1 << 23)
8296
: (std::numeric_limits<Type>::max)()))
@@ -579,8 +593,12 @@ class poisson : detail::distribution_base<poisson<IntType, Method>> {
579593
// Represents discrete Bernoulli random number distribution
580594
//
581595
// Supported types:
582-
// std::uint32_t
596+
// std::int8_t
597+
// std::uint8_t
598+
// std::int16_t
599+
// std::uint16_t
583600
// std::int32_t
601+
// std::uint32_t
584602
//
585603
// Supported methods:
586604
// oneapi::math::rng::bernoulli_method::icdf;

tests/unit_tests/rng/device/include/rng_device_test_common.hpp

Lines changed: 69 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,29 @@ bool compare_moments(const std::vector<Fp, AllocType>& r, double tM, double tD,
166166
return true;
167167
}
168168

169+
template <typename Distribution, typename Fp, typename AllocType>
170+
bool calculate_and_compare_moments_uniform(Distribution distr,
171+
const std::vector<Fp, AllocType>& r) {
172+
double tM, tD, tQ;
173+
double a = distr.a();
174+
double b = distr.b();
175+
176+
// Theoretical moments
177+
if constexpr (std::is_integral<Fp>::value) {
178+
tM = (a + b - 1.0) / 2.0;
179+
tD = ((b - a) * (b - a) - 1.0) / 12.0;
180+
tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) +
181+
(7.0 / 240.0);
182+
}
183+
else {
184+
tM = (b + a) / 2.0;
185+
tD = ((b - a) * (b - a)) / 12.0;
186+
tQ = ((b - a) * (b - a) * (b - a) * (b - a)) / 80.0;
187+
}
188+
189+
return compare_moments(r, tM, tD, tQ);
190+
}
191+
169192
template <typename Distribution>
170193
struct statistics_device {};
171194

@@ -174,92 +197,79 @@ struct statistics_device<oneapi::math::rng::device::uniform<Fp, Method>> {
174197
template <typename AllocType>
175198
bool check(const std::vector<Fp, AllocType>& r,
176199
const oneapi::math::rng::device::uniform<Fp, Method>& distr) {
177-
double tM, tD, tQ;
178-
Fp a = distr.a();
179-
Fp b = distr.b();
200+
return calculate_and_compare_moments_uniform(distr, r);
201+
}
202+
};
180203

181-
// Theoretical moments
182-
tM = (b + a) / 2.0;
183-
tD = ((b - a) * (b - a)) / 12.0;
184-
tQ = ((b - a) * (b - a) * (b - a) * (b - a)) / 80.0;
204+
template <typename Method>
205+
struct statistics_device<oneapi::math::rng::device::uniform<std::int8_t, Method>> {
206+
template <typename AllocType>
207+
bool check(const std::vector<std::int8_t, AllocType>& r,
208+
const oneapi::math::rng::device::uniform<std::int8_t, Method>& distr) {
209+
return calculate_and_compare_moments_uniform(distr, r);
210+
}
211+
};
185212

186-
return compare_moments(r, tM, tD, tQ);
213+
template <typename Method>
214+
struct statistics_device<oneapi::math::rng::device::uniform<std::uint8_t, Method>> {
215+
template <typename AllocType>
216+
bool check(const std::vector<std::uint8_t, AllocType>& r,
217+
const oneapi::math::rng::device::uniform<std::uint8_t, Method>& distr) {
218+
return calculate_and_compare_moments_uniform(distr, r);
187219
}
188220
};
189221

190222
template <typename Method>
191-
struct statistics_device<oneapi::math::rng::device::uniform<std::int32_t, Method>> {
223+
struct statistics_device<oneapi::math::rng::device::uniform<std::int16_t, Method>> {
192224
template <typename AllocType>
193-
bool check(const std::vector<std::int32_t, AllocType>& r,
194-
const oneapi::math::rng::device::uniform<std::int32_t, Method>& distr) {
195-
double tM, tD, tQ;
196-
double a = distr.a();
197-
double b = distr.b();
225+
bool check(const std::vector<std::int16_t, AllocType>& r,
226+
const oneapi::math::rng::device::uniform<std::int16_t, Method>& distr) {
227+
return calculate_and_compare_moments_uniform(distr, r);
228+
}
229+
};
198230

199-
// Theoretical moments
200-
tM = (a + b - 1.0) / 2.0;
201-
tD = ((b - a) * (b - a) - 1.0) / 12.0;
202-
tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) +
203-
(7.0 / 240.0);
231+
template <typename Method>
232+
struct statistics_device<oneapi::math::rng::device::uniform<std::uint16_t, Method>> {
233+
template <typename AllocType>
234+
bool check(const std::vector<std::uint16_t, AllocType>& r,
235+
const oneapi::math::rng::device::uniform<std::uint16_t, Method>& distr) {
236+
return calculate_and_compare_moments_uniform(distr, r);
237+
}
238+
};
204239

205-
return compare_moments(r, tM, tD, tQ);
240+
template <typename Method>
241+
struct statistics_device<oneapi::math::rng::device::uniform<std::int32_t, Method>> {
242+
template <typename AllocType>
243+
bool check(const std::vector<int32_t, AllocType>& r,
244+
const oneapi::math::rng::device::uniform<int32_t, Method>& distr) {
245+
return calculate_and_compare_moments_uniform(distr, r);
206246
}
207247
};
208248

209249
template <typename Method>
210250
struct statistics_device<oneapi::math::rng::device::uniform<std::uint32_t, Method>> {
211251
template <typename AllocType>
212-
bool check(const std::vector<std::uint32_t, AllocType>& r,
213-
const oneapi::math::rng::device::uniform<std::uint32_t, Method>& distr) {
214-
double tM, tD, tQ;
215-
double a = distr.a();
216-
double b = distr.b();
217-
218-
// Theoretical moments
219-
tM = (a + b - 1.0) / 2.0;
220-
tD = ((b - a) * (b - a) - 1.0) / 12.0;
221-
tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) +
222-
(7.0 / 240.0);
223-
224-
return compare_moments(r, tM, tD, tQ);
252+
bool check(const std::vector<uint32_t, AllocType>& r,
253+
const oneapi::math::rng::device::uniform<uint32_t, Method>& distr) {
254+
return calculate_and_compare_moments_uniform(distr, r);
225255
}
226256
};
227257

228258
template <typename Method>
229259
struct statistics_device<oneapi::math::rng::device::uniform<std::int64_t, Method>> {
230260
template <typename AllocType>
231-
bool check(const std::vector<std::int64_t, AllocType>& r,
232-
const oneapi::math::rng::device::uniform<std::int64_t, Method>& distr) {
233-
double tM, tD, tQ;
234-
double a = distr.a();
235-
double b = distr.b();
236-
237-
// Theoretical moments
238-
tM = (a + b - 1.0) / 2.0;
239-
tD = ((b - a) * (b - a) - 1.0) / 12.0;
240-
tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) +
241-
(7.0 / 240.0);
242-
243-
return compare_moments(r, tM, tD, tQ);
261+
bool check(const std::vector<int64_t, AllocType>& r,
262+
const oneapi::math::rng::device::uniform<int64_t, Method>& distr) {
263+
return calculate_and_compare_moments_uniform(distr, r);
244264
}
245265
};
246266

247267
template <typename Method>
248268
struct statistics_device<oneapi::math::rng::device::uniform<std::uint64_t, Method>> {
249269
template <typename AllocType>
250-
bool check(const std::vector<std::uint64_t, AllocType>& r,
251-
const oneapi::math::rng::device::uniform<std::uint64_t, Method>& distr) {
252-
double tM, tD, tQ;
253-
double a = distr.a();
254-
double b = distr.b();
255-
256-
// Theoretical moments
257-
tM = (a + b - 1.0) / 2.0;
258-
tD = ((b - a) * (b - a) - 1.0) / 12.0;
259-
tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) +
260-
(7.0 / 240.0);
261-
262-
return compare_moments(r, tM, tD, tQ);
270+
bool check(const std::vector<uint64_t, AllocType>& r,
271+
const oneapi::math::rng::device::uniform<uint64_t, Method>& distr) {
272+
return calculate_and_compare_moments_uniform(distr, r);
263273
}
264274
};
265275

tests/unit_tests/rng/device/moments/moments.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,92 @@ TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, RealDoublePrecision) {
6767
EXPECT_TRUEORSKIP((test3(GetParam())));
6868
}
6969

70+
/* Test small types (u)int8, (u)int16 only with uniform_method::standard since numbers are always generated
71+
as single precision numbers */
72+
TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, Integer8Precision) {
73+
rng_device_test<
74+
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
75+
oneapi::math::rng::device::uniform<
76+
std::int8_t, oneapi::math::rng::device::uniform_method::standard>>>
77+
test1;
78+
EXPECT_TRUEORSKIP((test1(GetParam())));
79+
rng_device_test<
80+
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
81+
oneapi::math::rng::device::uniform<
82+
std::int8_t, oneapi::math::rng::device::uniform_method::standard>>>
83+
test2;
84+
EXPECT_TRUEORSKIP((test2(GetParam())));
85+
rng_device_test<
86+
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
87+
oneapi::math::rng::device::uniform<
88+
std::int8_t, oneapi::math::rng::device::uniform_method::standard>>>
89+
test3;
90+
EXPECT_TRUEORSKIP((test3(GetParam())));
91+
}
92+
93+
TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedInteger8Precision) {
94+
rng_device_test<
95+
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
96+
oneapi::math::rng::device::uniform<
97+
std::uint8_t, oneapi::math::rng::device::uniform_method::standard>>>
98+
test1;
99+
EXPECT_TRUEORSKIP((test1(GetParam())));
100+
rng_device_test<
101+
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
102+
oneapi::math::rng::device::uniform<
103+
std::uint8_t, oneapi::math::rng::device::uniform_method::standard>>>
104+
test2;
105+
EXPECT_TRUEORSKIP((test2(GetParam())));
106+
rng_device_test<
107+
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
108+
oneapi::math::rng::device::uniform<
109+
std::uint8_t, oneapi::math::rng::device::uniform_method::standard>>>
110+
test3;
111+
EXPECT_TRUEORSKIP((test3(GetParam())));
112+
}
113+
114+
TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, Integer16Precision) {
115+
rng_device_test<
116+
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
117+
oneapi::math::rng::device::uniform<
118+
std::int16_t, oneapi::math::rng::device::uniform_method::standard>>>
119+
test1;
120+
EXPECT_TRUEORSKIP((test1(GetParam())));
121+
rng_device_test<
122+
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
123+
oneapi::math::rng::device::uniform<
124+
std::int16_t, oneapi::math::rng::device::uniform_method::standard>>>
125+
test2;
126+
EXPECT_TRUEORSKIP((test2(GetParam())));
127+
rng_device_test<
128+
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
129+
oneapi::math::rng::device::uniform<
130+
std::int16_t, oneapi::math::rng::device::uniform_method::standard>>>
131+
test3;
132+
EXPECT_TRUEORSKIP((test3(GetParam())));
133+
}
134+
135+
TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedInteger16Precision) {
136+
rng_device_test<
137+
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
138+
oneapi::math::rng::device::uniform<
139+
std::uint16_t, oneapi::math::rng::device::uniform_method::standard>>>
140+
test1;
141+
EXPECT_TRUEORSKIP((test1(GetParam())));
142+
rng_device_test<
143+
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
144+
oneapi::math::rng::device::uniform<
145+
std::uint16_t, oneapi::math::rng::device::uniform_method::standard>>>
146+
test2;
147+
EXPECT_TRUEORSKIP((test2(GetParam())));
148+
rng_device_test<
149+
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
150+
oneapi::math::rng::device::uniform<
151+
std::uint16_t, oneapi::math::rng::device::uniform_method::standard>>>
152+
test3;
153+
EXPECT_TRUEORSKIP((test3(GetParam())));
154+
}
155+
70156
TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, IntegerPrecision) {
71157
rng_device_test<
72158
moments_test<oneapi::math::rng::device::philox4x32x10<1>,

0 commit comments

Comments
 (0)