11---
22title : " Reduce Sum: A Minimal Example"
3- date : " 16 June 2020"
3+ date : " 2 Dec 2020"
44output : html_document
55---
66
@@ -13,7 +13,10 @@ This introduction to `reduce_sum` copies directly from Richard McElreath's
1313
1414## Introduction
1515
16- Stan 2.23 introduced ` reduce_sum ` , a new way to parallelize the execution of
16+ ** Note:** This has been rewritten to use unnormalized distribution functions
17+ ( ` _lupdf˙/˙_lupmf ` ) which requires Cmdstan 2.25 or newer.
18+
19+ Cmdstan 2.23 introduced ` reduce_sum ` , a new way to parallelize the execution of
1720a single Stan chain across multiple cores. This is in addition to the already
1821existing ` map_rect ` utility, and introduces a number of features that make it
1922easier to use parallelism:
@@ -149,59 +152,79 @@ statement:
149152n_redcards ~ binomial_logit(n_games, beta[1] + beta[2] * rating);
150153```
151154
152- can be rewritten (up to a proportionality constant) as:
155+ can be rewritten as:
153156
154157``` {stan, output.var = "", eval = FALSE}
155158for(n in 1:N) {
156- target += binomial_logit_lpmf (n_redcards[n] | n_games[n], beta[1] + beta[2] * rating[n])
159+ target += binomial_logit_lupmf (n_redcards[n] | n_games[n], beta[1] + beta[2] * rating[n])
157160}
158161```
159162
160- Now it is clear that the calculation is the sum of a number of
161- conditionally independent Bernoulli log probability statements. So
162- whenever we need to calculate a large sum where each term is
163- independent of all others and associativity holds, then ` reduce_sum `
164- is useful.
163+ Now it is clear that the calculation is the sum (up to a
164+ proportionality constant) of a number of conditionally independent
165+ Bernoulli log probability statements. Whenever we need to calculate
166+ a large sum where each term is independent of all others and associativity
167+ holds, then ` reduce_sum ` is useful.
165168
166169To use ` reduce_sum ` , a function must be written that can be used to compute
167170arbitrary sections of this sum.
168171
172+ Note we are using ` binomial_logit_lupmf ` instead of ` binomial_logit_lpmf ` .
173+ This is because we only need this likelihood term up to a proportionality
174+ constant for MCMC to work and for some distributions this can make code
175+ run noticeably faster. There is a catch though: Stan only allows ` _lupmf `
176+ in the model block or in user-defined probability distribution functions.
177+ Thus, for us to use ` binomial_logit_lupmf ` the, function we write for
178+ ` reduce_sum ` must be a user-defined probability distribution function
179+ (which means it must be suffixed with ` _lpdf ` or ` _lpmf ` ).
180+
181+ If the difference in the performance of normalized and unnormalized functions
182+ is not relevant for your application, you can call your ` reduce_sum ` function
183+ whatever you like.
184+
169185Using the reducer interface defined in
170186[ Reduce-Sum] ( https://mc-stan.org/docs/2_23/functions-reference/functions-reduce.html ) :
171187
172188``` {stan, output.var = "", eval = FALSE}
173189functions {
174- real partial_sum (int[] slice_n_redcards,
175- int start, int end,
176- int[] n_games,
177- vector rating,
178- vector beta) {
179- return binomial_logit_lpmf (slice_n_redcards |
180- n_games[start:end],
181- beta[1] + beta[2] * rating[start:end]);
190+ real partial_sum_lpmf (int[] slice_n_redcards,
191+ int start, int end,
192+ int[] n_games,
193+ vector rating,
194+ vector beta) {
195+ return binomial_logit_lupmf (slice_n_redcards |
196+ n_games[start:end],
197+ beta[1] + beta[2] * rating[start:end]);
182198 }
183199}
184200```
185201
186202The likelihood statement in the model can now be written:
187203
188204``` {stan, output.var = "", eval = FALSE}
189- target += partial_sum (n_redcards, 1, N, n_games, rating, beta); // Sum terms 1 to N in the likelihood
205+ target += partial_sum_lupmf (n_redcards, 1, N, n_games, rating, beta); // Sum terms 1 to N in the likelihood
190206```
191207
192- Equivalently it could be broken into two pieces and written like:
208+ Note that we're calling ` partial_sum_lupmf ` even though we defined the
209+ function ` partial_sum_lpmf ` . ` partial_sum_lupmf ` is implicitly defined when
210+ we write ` partial_sum_lpmf ` and is a special version of the function that
211+ will signify to all the ` _lupmf ` calls inside it that it is okay to drop
212+ constants. If we call ` partial_sum_lpmf ` , the ` binomial_logit_lupmf ` function
213+ call will not drop constants (and hence be slower).
214+
215+ Equivalently this partial sum could be broken into two pieces and written like:
193216
194217``` {stan, output.var = "", eval = FALSE}
195218int M = N / 2;
196- target += partial_sum (n_redcards[1:M], 1, M, n_games, rating, beta) // Sum terms 1 to M
197- target += partial_sum (n_redcards[(M + 1):N], M + 1, N, n_games, rating, beta); // Sum terms M + 1 to N
219+ target += partial_sum_lupmf (n_redcards[1:M], 1, M, n_games, rating, beta) // Sum terms 1 to M
220+ target += partial_sum_lupmf (n_redcards[(M + 1):N], M + 1, N, n_games, rating, beta); // Sum terms M + 1 to N
198221```
199222
200- By passing ` partial_sum ` to ` reduce_sum ` , we allow Stan to
223+ By passing ` partial_sum_lupmf ` to ` reduce_sum ` , we tell Stan to
201224automatically break up these calculations and do them in parallel.
202225
203226Notice the difference in how ` n_redcards ` is split in half (to reflect
204- which terms of the sum are being accumulated) and the rest of the arguments
227+ which terms of the sum are being accumulated) and the rest of the arguments
205228(` n_games ` , ` x ` , and ` beta ` ) are left alone. This distinction is important
206229and more fully described in the User's Guide section on
207230[ Reduce-sum] ( https://mc-stan.org/docs/2_23/stan-users-guide/reduce-sum.html ) .
@@ -211,7 +234,7 @@ likelihood:
211234
212235``` {stan, output.var = "", eval = FALSE}
213236int grainsize = 1;
214- target += reduce_sum(partial_sum , n_redcards, grainsize,
237+ target += reduce_sum(partial_sum_lupmf , n_redcards, grainsize,
215238 n_games, rating, beta);
216239```
217240
@@ -221,16 +244,20 @@ be estimated automatically (`grainsize` should be left at 1 unless specific test
221244are done to
222245[ pick a different one] ( https://mc-stan.org/docs/2_23/stan-users-guide/reduce-sum.html#reduce-sum-grainsize ) ).
223246
247+ Again, if we passed ` partial_sum_lpmf ` to ` reduce_sum ` instead of
248+ ` partial_sum_lupmf ` we would not take advantage of the performance benefits
249+ of using ` bernoulli_logit_lupmf ` .
250+
224251Making ` grainsize ` data (this makes it convenient to experiment with), the final
225252model is:
226253``` {stan, output.var = "", eval = FALSE}
227254functions {
228- real partial_sum (int[] slice_n_redcards,
229- int start, int end,
230- int[] n_games,
231- vector rating,
232- vector beta) {
233- return binomial_logit_lpmf (slice_n_redcards |
255+ real partial_sum_lpmf (int[] slice_n_redcards,
256+ int start, int end,
257+ int[] n_games,
258+ vector rating,
259+ vector beta) {
260+ return binomial_logit_lupmf (slice_n_redcards |
234261 n_games[start:end],
235262 beta[1] + beta[2] * rating[start:end]);
236263 }
@@ -250,7 +277,7 @@ model {
250277 beta[1] ~ normal(0, 10);
251278 beta[2] ~ normal(0, 1);
252279
253- target += reduce_sum(partial_sum , n_redcards, grainsize,
280+ target += reduce_sum(partial_sum_lupmf , n_redcards, grainsize,
254281 n_games, rating, beta);
255282}
256283```
@@ -311,11 +338,11 @@ to check diagnostics. `reduce_sum` is a tool for speeding up single chain
311338calculations, which can be useful for model development and on computers with
312339large numbers of cores.
313340
314- We can do a quick check that these two methods are mixing with posterior.
341+ We can do a quick check that these two methods are mixing with the ` posterior `
342+ package (https://github.com/stan-dev/posterior ).
315343When parallelizing a model is a good thing to do to make sure something is not
316344breaking:
317345``` {r}
318- remotes::install_github("jgabry/posterior")
319346library(posterior)
320347summarise_draws(bind_draws(fit0$draws(), fit1$draws(), along = "chain"))
321348```
0 commit comments