Skip to content

Commit 9a6f4aa

Browse files
committed
Updated reduce_sum docs to reflect addition of lupmf
1 parent bf2c0d5 commit 9a6f4aa

File tree

2 files changed

+63
-40
lines changed

2 files changed

+63
-40
lines changed

knitr/reduce-sum/logistic1.stan

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
functions {
2-
real partial_sum(int[] slice_n_redcards,
3-
int start, int end,
4-
int[] n_games,
5-
vector rating,
6-
vector beta) {
7-
return binomial_logit_lpmf(slice_n_redcards |
8-
n_games[start:end],
9-
beta[1] + beta[2] * rating[start:end]);
2+
real partial_sum_lpmf(int[] slice_n_redcards,
3+
int start, int end,
4+
int[] n_games,
5+
vector rating,
6+
vector beta) {
7+
return binomial_logit_lupmf(slice_n_redcards |
8+
n_games[start:end],
9+
beta[1] + beta[2] * rating[start:end]);
1010
}
1111
}
1212
data {
@@ -24,6 +24,6 @@ model {
2424
beta[1] ~ normal(0, 10);
2525
beta[2] ~ normal(0, 1);
2626

27-
target += reduce_sum(partial_sum, n_redcards, grainsize,
27+
target += reduce_sum(partial_sum_lupmf, n_redcards, grainsize,
2828
n_games, rating, beta);
29-
}
29+
}

knitr/reduce-sum/reduce_sum_tutorial.Rmd

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -149,55 +149,74 @@ statement:
149149
n_redcards ~ binomial_logit(n_games, beta[1] + beta[2] * rating);
150150
```
151151

152-
can be rewritten (up to a proportionality constant) as:
152+
can be rewritten as:
153153

154154
```{stan, output.var = "", eval = FALSE}
155155
for(n in 1:N) {
156-
target += binomial_logit_lpmf(n_redcards[n] | n_games[n], beta[1] + beta[2] * rating[n])
156+
target += binomial_logit_lupmf(n_redcards[n] | n_games[n], beta[1] + beta[2] * rating[n])
157157
}
158158
```
159159

160-
Now it is clear that the calculation is the sum of a number of
161-
conditionally independent Bernoulli log probability statements. So
162-
whenever we need to calculate a large sum where each term is
163-
independent of all others and associativity holds, then `reduce_sum`
164-
is useful.
160+
Now it is clear that the calculation is the sum (up to a
161+
proportionality constant) of a number of conditionally independent
162+
Bernoulli log probability statements. So whenever we need to calculate
163+
a large sum where each term is independent of all others and associativity
164+
holds, then `reduce_sum` is useful.
165165

166166
To use `reduce_sum`, a function must be written that can be used to compute
167167
arbitrary sections of this sum.
168168

169+
Note we used `binomial_logit_lupmf` instead of `binomial_logit_lpmf`.
170+
This is because we only need this likelihood term up to a proportionality
171+
constant for MCMC to work and for some distributions this can make code
172+
run noticeably faster. Because of the way that `_lupmf` features work,
173+
Stan only allows them in the model block or in user-defined probability
174+
distribution functions, and so the function we write for `reduce_sum`
175+
will need to be a probability distribution function (suffixed with
176+
`_lpdf` or `_lpmf`) for us to use `binomial_logit_lupmf` on the inside.
177+
If the difference in the normalized and unnormalized functions is not
178+
relevant for your application, you can call your `reduce_sum` function
179+
whatever you like.
180+
169181
Using the reducer interface defined in
170182
[Reduce-Sum](https://mc-stan.org/docs/2_23/functions-reference/functions-reduce.html):
171183

172184
```{stan, output.var = "", eval = FALSE}
173185
functions {
174-
real partial_sum(int[] slice_n_redcards,
175-
int start, int end,
176-
int[] n_games,
177-
vector rating,
178-
vector beta) {
179-
return binomial_logit_lpmf(slice_n_redcards |
180-
n_games[start:end],
181-
beta[1] + beta[2] * rating[start:end]);
186+
real partial_sum_lpmf(int[] slice_n_redcards,
187+
int start, int end,
188+
int[] n_games,
189+
vector rating,
190+
vector beta) {
191+
return binomial_logit_lupmf(slice_n_redcards |
192+
n_games[start:end],
193+
beta[1] + beta[2] * rating[start:end]);
182194
}
183195
}
184196
```
185197

186198
The likelihood statement in the model can now be written:
187199

188200
```{stan, output.var = "", eval = FALSE}
189-
target += partial_sum(n_redcards, 1, N, n_games, rating, beta); // Sum terms 1 to N in the likelihood
201+
target += partial_sum_lupmf(n_redcards, 1, N, n_games, rating, beta); // Sum terms 1 to N in the likelihood
190202
```
191203

192-
Equivalently it could be broken into two pieces and written like:
204+
Note that we're calling `partial_sum_lupmf` even though we defined the
205+
function `partial_sum_lpmf`. `partial_sum_lupmf` is implicitly defined when
206+
we write `partial_sum_lpmf` and is a special version of the function that
207+
will signify to all the `_lupmf` calls inside it that it is okay to drop
208+
constants. If we call `partial_sum_lpmf`, the `binomial_logit_lupmf` function
209+
call will not drop constants (and hence be slower).
210+
211+
Equivalently this partial sum could be broken into two pieces and written like:
193212

194213
```{stan, output.var = "", eval = FALSE}
195214
int M = N / 2;
196-
target += partial_sum(n_redcards[1:M], 1, M, n_games, rating, beta) // Sum terms 1 to M
197-
target += partial_sum(n_redcards[(M + 1):N], M + 1, N, n_games, rating, beta); // Sum terms M + 1 to N
215+
target += partial_sum_lupmf(n_redcards[1:M], 1, M, n_games, rating, beta) // Sum terms 1 to M
216+
target += partial_sum_lupmf(n_redcards[(M + 1):N], M + 1, N, n_games, rating, beta); // Sum terms M + 1 to N
198217
```
199218

200-
By passing `partial_sum` to `reduce_sum`, we allow Stan to
219+
By passing `partial_sum_lupmf` to `reduce_sum`, we allow Stan to
201220
automatically break up these calculations and do them in parallel.
202221

203222
Notice the difference in how `n_redcards` is split in half (to reflect
@@ -211,7 +230,7 @@ likelihood:
211230

212231
```{stan, output.var = "", eval = FALSE}
213232
int grainsize = 1;
214-
target += reduce_sum(partial_sum, n_redcards, grainsize,
233+
target += reduce_sum(partial_sum_lupmf, n_redcards, grainsize,
215234
n_games, rating, beta);
216235
```
217236

@@ -221,16 +240,20 @@ be estimated automatically (`grainsize` should be left at 1 unless specific test
221240
are done to
222241
[pick a different one](https://mc-stan.org/docs/2_23/stan-users-guide/reduce-sum.html#reduce-sum-grainsize)).
223242

243+
Again, if we passed `partial_sum_lpmf` to `reduce_sum` instead of
244+
`partial_sum_lupmf` we would not take advantage of the performance benefits
245+
of using `bernoulli_logit_lupmf`.
246+
224247
Making `grainsize` data (this makes it convenient to experiment with), the final
225248
model is:
226249
```{stan, output.var = "", eval = FALSE}
227250
functions {
228-
real partial_sum(int[] slice_n_redcards,
229-
int start, int end,
230-
int[] n_games,
231-
vector rating,
232-
vector beta) {
233-
return binomial_logit_lpmf(slice_n_redcards |
251+
real partial_sum_lpmf(int[] slice_n_redcards,
252+
int start, int end,
253+
int[] n_games,
254+
vector rating,
255+
vector beta) {
256+
return binomial_logit_lupmf(slice_n_redcards |
234257
n_games[start:end],
235258
beta[1] + beta[2] * rating[start:end]);
236259
}
@@ -250,7 +273,7 @@ model {
250273
beta[1] ~ normal(0, 10);
251274
beta[2] ~ normal(0, 1);
252275
253-
target += reduce_sum(partial_sum, n_redcards, grainsize,
276+
target += reduce_sum(partial_sum_lupmf, n_redcards, grainsize,
254277
n_games, rating, beta);
255278
}
256279
```
@@ -311,11 +334,11 @@ to check diagnostics. `reduce_sum` is a tool for speeding up single chain
311334
calculations, which can be useful for model development and on computers with
312335
large numbers of cores.
313336

314-
We can do a quick check that these two methods are mixing with posterior.
337+
We can do a quick check that these two methods are mixing with the `posterior`
338+
package (https://github.com/stan-dev/posterior).
315339
When parallelizing a model is a good thing to do to make sure something is not
316340
breaking:
317341
```{r}
318-
remotes::install_github("jgabry/posterior")
319342
library(posterior)
320343
summarise_draws(bind_draws(fit0$draws(), fit1$draws(), along = "chain"))
321344
```

0 commit comments

Comments
 (0)