Skip to content

Commit 4e134f0

Browse files
authored
Merge pull request #171 from stan-dev/feature/reduce-sum-argument-order
Changed argument order for reduce_sum partial_sum function (follow up to pull #161)
2 parents 50cdcae + e7e3f11 commit 4e134f0

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

src/functions-reference/higher-order_functions.Rmd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ parallelization of the resultant sum.
445445
`real` **`reduce_sum`**`(F f, T[] x, int grainsize, T1 s1, T2 s2, ...)`<br>\newline
446446
`real` **`reduce_sum_static`**`(F f, T[] x, int grainsize, T1 s1, T2 s2, ...)`<br>\newline
447447

448-
Returns the equivalent of `f(1, size(x), x, s1, s2, ...)`, but computes
448+
Returns the equivalent of `f(x, 1, size(x), s1, s2, ...)`, but computes
449449
the result in parallel by breaking the array `x` into independent
450450
partial sums. `s1, s2, ...` are shared between all terms in the sum.
451451

@@ -464,18 +464,18 @@ types of all the shared arguments (`T1`, `T2`, ...) match those of the original
464464
`reduce_sum` (`reduce_sum_static`) call.
465465

466466
```
467-
(int start, int end, T[] x_subset, T1 s1, T2 s2, ...):real
467+
(T[] x_subset, int start, int end, T1 s1, T2 s2, ...):real
468468
```
469469

470470
The partial sum function returns the sum of the `start` to `end` terms (inclusive) of the overall
471471
calculations. The arguments to the partial sum function are:
472472

473+
* *`x_subset`*, the subset of `x` a given partial sum is responsible for computing, type `T[]`, where `T` matches the type of `x` in `reduce_sum` (`reduce_sum_static`)
474+
473475
* *`start`*, the index of the first term of the partial sum, type `int`
474476

475477
* *`end`*, the index of the last term of the partial sum (inclusive), type `int`
476478

477-
* *`x_subset`*, the subset of `x` a given partial sum is responsible for computing, type `T[]`, where `T` matches the type of `x` in `reduce_sum` (`reduce_sum_static`)
478-
479479
* *`s1`*, first shared argument, type `T1`, matching type of `s1` in `reduce_sum` (`reduce_sum_static`)
480480

481481
* *`s2`*, second shared argument, type `T2`, matching type of `s2` in `reduce_sum` (`reduce_sum_static`)

src/stan-users-guide/parallelization.Rmd

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,15 @@ real reduce_sum_static(F f, T[] x, int grainsize, T1 s1, T2 s2, ...)
108108
The user-defined partial sum functions have the signature:
109109

110110
```
111-
real f(int start, int end, T[] x_slice, T1 s1, T2 s2, ...)
111+
real f(T[] x_slice, int start, int end, T1 s1, T2 s2, ...)
112112
```
113113

114114
and take the arguments:
115115

116-
1. ```start``` - An integer specifying the first term in the partial sum
117-
2. ```end``` - An integer specifying the last term in the partial sum (inclusive)
118-
3. ```x_slice``` - The subset of ```x``` (from ```reduce_sum``` / `reduce_sum_static`) for
119-
which this partial sum is responsible (```x_slice = x[start:end]```)
116+
1. ```x_slice``` - The subset of ```x``` (from ```reduce_sum``` / `reduce_sum_static`) for
117+
which this partial sum is responsible (```x_slice = x[start:end]```)
118+
2. ```start``` - An integer specifying the first term in the partial sum
119+
3. ```end``` - An integer specifying the last term in the partial sum (inclusive)
120120
4. ```s1, s2, ...``` - Arguments shared in every term (passed on
121121
without modification from the ```reduce_sum``` / `reduce_sum_static` call)
122122

@@ -137,15 +137,15 @@ real sum = reduce_sum(f, x, grainsize, s1, s2, ...);
137137
can be replaced by either:
138138

139139
```
140-
real sum = f(1, size(x), x, s1, s2, ...);
140+
real sum = f(x, 1, size(x), s1, s2, ...);
141141
```
142142

143143
or the code:
144144

145145
```
146146
real sum = 0.0;
147147
for(i in 1:size(x)) {
148-
sum += f(i, i, { x[i] }, s1, s2, ...);
148+
sum += f({ x[i] }, i, i, s1, s2, ...);
149149
}
150150
```
151151

@@ -180,6 +180,7 @@ y ~ bernoulli_logit(beta[1] + beta[2] * x);
180180
```
181181

182182
can be rewritten (up to a proportionality constant) as:
183+
183184
```
184185
for(n in 1:N) {
185186
target += bernoulli_logit_lpmf(y[n] | beta[1] + beta[2] * x[n])
@@ -195,8 +196,8 @@ the total sum. Using the interface defined in
195196

196197
```
197198
functions {
198-
real partial_sum(int start, int end,
199-
int[] y_slice,
199+
real partial_sum(int[] y_slice,
200+
int start, int end,
200201
vector x,
201202
vector beta) {
202203
return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
@@ -207,7 +208,7 @@ functions {
207208
The likelihood statement in the model can now be written:
208209

209210
```
210-
target += partial_sum(1, N, y, x, beta); // Sum terms 1 to N of the likelihood
211+
target += partial_sum(y, 1, N, x, beta); // Sum terms 1 to N of the likelihood
211212
```
212213

213214
In this example, `y` was chosen to be sliced over because there
@@ -232,8 +233,8 @@ and computes them in parallel. `grainsize = 1` specifies that the
232233

233234
```
234235
functions {
235-
real partial_sum(int start, int end,
236-
int[] y_slice,
236+
real partial_sum(int[] y_slice,
237+
int start, int end,
237238
vector x,
238239
vector beta) {
239240
return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);

0 commit comments

Comments
 (0)