Skip to content

Commit 3fcff92

Browse files
committed
Updated reduce_sum docs to reflect design-doc changes (design-doc pull request #17)
1 parent 9c3af80 commit 3fcff92

File tree

2 files changed

+139
-80
lines changed

2 files changed

+139
-80
lines changed

src/functions-reference/higher-order_functions.Rmd

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ if (knitr::is_html_output()) {
1010
cat(' * <a href="functions-algebraic-solver.html">Algebraic Equation Solver</a>\n')
1111
cat(' * <a href="functions-ode-solver.html">Ordinary Differential Equation (ODE) Solvers</a>\n')
1212
cat(' * <a href="functions-1d-integrator.html">1D Integrator</a>\n')
13+
cat(' * <a href="functions-reduce.html">Reduce-Sum</a>\n')
1314
cat(' * <a href="functions-map.html">Higher-Order Map</a>\n')
1415
}
1516
```
@@ -382,67 +383,75 @@ Internally the 1D integrator uses the double-exponential methods in the Boost 1D
382383

383384
The gradients of the integral are computed in accordance with the Leibniz integral rule. Gradients of the integrand are computed internally with Stan's automatic differentiation.
384385

385-
## Parallel Reduce-Sum Function {#functions-reduce}
386+
## Reduce-Sum Function {#functions-reduce}
386387

387-
Stan provides a higher-order ```reduce_sum``` function for parallelizing operations that can be represented as a reduce (by summation) over a sequence of terms.
388+
Stan provides a higher-order `reduce_sum` function for parallelizing operations that can be represented as a sum of a function, `g: U -> real`, evaluated at each element of a list of type `U[]`, `{ x1, x2, ... }`. That is:
388389

389-
### Reduce-Sum Function
390+
```g(x1) + g(x2) + ...```
390391

391-
The reduce sum function operates on a reducing function, a list of
392-
sliced arguments (one for each term in the parallel-for), a recommended grainsize,
393-
and a set of shared arguments.
392+
`reduce_sum` doesn't work on `g` itself, but takes a partial sum function, `f: U[] -> real`, where:
394393

395-
<!-- real; map_rect; (F f, T[] sliced\_arg, int grainsize, T1 arg1, T2 arg2, ...); -->
396-
\index{{\tt \bfseries reduce\_sum }!{\tt (F f, T[] sliced\_arg, int grainsize, T1 arg1, T2 arg2, ...): real}|hyperpage}
394+
```f({ x1 }) = g(x1)```
395+
```f({ x1, x2 }) = g(x1) + g(x2)```
396+
```f({ x1, x2, ... }) = g(x1) + g(x2) + ...```
397397

398-
`real` **`reduce_sum`**`(F f, T[] sliced_arg, int grainsize, T1 arg1, T2 arg2, ...)`<br>\newline
398+
### The Reduce-sum Function
399399

400-
Return the equivalent of `f(1, size(sliced_arg), arg1, arg2, ...)`, but compute
401-
the result in parallel by breaking `sliced_arg` into pieces and computing each piece
402-
in a different thread. `arg1, arg2, ...` are shared between all terms in the sum.
400+
The `reduce_sum` function takes a partial sum function, an array argument x
401+
(with one for each term in the sum), a recommended grainsize, and a set of shared arguments and
402+
parallelizes the resultant sum.
403403

404-
* *`f`*: function literal referring to a function specifying the reduce operation with signature `(int, int, T[], T1, T2, ...):real`
404+
<!-- real; reduce_sum; (F f, T[] x, int grainsize, T1 s1, T2 s2, ...); -->
405+
\index{{\tt \bfseries reduce\_sum }!{\tt (F f, T[] x, int grainsize, T1 s1, T2 s2, ...): real}|hyperpage}
406+
407+
`real` **`reduce_sum`**`(F f, T[] x, int grainsize, T1 s1, T2 s2, ...)`<br>\newline
408+
409+
Return the equivalent of `f(1, size(x), x, s1, s2, ...)`, but compute
410+
the result in parallel by breaking the array `x` into pieces and computing each piece
411+
in a different thread. `s1, s2, ...` are shared between all terms in the sum.
412+
413+
* *`f`*: function literal referring to a function specifying the partial sum operation with signature `(int, int, T[], T1, T2, ...):real`
405414
The arguments represent
406-
+ (1) the index of the first term of the reduction,
407-
+ (2) the index of the last term of the reduction,
408-
+ (3) the subset `sliced_arg` this reduce is responsible for computing,
415+
+ (1) the index of the first term of the partial sum,
416+
+ (2) the index of the last term of the partial sum,
417+
+ (3) the subset `x` this reduce is responsible for computing,
409418
+ (4) first shared argument,
410419
+ (5) second shared argument,
411420
+ ... the rest of the shared arguments.
412421

413-
* *`sliced_args`*: array of non-shared arguments, one for each term of the reduction, array of `T`, where `T` can be any type,
414-
* *`grainsize`*: recommented number of terms in each reduce call, set to zero to estimate automatically, type `int`,
415-
* *`arg1`*: first (optional) shared argument, type `T1`, where `T1` can be any type
416-
* *`arg2`*: second (optional) shared argument, type `T2`, where `T2` can be any type,
422+
* *`x`*: array of `T`, one for each term of the reduction, `T` can be any type,
423+
* *`grainsize`*: recommented number of terms in each reduce call, set to one to estimate automatically, type `int`,
424+
* *`s1`*: first (optional) shared argument, type `T1`, where `T1` can be any type
425+
* *`s2`*: second (optional) shared argument, type `T2`, where `T2` can be any type,
417426
* *`...`*: remainder of shared arguments, each of which can be any type.
418427

419-
### Specifying the Reduce Function
428+
### The Partial-sum Function
420429

421-
The reduce function must have the following signature where the types T, and the
422-
types of all the variadic arguments (`T1`, `T2`, ...) match those of the original
430+
The partial sum function must have the following signature where the types `T`, and the
431+
types of all the shared arguments (`T1`, `T2`, ...) match those of the original
423432
`reduce_sum` call.
424433

425434
```
426-
(int start, int end, T[] subset_sliced_arg, T1 arg1, T2 arg2, ...):real
435+
(int start, int end, T[] x_subset, T1 s1, T2 s2, ...):real
427436
```
428437

429-
The reduce function returns the sum of the `start` to `end` terms of the overall
438+
The reduce function returns the sum of the `start` to `end` terms (inclusive) of the overall
430439
calculations. The arguments to the reduce function are:
431440

432-
* *`start`*, the index of the first term of the reduction, type `int`
441+
* *`start`*, the index of the first term of the partial sum, type `int`
433442

434-
* *`end`*, the index of the last term of the reduction (inclusive), type `int`
443+
* *`end`*, the index of the last term of the partial sum (inclusive), type `int`
435444

436-
* *`subset_sliced_arg`*, the subset `sliced_arg` this reduce is responsible for computing, type `T[]`, where `T` matches the type of `sliced_arg` in `reduce_sum`
445+
* *`x_subset`*, the subset of `x` this partial sum is responsible for computing, type `T[]`, where `T` matches the type of `x` in `reduce_sum`
437446

438-
* *`arg1`*, first shared argument, type `T1`, matching type of `arg1` in `reduce_sum`
447+
* *`s1`*, first shared argument, type `T1`, matching type of `s1` in `reduce_sum`
439448

440-
* *`arg2`*, second shared argument, type `T2`, matching type of `arg2` in `reduce_sum`
449+
* *`s2`*, second shared argument, type `T2`, matching type of `s2` in `reduce_sum`
441450

442451
* *`...`*, remainder of shared arguments, with types matching those in `reduce_sum`
443452

444453

445-
## Parallel Map-Rect Function {#functions-map}
454+
## Map-Rect Function {#functions-map}
446455

447456
Stan provides a higher-order map function. This allows map-reduce
448457
functionality to be coded in Stan as described in the user's guide.

src/stan-users-guide/parallel-computing.Rmd

Lines changed: 99 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,82 +2,106 @@
22

33
Stan has two mechanisms for parallelizing calculations used in a model: `reduce_sum` and and `map_rect`.
44

5-
The main differences are:
5+
The main advantages to `reduce_sum` are:
66

7-
1. `reduce_sum` requires the result of the calculation to be a scalar, while `map_rect` returns a list of vectors
8-
2. `reduce_sum` has a more flexible interface and can accept arbitrary Stan types as arguments, `map_rect` is more restrictive on what arguments can be and how they are shaped
9-
3. `map_rect` can parallelize work over multiple computers or a single computer, while `reduce_sum` works only on a single computer
10-
4. `map_rect` requires work to be broken into pieces manually, while `reduce_sum` mostly automates this
7+
1. `reduce_sum` has a more flexible argument interface, avoiding the packing and unpacking that is necessary with `map_rect`.
8+
2. `reduce_sum` partitions the data for parallelization automatically (this is done manually in `map_rect`).
9+
3. `reduce_sum` is easier to use.
1110

12-
## Reduce-Sum { #reduce-sum }
11+
while the advantages of `map_rect` are:
12+
13+
1. `map_rect` returns a list of vectors, while `reduce_sum` returns only a real.
14+
2. `map_rect` can be parallelized across multiple computers, while `reduce_sum` can only parallelized across multiple cores.
1315

14-
```reduce_sum``` is a tool for parallelizing operations that can be represented as a parallel-for combined with a sum (that returns a scalar).
16+
## Reduce-Sum { #reduce-sum }
1517

16-
In terms of probabilistic models, an example of this comes up when N terms in a likelihood combined multiplicatively and can be computed independently of each other (where independence here is in the computational sense, not necessarily the statistical sense). In this case, computing the log density means computing the sum of a number of terms that can each be computed separately.
18+
```reduce_sum``` is a tool for parallelizing operations that can be represented as a sum of functions, `g: U -> real`.
1719

18-
```reduce_sum``` is not useful when there are dependencies between the terms. This can happen, for instance, if there were N terms in a Gaussian process likelihood. ```reduce_sum``` will not be useful for accelerating this.
20+
For instance, for a sequence of ```x``` values of type ```U```, ```{ x1, x2, ... }```, we might compute the sum:
1921

20-
If for a set of input arguments, ```args0, args1, args2, ...``` and a scalar function ```f```, the log likelihood can be computed as:
22+
```g(x1) + g(x2) + ...```
2123

22-
```f(args0) + f(args1) + f(args2) + ...```
24+
In probabilistic modeling this comes up when there are N conditionally independent terms in a likelihood. Because of the conditional independence, these terms can be computed in parallel. If dependencies exist between the terms, then this isn't possible. For instance, in evaluating the log density of a Gaussian process ```reduce_sum``` would not be very useful.
2325

24-
then this calculation can be written as a reduction over the set of arguments. If this reducing function is called ```reduce```, then it would need to perform the operations:
26+
```reduce_sum``` doesn't actually take ```g: U -> real``` as an input argument. Instead it takes ```f: U[] -> real```, where ```f``` computes the partial sum corresponding to the slice of the sequence ```x``` passed in. For instance:
2527

26-
```reduce({ args0, args1, args2, ... }) = f(args0) + f(args1) + f(args2) + ...```
28+
```
29+
f({ x1, x2, x3 }) = g(x1) + g(x2) + g(x3)
30+
f({ x1 }) = g(x1)
31+
f({ x1, x2, x3 }) = f({ x1, x2 }) + f({ x3 })
32+
```
2733

28-
If the user can write a function like ```reduce```, then it is trivial for us to provide a function to automatically parallelize the calculations.
34+
If the user can write a function ```f: U[] -> real``` to compute the necessary partial sums in the calculation, then we can provide a function to automatically parallelize the calculations (and this is what ```reduce_sum``` is).
2935

30-
Again, if the set of work is represented as a list of arguments ``{ args0, args1, args2, ... }```, then mathematically it is possible to rewrite this sum with any combination of partial-reduces.
36+
If the set of work is represented as an array ```{ x1, x2, x3, ... }```, then mathematically it is possible to rewrite this sum with any combination of partial sums.
3137

3238
For instance, the sum can be written:
3339

34-
1. ```reduce({ args0, args1, args2, ... })```, summing over all arguments, using one reduce function
35-
2. ```reduce({ args0, ..., args(M - 1) }) + reduce({ argsM, args2, ...})```, computing the first M terms separately from the rest
36-
3. ```reduce({ args0 }) + reduce({ args1 }) + reduce({ args2 }) + ...```, computing each term individually and summing them
40+
1. ```f({ x1, x2, x3, ... })```, summing over all arguments, using one partial sum
41+
2. ```f({ x1, ..., xM }) + reduce({ x(M + 1), x(M + 2), ...})```, computing the first M terms separately from the rest
42+
3. ```f({ x1 }) + f({ x2 }) + f({ x3 }) + ...```, computing each term individually and summing them
3743

38-
The first function call is completely serial, the second can be parallelized over two workers, and the last can be parallelized over as many workers as there are arguments. Depending on how the list is sliced up, it is possible to parallelize this calculation over many workers.
44+
The first form uses only one partial sum and no parallelization can be done, the second uses two partial sums and so can be parallelized over two workers, and the last can be parallelized over as many workers as there are elements in the array ```x```. Depending on how the list is sliced up, it is possible to parallelize this calculation over many workers.
3945

40-
```reduce_sum``` is the tool that will allow us to automatically parallelize these reduce operations (and sum them together).
46+
```reduce_sum``` is the tool that will allow us to automatically parallelize this calculation.
47+
48+
For efficiency and convenience, ```reduce_sum``` allows for additional shared arguments to be passed to every term in the sum. So for the array ```{ x1, x2, ... }``` and the shared arguments ```s1, s2, ...``` the effective sum (with individual terms) looks like:
49+
50+
```
51+
g(x1, s1, s2, ...) + g(x2, s1, s2, ...) + g(x3, s1, s2, ...) + ...
52+
```
4153

42-
To implement this efficiently in Stan, the individual arguments are split into two types. The first are the arguments that are specific to each term in the reduction. These are called the sliced arguments (because we will slice these up to determine how to distribute the work). The second type of arguments are shared arguments, and are information that is shared in the computation of every term in the sum.
54+
which can be written equivalently with partial sums to look like:
55+
56+
```
57+
f({ x1, x2 }, s1, s2, ...) + f({ x3 }, s1, s2, ...)
58+
```
59+
60+
where the particular slicing of the ```x``` array can change.
4361

4462
Given this, the signature for reduce_sum is:
4563

4664
```
47-
real reduce_sum(F func, T[] sliced_arg, int grainsize, T1 arg1, T2 arg2, ...)
65+
real reduce_sum(F func, T[] x, int grainsize, T1 s1, T2 s2, ...)
4866
```
4967

50-
1. ```func``` - The user-defined reduce function
51-
2. ```sliced_arg``` - An array of any type, with each element of the array corresponding to a term of the final summation (the length of ```sliced_arg``` is the total number of terms to sum)
52-
3. ```grainsize``` - A hint to the runtime of how many terms of the summation to compute in each reduction
53-
4-. ```arg1, arg2, ...``` - All the arguments that are to be shared in the calculation of every term in the sum
68+
1. ```func``` - User defined function that computes partial sums
69+
2. ```x``` - Array to slice, each element corresponds to a term in the summation
70+
3. ```grainsize``` - Target for size of slices
71+
4-. ```s1, s2, ...``` - Arguments shared in every term
5472

55-
The user-defined reduce function is slightly different:
73+
The user-defined partial sum functions have the signature:
5674

5775
```
58-
real func(int start, int end, T[] subset_sliced_arg, T1 arg1, T2 arg2, ...)
76+
real func(int start, int end, T[] x_subset, T1 arg1, T2 arg2, ...)
5977
```
6078

61-
and takes the arguments:
62-
1. ```start``` - An integer specifying the first element of the sequence of terms this reduce call is responsible for computing
63-
2. ```end``` - An integer specifying the last element of the sequence of terms this reduce call is responsible for computing
64-
3. ```subset_sliced_arg``` - The subset of sliced_arg for which this reduce is responsible (```sliced_arg[start:end]```)
65-
4-. ```arg1, arg2, ...``` all the shared arguments -- passed on without modification from the reduce_sum call
79+
and take the arguments:
80+
1. ```start``` - An integer specifying the first term in the partial sum
81+
2. ```end``` - An integer specifying the last term in the partial sum (inclusive)
82+
3. ```x_subset``` - The subset of ```x``` (from ```reduce_sum```) for which this partial sum is responsible (```x[start:end]```)
83+
4-. ```arg1, arg2, ...``` Arguments shared in every term (passed on without modification from the reduce_sum call)
6684

67-
The user-provided function ```func``` is expect to compute the ```start``` through ```end``` terms of the overall sum, accumulate them, and return that value. The user function is only passed the subset ```sliced_arg[start:end]``` of sliced arg (as ```subset_sliced_arg```). ```start``` and ```end``` are passed so that ```func``` can index any of the ```argM``` appropriately. The trailing arguments ```argM``` are passed without modification to every call of ```func```.
85+
The user-provided function ```func``` is expect to compute the ```start``` through ```end``` terms of the overall sum, accumulate them, and return that value. The user function is passed the subset ```x[start:end]``` as ```x_subset```. ```start``` and ```end``` are passed so that ```func``` can index any of the tailing ```sM``` arguments as necessary. The trailing ```sM``` arguments are passed without modification to every call of ```func```.
86+
87+
The ```reduce_sum``` call:
88+
89+
```
90+
real sum = reduce_sum(func, x, grainsize, s1, s2, ...)
91+
```
6892

69-
An overall call to ```reduce_sum``` can be replaced by either:
93+
can be replaced by either:
7094

7195
```
72-
real sum = func(1, size(sliced_arg), sliced_arg, arg1, arg2, ...)
96+
real sum = func(1, size(x), x, s1, s2, ...)
7397
```
7498

75-
or (modulo differences due to rearrangements of summations) the code:
99+
or the code:
76100

77101
```
78102
real sum = 0.0;
79-
for(i in 1:size(sliced_arg)) {
80-
sum = sum + func(i, i, { sliced_arg[i] }, arg1, arg2, ...);
103+
for(i in 1:size(x)) {
104+
sum = sum + func(i, i, { x[i] }, s1, s2, ...);
81105
}
82106
```
83107

@@ -126,26 +150,26 @@ independent Bernoulli log probability statements, which is the condition where
126150
`reduce_sum` is useful.
127151

128152
To use `reduce_sum`, a function must be written that can be used to compute
129-
arbitrary subsets of the sums.
153+
arbitrary partial sums of the total sum.
130154

131-
Using the reducer interface defined in [Reduce-Sum](#reduce-sum), such a function
155+
Using the interface defined in [Reduce-Sum](#reduce-sum), such a function
132156
can be written like:
133157

134158
```
135159
functions {
136-
real reducer_func(int start, int end,
137-
int[] subset_y,
138-
vector x,
139-
vector beta) {
140-
return bernoulli_logit_lpmf(subset_y | beta[1] + beta[2] * x[start:end]);
160+
real partial_sum(int start, int end,
161+
int[] y_subset,
162+
vector x,
163+
vector beta) {
164+
return bernoulli_logit_lpmf(y_subset | beta[1] + beta[2] * x[start:end]);
141165
}
142166
}
143167
```
144168

145169
And the likelihood statement in the model can now be written:
146170

147171
```
148-
target += reducer_fun(1, N, y, x, beta); // Sum terms 1 to N in the likelihood
172+
target += partial_sum(1, N, y, x, beta); // Sum terms 1 to N in the likelihood
149173
```
150174

151175
In this example, `y` was chosen to be sliced over because there
@@ -159,14 +183,40 @@ likelihood:
159183

160184
```
161185
int grainsize = 100;
162-
target += reduce_sum(reducer_func, y,
186+
target += reduce_sum(partial_sum, y,
163187
grainsize,
164188
x, beta);
165189
```
166190

167191
`reduce_sum` automatically breaks the sum into roughly `grainsize` sized pieces
168192
and computes them in parallel. `grainsize = 1` specifies that the grainsize should
169-
be estimated automatically.
193+
be estimated automatically. The final model looks like:
194+
195+
```
196+
functions {
197+
real partial_sum(int start, int end,
198+
int[] y_subset,
199+
vector x,
200+
vector beta) {
201+
return bernoulli_logit_lpmf(y_subset | beta[1] + beta[2] * x[start:end]);
202+
}
203+
}
204+
data {
205+
int N;
206+
int y[N];
207+
vector[N] x;
208+
}
209+
parameters {
210+
vector[2] beta;
211+
}
212+
model {
213+
int grainsize = 100;
214+
beta ~ std_normal();
215+
target += reduce_sum(partial_sum, y,
216+
grainsize,
217+
x, beta);
218+
}
219+
```
170220

171221
### Picking the Grainsize
172222

0 commit comments

Comments
 (0)