Skip to content

Commit 2dde61c

Browse files
authored
Merge pull request #1301 from stan-dev/stansummary-new-stats
Update stansummary to use split, rank-normalized Rhat, ESS
2 parents 82c51d4 + a23af88 commit 2dde61c

File tree

11 files changed

+189
-205
lines changed

11 files changed

+189
-205
lines changed

src/cmdstan/diagnose.cpp

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
#include <cmdstan/return_codes.hpp>
12
#include <cmdstan/stansummary_helper.hpp>
2-
#include <stan/mcmc/chains.hpp>
3+
#include <stan/mcmc/chainset.hpp>
34
#include <algorithm>
45
#include <fstream>
56
#include <iomanip>
67
#include <ios>
78
#include <iostream>
89

9-
double RHAT_MAX = 1.05;
10+
using cmdstan::return_codes;
11+
12+
double RHAT_MAX = 1.01499; // round to 1.01
1013

1114
void diagnose_usage() {
1215
std::cout << "USAGE: diagnose <filename 1> [<filename 2> ... <filename N>]"
@@ -26,7 +29,7 @@ void diagnose_usage() {
2629
int main(int argc, const char *argv[]) {
2730
if (argc == 1) {
2831
diagnose_usage();
29-
return 0;
32+
return return_codes::OK;
3033
}
3134

3235
// Parse any arguments specifying filenames
@@ -45,49 +48,47 @@ int main(int argc, const char *argv[]) {
4548

4649
if (!filenames.size()) {
4750
std::cout << "No valid input files, exiting." << std::endl;
48-
return 0;
51+
return return_codes::NOT_OK;
4952
}
5053

5154
std::cout << std::fixed << std::setprecision(2);
5255

53-
// Parse specified files
54-
std::cout << "Processing csv files: " << filenames[0];
55-
ifstream.open(filenames[0].c_str());
56-
57-
stan::io::stan_csv stan_csv
58-
= stan::io::stan_csv_reader::parse(ifstream, &std::cout);
59-
stan::mcmc::chains<> chains(stan_csv);
60-
ifstream.close();
61-
62-
if (filenames.size() > 1)
63-
std::cout << ", ";
64-
else
65-
std::cout << std::endl << std::endl;
66-
67-
for (std::vector<std::string>::size_type chain = 1; chain < filenames.size();
68-
++chain) {
69-
std::cout << filenames[chain];
70-
ifstream.open(filenames[chain].c_str());
71-
stan_csv = stan::io::stan_csv_reader::parse(ifstream, &std::cout);
72-
chains.add(stan_csv);
73-
ifstream.close();
74-
if (chain < filenames.size() - 1)
75-
std::cout << ", ";
76-
else
77-
std::cout << std::endl << std::endl;
56+
std::vector<stan::io::stan_csv> csv_parsed;
57+
for (int i = 0; i < filenames.size(); ++i) {
58+
std::ifstream infile;
59+
std::stringstream out;
60+
stan::io::stan_csv sample;
61+
infile.open(filenames[i].c_str());
62+
try {
63+
sample = stan::io::stan_csv_reader::parse(infile, &out);
64+
// csv_reader warnings are errors - fail fast.
65+
if (!out.str().empty()) {
66+
throw std::invalid_argument(out.str());
67+
}
68+
csv_parsed.push_back(sample);
69+
} catch (const std::invalid_argument &e) {
70+
std::cout << "Cannot parse input csv file: " << filenames[i] << e.what()
71+
<< "." << std::endl;
72+
return return_codes::NOT_OK;
73+
}
7874
}
79-
75+
stan::mcmc::chainset chains(csv_parsed);
76+
stan::io::stan_csv_metadata metadata = csv_parsed[0].metadata;
77+
std::vector<std::string> param_names = csv_parsed[0].header;
78+
size_t num_params = param_names.size();
8079
int num_samples = chains.num_samples();
8180
std::vector<std::string> bad_n_eff_names;
8281
std::vector<std::string> bad_rhat_names;
8382
bool has_errors = false;
8483

85-
for (int i = 0; i < chains.num_params(); ++i) {
86-
if (chains.param_name(i) == std::string("treedepth__")) {
84+
for (int i = 0; i < num_params; ++i) {
85+
if (param_names[i] == std::string("treedepth__")) {
8786
std::cout << "Checking sampler transitions treedepth." << std::endl;
88-
int max_limit = stan_csv.metadata.max_depth;
87+
int max_limit = metadata.max_depth;
8988
long n_max = 0;
90-
Eigen::VectorXd t_samples = chains.samples(i);
89+
Eigen::MatrixXd draws = chains.samples(i);
90+
Eigen::VectorXd t_samples
91+
= Eigen::Map<Eigen::VectorXd>(draws.data(), draws.size());
9192
for (long n = 0; n < t_samples.size(); ++n) {
9293
if (t_samples(n) >= max_limit) {
9394
++n_max;
@@ -109,7 +110,7 @@ int main(int argc, const char *argv[]) {
109110
std::cout << "Treedepth satisfactory for all transitions." << std::endl
110111
<< std::endl;
111112
}
112-
} else if (chains.param_name(i) == std::string("divergent__")) {
113+
} else if (param_names[i] == std::string("divergent__")) {
113114
std::cout << "Checking sampler transitions for divergences." << std::endl;
114115
int n_divergent = chains.samples(i).sum();
115116
if (n_divergent > 0) {
@@ -129,26 +130,22 @@ int main(int argc, const char *argv[]) {
129130
std::cout << "No divergent transitions found." << std::endl
130131
<< std::endl;
131132
}
132-
} else if (chains.param_name(i) == std::string("energy__")) {
133+
} else if (param_names[i] == std::string("energy__")) {
133134
std::cout << "Checking E-BFMI - sampler transitions HMC potential energy."
134135
<< std::endl;
135-
Eigen::VectorXd e_samples = chains.samples(i);
136+
Eigen::MatrixXd draws = chains.samples(i);
137+
Eigen::VectorXd e_samples
138+
= Eigen::Map<Eigen::VectorXd>(draws.data(), draws.size());
136139
double delta_e_sq_mean = 0;
137-
double e_mean = 0;
138-
double e_var = 0;
139-
e_mean += e_samples(0);
140-
e_var += e_samples(0) * (e_samples(0) - e_mean);
140+
double e_mean = chains.mean(i);
141+
double e_var = chains.variance(i);
141142
for (long n = 1; n < e_samples.size(); ++n) {
142143
double e = e_samples(n);
143144
double delta_e_sq = (e - e_samples(n - 1)) * (e - e_samples(n - 1));
144145
double d = delta_e_sq - delta_e_sq_mean;
145146
delta_e_sq_mean += d / n;
146147
d = e - e_mean;
147-
e_mean += d / (n + 1);
148-
e_var += d * (e - e_mean);
149148
}
150-
151-
e_var /= static_cast<double>(e_samples.size() - 1);
152149
double e_bfmi = delta_e_sq_mean / e_var;
153150
double e_bfmi_threshold = 0.3;
154151
if (e_bfmi < e_bfmi_threshold) {
@@ -163,14 +160,16 @@ int main(int argc, const char *argv[]) {
163160
} else {
164161
std::cout << "E-BFMI satisfactory." << std::endl << std::endl;
165162
}
166-
} else if (chains.param_name(i).find("__") == std::string::npos) {
167-
double n_eff = chains.effective_sample_size(i);
163+
} else if (param_names[i].find("__") == std::string::npos) {
164+
auto [ess_bulk, ess_tail] = chains.split_rank_normalized_ess(i);
165+
double n_eff = ess_bulk < ess_tail ? ess_bulk : ess_tail;
168166
if (n_eff / num_samples < 0.001)
169-
bad_n_eff_names.push_back(chains.param_name(i));
167+
bad_n_eff_names.push_back(param_names[i]);
170168

171-
double split_rhat = chains.split_potential_scale_reduction(i);
169+
auto [rhat_bulk, rhat_tail] = chains.split_rank_normalized_rhat(i);
170+
double split_rhat = rhat_bulk > rhat_tail ? rhat_bulk : rhat_tail;
172171
if (split_rhat > RHAT_MAX)
173-
bad_rhat_names.push_back(chains.param_name(i));
172+
bad_rhat_names.push_back(param_names[i]);
174173
}
175174
}
176175
if (bad_n_eff_names.size() > 0) {
@@ -187,13 +186,15 @@ int main(int argc, const char *argv[]) {
187186
<< " may be substantially lower than quoted." << std::endl
188187
<< std::endl;
189188
} else {
190-
std::cout << "Effective sample size satisfactory." << std::endl
189+
std::cout << "Rank-normalized split effective sample size satisfactory "
190+
<< "for all parameters." << std::endl
191191
<< std::endl;
192192
}
193193

194194
if (bad_rhat_names.size() > 0) {
195195
has_errors = true;
196-
std::cout << "The following parameters had split R-hat greater than "
196+
std::cout << "The following parameters had rank-normalized split R-hat "
197+
"greater than "
197198
<< RHAT_MAX << ":" << std::endl;
198199
std::cout << " ";
199200
for (size_t n = 0; n < bad_rhat_names.size() - 1; ++n)
@@ -207,13 +208,14 @@ int main(int argc, const char *argv[]) {
207208
<< " effective parameterization." << std::endl
208209
<< std::endl;
209210
} else {
210-
std::cout << "Split R-hat values satisfactory all parameters." << std::endl
211+
std::cout << "Rank-normalized split R-hat values satisfactory "
212+
<< "for all parameters." << std::endl
211213
<< std::endl;
212214
}
213215
if (!has_errors)
214216
std::cout << "Processing complete, no problems detected." << std::endl;
215217
else
216218
std::cout << "Processing complete." << std::endl;
217219

218-
return 0;
220+
return return_codes::OK;
219221
}

src/cmdstan/stansummary.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <cmdstan/return_codes.hpp>
22
#include <cmdstan/stansummary_helper.hpp>
3-
#include <stan/mcmc/chains.hpp>
43
#include <stan/io/ends_with.hpp>
54
#include <algorithm>
65
#include <fstream>
@@ -34,7 +33,7 @@ Example: stansummary model_chain_1.csv model_chain_2.csv
3433
-c, --csv_filename [file] Write statistics to a csv file.
3534
-h, --help Produce help message, then exit.
3635
-p, --percentiles [values] Percentiles to report as ordered set of
37-
comma-separated numbers from (0.1,99.9), inclusive.
36+
comma-separated numbers from (0.0,100.0), inclusive.
3837
Default is 5,50,95.
3938
-s, --sig_figs [n] Significant figures reported. Default is 2.
4039
Must be an integer from (1, 18), inclusive.
@@ -140,8 +139,8 @@ Example: stansummary model_chain_1.csv model_chain_2.csv
140139

141140
// check for stan csv file parse errors written to output stream
142141
std::stringstream cout_ss;
143-
stan::mcmc::chains<> chains = parse_csv_files(
144-
filenames, metadata, warmup_times, sampling_times, thin, &std::cout);
142+
auto chains = parse_csv_files(filenames, metadata, warmup_times,
143+
sampling_times, thin, &std::cout);
145144

146145
// Get column headers for sampler, model params
147146
size_t max_name_length = 0;

0 commit comments

Comments
 (0)