Skip to content

Commit 9048555

Browse files
authored
Merge pull request #3313 from stan-dev/feature/3299-chainset
Feature/3299 chainset
2 parents 2550bbd + 68a3cfb commit 9048555

34 files changed

+10345
-2483
lines changed

src/stan/analyze/mcmc/compute_effective_sample_size.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
namespace stan {
1313
namespace analyze {
1414
/**
15+
* \deprecated use split_rank_normalized_ess instead
16+
*
1517
* Computes the effective sample size (ESS) for the specified
1618
* parameter across all kept samples. The value returned is the
1719
* minimum of ESS and the number_total_draws *
@@ -138,6 +140,8 @@ inline double compute_effective_sample_size(std::vector<const double*> draws,
138140
}
139141

140142
/**
143+
* \deprecated use split_rank_normalized_ess instead
144+
*
141145
* Computes the effective sample size (ESS) for the specified
142146
* parameter across all kept samples. The value returned is the
143147
* minimum of ESS and the number_total_draws *
@@ -164,6 +168,8 @@ inline double compute_effective_sample_size(std::vector<const double*> draws,
164168
}
165169

166170
/**
171+
* \deprecated use split_rank_normalized_ess instead
172+
*
167173
* Computes the split effective sample size (ESS) for the specified
168174
* parameter across all kept samples. The value returned is the
169175
* minimum of ESS and the number_total_draws *
@@ -199,6 +205,8 @@ inline double compute_split_effective_sample_size(
199205
}
200206

201207
/**
208+
* \deprecated use split_rank_normalized_ess instead
209+
*
202210
* Computes the split effective sample size (ESS) for the specified
203211
* parameter across all kept samples. The value returned is the
204212
* minimum of ESS and the number_total_draws *

src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ namespace stan {
1818
namespace analyze {
1919

2020
/**
21+
* \deprecated use `rhat` instead
22+
*
2123
* Computes the potential scale reduction (Rhat) for the specified
2224
* parameter across all kept samples.
2325
*
@@ -102,6 +104,8 @@ inline double compute_potential_scale_reduction(
102104
}
103105

104106
/**
107+
* \deprecated use split_rank_normalized_rhat instead
108+
*
105109
* Computes the potential scale reduction (Rhat) for the specified
106110
* parameter across all kept samples.
107111
*
@@ -125,6 +129,8 @@ inline double compute_potential_scale_reduction(
125129
}
126130

127131
/**
132+
* \deprecated use split_rank_normalized_rhat instead
133+
*
128134
* Computes the split potential scale reduction (Rhat) for the
129135
* specified parameter across all kept samples. When the number of
130136
* total draws N is odd, the (N+1)/2th draw is ignored.
@@ -157,6 +163,8 @@ inline double compute_split_potential_scale_reduction(
157163
}
158164

159165
/**
166+
* \deprecated use split_rank_normalized_rhat instead
167+
*
160168
* Computes the split potential scale reduction (Rhat) for the
161169
* specified parameter across all kept samples. When the number of
162170
* total draws N is odd, the (N+1)/2th draw is ignored.

src/stan/analyze/mcmc/ess.hpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#ifndef STAN_ANALYZE_MCMC_ESS_HPP
2+
#define STAN_ANALYZE_MCMC_ESS_HPP
3+
4+
#include <stan/math/prim.hpp>
5+
#include <stan/analyze/mcmc/autocovariance.hpp>
6+
#include <algorithm>
7+
#include <cmath>
8+
#include <vector>
9+
#include <limits>
10+
11+
namespace stan {
12+
namespace analyze {
13+
14+
/**
15+
* Computes the effective sample size (ESS) for the specified
16+
* parameter across all chains. The number of draws per chain must be > 3,
17+
* and the values across all draws must be finite and not constant.
18+
* See https://arxiv.org/abs/1903.08008, section 3.2 for discussion.
19+
*
20+
* Sample autocovariance is computed using the implementation in this namespace
21+
* which normalizes lag-k autocorrelation estimators by N instead of (N - k),
22+
* yielding biased but more stable estimators as discussed in Geyer (1992); see
23+
* https://projecteuclid.org/euclid.ss/1177011137.
24+
*
25+
* @param chains matrix of draws across all chains
26+
* @return effective sample size for the specified parameter
27+
*/
28+
double ess(const Eigen::MatrixXd& chains) {
29+
const Eigen::Index num_chains = chains.cols();
30+
const Eigen::Index draws_per_chain = chains.rows();
31+
Eigen::MatrixXd acov(draws_per_chain, num_chains);
32+
Eigen::VectorXd chain_mean(num_chains);
33+
Eigen::VectorXd chain_var(num_chains);
34+
35+
// compute the per-chain autocovariance
36+
for (size_t i = 0; i < num_chains; ++i) {
37+
chain_mean(i) = chains.col(i).mean();
38+
Eigen::Map<const Eigen::VectorXd> draw_col(chains.col(i).data(),
39+
draws_per_chain);
40+
Eigen::VectorXd cov_col(draws_per_chain);
41+
autocovariance<double>(draw_col, cov_col);
42+
acov.col(i) = cov_col;
43+
chain_var(i) = cov_col(0) * draws_per_chain / (draws_per_chain - 1);
44+
}
45+
46+
// compute var_plus, eqn (3)
47+
double w_chain_var = math::mean(chain_var); // W (within chain var)
48+
double var_plus
49+
= w_chain_var * (draws_per_chain - 1) / draws_per_chain; // \hat{var}^{+}
50+
if (num_chains > 1) {
51+
var_plus += math::variance(chain_mean); // B (between chain var)
52+
}
53+
54+
// Geyer's initial positive sequence, eqn (11)
55+
Eigen::VectorXd rho_hat_t = Eigen::VectorXd::Zero(draws_per_chain);
56+
double rho_hat_even = 1.0;
57+
rho_hat_t(0) = rho_hat_even; // lag 0
58+
59+
Eigen::VectorXd acov_t(num_chains);
60+
for (size_t i = 0; i < num_chains; ++i) {
61+
acov_t(i) = acov(1, i);
62+
}
63+
double rho_hat_odd = 1 - (w_chain_var - acov_t.mean()) / var_plus;
64+
rho_hat_t(1) = rho_hat_odd; // lag 1
65+
66+
// compute autocorrelation at lag t for pair (t, t+1)
67+
// paired autocorrelation is guaranteed to be positive, monotone and convex
68+
size_t t = 1;
69+
while (t < draws_per_chain - 4 && (rho_hat_even + rho_hat_odd > 0)
70+
&& !std::isnan(rho_hat_even + rho_hat_odd)) {
71+
for (size_t i = 0; i < num_chains; ++i) {
72+
acov_t(i) = acov.col(i)(t + 1);
73+
}
74+
rho_hat_even = 1 - (w_chain_var - acov_t.mean()) / var_plus;
75+
for (size_t i = 0; i < num_chains; ++i) {
76+
acov_t(i) = acov.col(i)(t + 2);
77+
}
78+
rho_hat_odd = 1 - (w_chain_var - acov_t.mean()) / var_plus;
79+
if ((rho_hat_even + rho_hat_odd) >= 0) {
80+
rho_hat_t(t + 1) = rho_hat_even;
81+
rho_hat_t(t + 2) = rho_hat_odd;
82+
}
83+
// convert initial positive sequence into an initial monotone sequence
84+
if (rho_hat_t(t + 1) + rho_hat_t(t + 2) > rho_hat_t(t - 1) + rho_hat_t(t)) {
85+
rho_hat_t(t + 1) = (rho_hat_t(t - 1) + rho_hat_t(t)) / 2;
86+
rho_hat_t(t + 2) = rho_hat_t(t + 1);
87+
}
88+
t += 2;
89+
}
90+
91+
auto max_t = t; // max lag, used for truncation
92+
// see discussion p. 8, par "In extreme antithetic cases, "
93+
if (rho_hat_even > 0) {
94+
rho_hat_t(max_t + 1) = rho_hat_even;
95+
}
96+
97+
double draws_total = num_chains * draws_per_chain;
98+
// eqn (13): Geyer's truncation rule, w/ modification
99+
double tau_hat = -1 + 2 * rho_hat_t.head(max_t).sum() + rho_hat_t(max_t + 1);
100+
// safety check for negative values and with max ess equal to ess*log10(ess)
101+
tau_hat = std::max(tau_hat, 1 / std::log10(draws_total));
102+
return (draws_total / tau_hat);
103+
}
104+
105+
} // namespace analyze
106+
} // namespace stan
107+
108+
#endif

src/stan/analyze/mcmc/mcse.hpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#ifndef STAN_ANALYZE_MCMC_MCSE_HPP
2+
#define STAN_ANALYZE_MCMC_MCSE_HPP
3+
4+
#include <stan/analyze/mcmc/check_chains.hpp>
5+
#include <stan/analyze/mcmc/split_chains.hpp>
6+
#include <stan/analyze/mcmc/ess.hpp>
7+
#include <stan/math/prim.hpp>
8+
#include <cmath>
9+
#include <limits>
10+
#include <utility>
11+
12+
namespace stan {
13+
namespace analyze {
14+
15+
/**
16+
* Computes the mean Monte Carlo error estimate for the central 90% interval.
17+
* See https://arxiv.org/abs/1903.08008, section 4.4.
18+
* Follows implementation in the R posterior package.
19+
*
20+
* @param chains matrix of draws across all chains
21+
* @return mcse
22+
*/
23+
inline double mcse_mean(const Eigen::MatrixXd& chains) {
24+
const Eigen::Index num_draws = chains.rows();
25+
if (chains.rows() < 4 || !is_finite_and_varies(chains))
26+
return std::numeric_limits<double>::quiet_NaN();
27+
28+
double sample_var
29+
= (chains.array() - chains.mean()).square().sum() / (chains.size() - 1);
30+
return std::sqrt(sample_var / ess(chains));
31+
}
32+
33+
/**
34+
* Computes the standard deviation of the Monte Carlo error estimate
35+
* https://arxiv.org/abs/1903.08008, section 4.4.
36+
* Follows implementation in the R posterior package:
37+
* https://github.com/stan-dev/posterior/blob/98bf52329d68f3307ac4ecaaea659276ee1de8df/R/convergence.R#L478-L496
38+
*
39+
* @param chains matrix of draws across all chains
40+
* @return mcse
41+
*/
42+
inline double mcse_sd(const Eigen::MatrixXd& chains) {
43+
if (chains.rows() < 4 || !is_finite_and_varies(chains))
44+
return std::numeric_limits<double>::quiet_NaN();
45+
46+
// center the data, take abs value
47+
Eigen::MatrixXd draws_ctr = (chains.array() - chains.mean()).abs().matrix();
48+
49+
// posterior pkg fn `ess_mean` computes on split chains
50+
double ess_mean = ess(split_chains(draws_ctr));
51+
52+
// estimated variance (2nd moment)
53+
double Evar = draws_ctr.array().square().mean();
54+
55+
// variance of variance, adjusted for ESS
56+
double fourth_moment = draws_ctr.array().pow(4).mean();
57+
double varvar = (fourth_moment - std::pow(Evar, 2)) / ess_mean;
58+
59+
// variance of standard deviation - use Taylor series approximation
60+
double varsd = varvar / Evar / 4.0;
61+
return std::sqrt(varsd);
62+
}
63+
64+
} // namespace analyze
65+
} // namespace stan
66+
67+
#endif

src/stan/analyze/mcmc/rank_normalization.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ namespace stan {
1212
namespace analyze {
1313

1414
/**
15-
* Computes normalized average ranks for pooled draws. Normal scores computed
16-
* using inverse normal transformation and a fractional offset. Based on paper
15+
* Computes normalized average ranks for pooled draws. The values across
16+
* all draws be finite and not constant. Normal scores computed using
17+
* inverse normal transformation and a fractional offset. Based on paper
1718
* https://arxiv.org/abs/1903.08008
1819
*
1920
* @param chains matrix of draws, one column per chain

src/stan/analyze/mcmc/rhat.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#ifndef STAN_ANALYZE_MCMC_RHAT_HPP
2+
#define STAN_ANALYZE_MCMC_RHAT_HPP
3+
4+
#include <stan/math/prim.hpp>
5+
#include <algorithm>
6+
#include <cmath>
7+
#include <vector>
8+
#include <limits>
9+
10+
namespace stan {
11+
namespace analyze {
12+
13+
/**
14+
* Computes square root of marginal posterior variance of the estimand by the
15+
* weighted average of within-chain variance W and between-chain variance B.
16+
*
17+
* @param chains stores chains in columns
18+
* @return square root of ((N-1)/N)W + B/N
19+
*/
20+
inline double rhat(const Eigen::MatrixXd& chains) {
21+
const Eigen::Index num_chains = chains.cols();
22+
const Eigen::Index num_draws = chains.rows();
23+
24+
Eigen::RowVectorXd within_chain_means = chains.colwise().mean();
25+
double across_chain_mean = within_chain_means.mean();
26+
double between_variance
27+
= num_draws
28+
* (within_chain_means.array() - across_chain_mean).square().sum()
29+
/ (num_chains - 1);
30+
double within_variance =
31+
// Divide each row by chains and get sum of squares for each chain
32+
// (getting a vector back)
33+
((chains.rowwise() - within_chain_means)
34+
.array()
35+
.square()
36+
.colwise()
37+
// divide each sum of square by num_draws, sum the sum of squares,
38+
// and divide by num chains
39+
.sum()
40+
/ (num_draws - 1.0))
41+
.sum()
42+
/ num_chains;
43+
44+
return sqrt((between_variance / within_variance + num_draws - 1) / num_draws);
45+
}
46+
47+
} // namespace analyze
48+
} // namespace stan
49+
50+
#endif

src/stan/analyze/mcmc/split_chains.hpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@ namespace analyze {
2020
inline Eigen::MatrixXd split_chains(const std::vector<Eigen::MatrixXd>& chains,
2121
const int index) {
2222
size_t num_chains = chains.size();
23-
size_t num_samples = chains[0].rows();
24-
size_t half = std::floor(num_samples / 2.0);
23+
size_t num_draws = chains[0].rows();
24+
size_t half = std::floor(num_draws / 2.0);
25+
size_t tail_start = std::floor((num_draws + 1) / 2.0);
2526

2627
Eigen::MatrixXd split_draws_matrix(half, num_chains * 2);
2728
int split_i = 0;
2829
for (std::size_t i = 0; i < num_chains; ++i) {
2930
Eigen::Map<const Eigen::VectorXd> head_block(chains[i].col(index).data(),
3031
half);
3132
Eigen::Map<const Eigen::VectorXd> tail_block(
32-
chains[i].col(index).data() + half, half);
33+
chains[i].col(index).data() + tail_start, half);
3334

3435
split_draws_matrix.col(split_i) = head_block;
3536
split_draws_matrix.col(split_i + 1) = tail_block;
@@ -47,15 +48,16 @@ inline Eigen::MatrixXd split_chains(const std::vector<Eigen::MatrixXd>& chains,
4748
*/
4849
inline Eigen::MatrixXd split_chains(const Eigen::MatrixXd& samples) {
4950
size_t num_chains = samples.cols();
50-
size_t num_samples = samples.rows();
51-
size_t half = std::floor(num_samples / 2.0);
51+
size_t num_draws = samples.rows();
52+
size_t half = std::floor(num_draws / 2.0);
53+
size_t tail_start = std::floor((num_draws + 1) / 2.0);
5254

5355
Eigen::MatrixXd split_draws_matrix(half, num_chains * 2);
5456
int split_i = 0;
5557
for (std::size_t i = 0; i < num_chains; ++i) {
5658
Eigen::Map<const Eigen::VectorXd> head_block(samples.col(i).data(), half);
57-
Eigen::Map<const Eigen::VectorXd> tail_block(samples.col(i).data() + half,
58-
half);
59+
Eigen::Map<const Eigen::VectorXd> tail_block(
60+
samples.col(i).data() + tail_start, half);
5961

6062
split_draws_matrix.col(split_i) = head_block;
6163
split_draws_matrix.col(split_i + 1) = tail_block;

0 commit comments

Comments
 (0)