Skip to content

Commit aca714a

Browse files
committed
Add discrete distribution
1 parent 3878ad7 commit aca714a

File tree

5 files changed

+175
-16
lines changed

5 files changed

+175
-16
lines changed

doc/modules/random.rst

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,46 @@ Distributions
6969
=============
7070

7171
Distributions produce numerical values from the random bitstream provided by
72-
an RNG engine.
72+
an RNG engine. For efficiency, each distribution subroutine accepts an *array*
73+
of values that are filled with samples of the distribution.
7374

7475
normal_distribution
7576
-------------------
7677

77-
Fill the given array with Gaussian-distributed numbers generated using the
78-
given random number engine about the given mean value with the given standard
79-
deviation.
78+
Each element of the sampled array is distributed according to a Gaussian
79+
function with the given mean and standard deviation.
8080

8181
uniform_int_distribution
8282
------------------------
8383

84-
Fill the given array with uniformly distributed integers generated using the
85-
given random number engine, between the two bounds (inclusive on both sides).
84+
Each element is uniformly sampled between the two provided bounds, inclusive on
85+
both sides.
8686

8787
uniform_real_distribution
8888
-------------------------
8989

90-
Fill the given array with uniformly distributed real numbers generated using the
91-
given random number engine, between the two bounds (inclusive on left side only).
90+
Each element is a sample of a uniform distribution between the two bounds,
91+
inclusive on left side only.
9292

93+
discrete_distribution
94+
---------------------
95+
96+
The discrete distribution is constructed with an array of :math:`N` weights:
97+
the probability that an index in the range :math:`[1, N]` will be selected.
98+
::
99+
100+
integer(C_INT), dimension(4), parameter :: weights = [1, 1, 2, 4]
101+
integer(C_INT), dimension(1024) :: sampled
102+
call discrete_distribution(weights, Engine(), sampled)
103+
104+
In the above example, ``1`` and ``2`` will be present in the ``sampled`` array
105+
about the same number of times since those indices have equal weight; ``3``
106+
will be present about twice as often as ``1`` and ``4`` will be present about
107+
four times as often.
108+
109+
.. note:: The C++ distribution returns values in :math:`[0, N)`, so in
110+
accordance with Flibcpp's :ref:`indexing convention <conventions_indexing>`
111+
the result is transformed when provided to Fortran users.
93112

94113
.. ############################################################################
95114
.. end of doc/modules/random.rst

include/flc_random.i

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ static inline void flc_generate(D dist, G& g, T* data, size_t size) {
5454
%}
5555

5656
%apply (const SWIGTYPE *DATA, size_t SIZE) {
57-
(SWIGTYPE const *WEIGHTS, size_t WEIGHTSIZE) };
57+
(const int32_t *WEIGHTS, size_t WEIGHTSIZE),
58+
(const int64_t *WEIGHTS, size_t WEIGHTSIZE) };
5859

5960
%inline %{
6061
template<class T, class G>
@@ -77,6 +78,16 @@ static void normal_distribution(T mean, T stddev,
7778
flc_generate(std::normal_distribution<T>(mean, stddev),
7879
engine, DATA, DATASIZE);
7980
}
81+
82+
template<class T, class G>
83+
static void discrete_distribution(const T* WEIGHTS, size_t WEIGHTSIZE,
84+
G& engine, T* DATA, size_t DATASIZE) {
85+
std::discrete_distribution<T> dist(WEIGHTS, WEIGHTS + WEIGHTSIZE);
86+
T* const end = DATA + DATASIZE;
87+
while (DATA != end) {
88+
*DATA++ = dist(engine) + 1; // Note: transform to Fortran 1-offset
89+
}
90+
}
8091
%}
8192

8293
%define %flc_distribution(NAME, STDENGINE, TYPE)
@@ -93,3 +104,7 @@ static void normal_distribution(T mean, T stddev,
93104
%flc_distribution(uniform_real, FLC_DEFAULT_ENGINE, double)
94105

95106
%flc_distribution(normal, FLC_DEFAULT_ENGINE, double)
107+
108+
// Discrete sampling distribution
109+
%flc_distribution(discrete, FLC_DEFAULT_ENGINE, int32_t)
110+
%flc_distribution(discrete, FLC_DEFAULT_ENGINE, int64_t)

src/flc_random.f90

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ module flc_random
6262
module procedure swigf_uniform_int_distribution__SWIG_1, swigf_uniform_int_distribution__SWIG_2
6363
end interface
6464
public :: uniform_int_distribution
65+
interface discrete_distribution
66+
module procedure swigf_discrete_distribution__SWIG_1, swigf_discrete_distribution__SWIG_2
67+
end interface
68+
public :: discrete_distribution
6569

6670
! WRAPPER DECLARATIONS
6771
interface
@@ -223,6 +227,26 @@ subroutine swigc_normal_distribution(farg1, farg2, farg3, farg4) &
223227
type(SwigArrayWrapper) :: farg4
224228
end subroutine
225229

230+
subroutine swigc_discrete_distribution__SWIG_1(farg1, farg3, farg4) &
231+
bind(C, name="_wrap_discrete_distribution__SWIG_1")
232+
use, intrinsic :: ISO_C_BINDING
233+
import :: swigarraywrapper
234+
import :: swigclasswrapper
235+
type(SwigArrayWrapper) :: farg1
236+
type(SwigClassWrapper) :: farg3
237+
type(SwigArrayWrapper) :: farg4
238+
end subroutine
239+
240+
subroutine swigc_discrete_distribution__SWIG_2(farg1, farg3, farg4) &
241+
bind(C, name="_wrap_discrete_distribution__SWIG_2")
242+
use, intrinsic :: ISO_C_BINDING
243+
import :: swigarraywrapper
244+
import :: swigclasswrapper
245+
type(SwigArrayWrapper) :: farg1
246+
type(SwigClassWrapper) :: farg3
247+
type(SwigArrayWrapper) :: farg4
248+
end subroutine
249+
226250
end interface
227251

228252

@@ -522,5 +546,35 @@ subroutine normal_distribution(mean, stddev, engine, data)
522546
call swigc_normal_distribution(farg1, farg2, farg3, farg4)
523547
end subroutine
524548

549+
subroutine swigf_discrete_distribution__SWIG_1(weights, engine, data)
550+
use, intrinsic :: ISO_C_BINDING
551+
integer(C_INT32_T), dimension(:), intent(in), target :: weights
552+
class(MersenneEngine4), intent(in) :: engine
553+
integer(C_INT32_T), dimension(:), target :: data
554+
type(SwigArrayWrapper) :: farg1
555+
type(SwigClassWrapper) :: farg3
556+
type(SwigArrayWrapper) :: farg4
557+
558+
call SWIGTM_fin_int32_t_Sb__SB_(weights, farg1)
559+
farg3 = engine%swigdata
560+
call SWIGTM_fin_int32_t_Sb__SB_(data, farg4)
561+
call swigc_discrete_distribution__SWIG_1(farg1, farg3, farg4)
562+
end subroutine
563+
564+
subroutine swigf_discrete_distribution__SWIG_2(weights, engine, data)
565+
use, intrinsic :: ISO_C_BINDING
566+
integer(C_INT64_T), dimension(:), intent(in), target :: weights
567+
class(MersenneEngine4), intent(in) :: engine
568+
integer(C_INT64_T), dimension(:), target :: data
569+
type(SwigArrayWrapper) :: farg1
570+
type(SwigClassWrapper) :: farg3
571+
type(SwigArrayWrapper) :: farg4
572+
573+
call SWIGTM_fin_int64_t_Sb__SB_(weights, farg1)
574+
farg3 = engine%swigdata
575+
call SWIGTM_fin_int64_t_Sb__SB_(data, farg4)
576+
call swigc_discrete_distribution__SWIG_2(farg1, farg3, farg4)
577+
end subroutine
578+
525579

526580
end module

src/flc_randomFORTRAN_wrap.cxx

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,16 @@ static void normal_distribution(T mean, T stddev,
286286
engine, DATA, DATASIZE);
287287
}
288288

289+
template<class T, class G>
290+
static void discrete_distribution(const T* WEIGHTS, size_t WEIGHTSIZE,
291+
G& engine, T* DATA, size_t DATASIZE) {
292+
std::discrete_distribution<T> dist(WEIGHTS, WEIGHTS + WEIGHTSIZE);
293+
T* const end = DATA + DATASIZE;
294+
while (DATA != end) {
295+
*DATA++ = dist(engine) + 1; // Note: transform to Fortran 1-offset
296+
}
297+
}
298+
289299

290300
struct SwigClassWrapper {
291301
void* cptr;
@@ -671,5 +681,41 @@ SWIGEXPORT void _wrap_normal_distribution(double const *farg1, double const *far
671681
}
672682

673683

684+
SWIGEXPORT void _wrap_discrete_distribution__SWIG_1(SwigArrayWrapper *farg1, SwigClassWrapper *farg3, SwigArrayWrapper *farg4) {
685+
int32_t *arg1 = (int32_t *) 0 ;
686+
size_t arg2 ;
687+
std::mt19937 *arg3 = 0 ;
688+
int32_t *arg4 = (int32_t *) 0 ;
689+
size_t arg5 ;
690+
691+
arg1 = (int32_t *)farg1->data;
692+
arg2 = farg1->size;
693+
SWIG_check_nonnull(*farg3, "std::mt19937 &", "MersenneEngine4", "discrete_distribution< int32_t,std::mt19937 >(int32_t const *,size_t,std::mt19937 &,int32_t *,size_t)", return );
694+
arg3 = (std::mt19937 *)farg3->cptr;
695+
arg4 = (int32_t *)farg4->data;
696+
arg5 = farg4->size;
697+
discrete_distribution< int32_t,std::mt19937 >((int32_t const *)arg1,arg2,*arg3,arg4,arg5);
698+
SWIG_free_rvalue< std::mt19937, SWIGPOLICY_std_mt19937 >(*farg3);
699+
}
700+
701+
702+
SWIGEXPORT void _wrap_discrete_distribution__SWIG_2(SwigArrayWrapper *farg1, SwigClassWrapper *farg3, SwigArrayWrapper *farg4) {
703+
int64_t *arg1 = (int64_t *) 0 ;
704+
size_t arg2 ;
705+
std::mt19937 *arg3 = 0 ;
706+
int64_t *arg4 = (int64_t *) 0 ;
707+
size_t arg5 ;
708+
709+
arg1 = (int64_t *)farg1->data;
710+
arg2 = farg1->size;
711+
SWIG_check_nonnull(*farg3, "std::mt19937 &", "MersenneEngine4", "discrete_distribution< int64_t,std::mt19937 >(int64_t const *,size_t,std::mt19937 &,int64_t *,size_t)", return );
712+
arg3 = (std::mt19937 *)farg3->cptr;
713+
arg4 = (int64_t *)farg4->data;
714+
arg5 = farg4->size;
715+
discrete_distribution< int64_t,std::mt19937 >((int64_t const *)arg1,arg2,*arg3,arg4,arg5);
716+
SWIG_free_rvalue< std::mt19937, SWIGPOLICY_std_mt19937 >(*farg3);
717+
}
718+
719+
674720
} // extern
675721

test/test_random.F90

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ program test_random
1313
call test_uniform_int_distribution()
1414
call test_uniform_real_distribution()
1515
call test_normal_distribution()
16+
call test_discrete_distribution()
1617
contains
1718

1819
!-----------------------------------------------------------------------------!
@@ -57,8 +58,7 @@ subroutine test_uniform_int_distribution()
5758
call uniform_int_distribution(5, 15, rng, arr)
5859
ASSERT(minval(arr) >= 5)
5960
ASSERT(maxval(arr) <= 15)
60-
61-
write(*,*) sum(arr) - (10 * size(arr))
61+
ASSERT(abs(sum(arr) - (10 * size(arr))) < size(arr) / 10)
6262

6363
call rng%release()
6464
end subroutine
@@ -79,8 +79,8 @@ subroutine test_uniform_real_distribution()
7979
ASSERT(maxval(arr) <= 15.d0)
8080

8181
avg = sum(arr) / real(size(arr), kind=8)
82-
write(*,*) "Average of sampled real values:", avg
8382
ASSERT(avg >= 9.5 .and. avg <= 10.5)
83+
8484
call rng%release()
8585
end subroutine
8686

@@ -90,18 +90,43 @@ subroutine test_normal_distribution()
9090
use flc_random, only : Engine => MersenneEngine4, normal_distribution
9191
implicit none
9292
real(C_DOUBLE), dimension(:), allocatable :: arr
93+
real(C_DOUBLE) :: avg
9394
type(Engine) :: rng
9495

95-
allocate(arr(10))
96+
allocate(arr(1024))
9697
rng = Engine() ! Initialize with default seed
9798

98-
! Mean=10, sigma=5
99-
call normal_distribution(10.0d0, 5.0d0, rng, arr)
100-
write(*,*) "Samples from normal distribution:", arr
99+
! Mean=10, sigma=2
100+
call normal_distribution(10.0d0, 2.0d0, rng, arr)
101+
102+
avg = sum(arr) / real(size(arr), kind=8)
103+
ASSERT(avg >= 9.9 .and. avg <= 10.1)
101104

102105
call rng%release()
103106
end subroutine
104107

108+
!-----------------------------------------------------------------------------!
109+
subroutine test_discrete_distribution()
110+
use, intrinsic :: ISO_C_BINDING
111+
use flc_random, only : Engine => MersenneEngine4, discrete_distribution
112+
implicit none
113+
integer(C_INT), dimension(4), parameter :: weights = [1, 1, 2, 4]
114+
integer(C_INT), dimension(1024) :: sampled
115+
integer(C_INT), dimension(4) :: tallied = 0
116+
integer(C_INT), dimension(4), parameter :: gold_result = [130, 127, 267, 500]
117+
integer :: i
118+
119+
! Sample 1024 random ints
120+
call discrete_distribution(weights, Engine(), sampled)
121+
ASSERT(minval(sampled) == 1)
122+
ASSERT(maxval(sampled) == size(weights))
123+
do i = 1, size(sampled)
124+
tallied(sampled(i)) = tallied(sampled(i)) + 1
125+
enddo
126+
127+
ASSERT(all(tallied == gold_result))
128+
end subroutine
129+
105130
!-----------------------------------------------------------------------------!
106131

107132
end program

0 commit comments

Comments
 (0)