1+ #include < cmdstan/return_codes.hpp>
12#include < cmdstan/stansummary_helper.hpp>
2- #include < stan/mcmc/chains .hpp>
3+ #include < stan/mcmc/chainset .hpp>
34#include < algorithm>
45#include < fstream>
56#include < iomanip>
67#include < ios>
78#include < iostream>
89
9- double RHAT_MAX = 1.05 ;
10+ using cmdstan::return_codes;
11+
12+ double RHAT_MAX = 1.01499 ; // round to 1.01
1013
1114void diagnose_usage () {
1215 std::cout << " USAGE: diagnose <filename 1> [<filename 2> ... <filename N>]"
@@ -26,7 +29,7 @@ void diagnose_usage() {
2629int main (int argc, const char *argv[]) {
2730 if (argc == 1 ) {
2831 diagnose_usage ();
29- return 0 ;
32+ return return_codes::OK ;
3033 }
3134
3235 // Parse any arguments specifying filenames
@@ -45,49 +48,47 @@ int main(int argc, const char *argv[]) {
4548
4649 if (!filenames.size ()) {
4750 std::cout << " No valid input files, exiting." << std::endl;
48- return 0 ;
51+ return return_codes::NOT_OK ;
4952 }
5053
5154 std::cout << std::fixed << std::setprecision (2 );
5255
53- // Parse specified files
54- std::cout << " Processing csv files: " << filenames[0 ];
55- ifstream.open (filenames[0 ].c_str ());
56-
57- stan::io::stan_csv stan_csv
58- = stan::io::stan_csv_reader::parse (ifstream, &std::cout);
59- stan::mcmc::chains<> chains (stan_csv);
60- ifstream.close ();
61-
62- if (filenames.size () > 1 )
63- std::cout << " , " ;
64- else
65- std::cout << std::endl << std::endl;
66-
67- for (std::vector<std::string>::size_type chain = 1 ; chain < filenames.size ();
68- ++chain) {
69- std::cout << filenames[chain];
70- ifstream.open (filenames[chain].c_str ());
71- stan_csv = stan::io::stan_csv_reader::parse (ifstream, &std::cout);
72- chains.add (stan_csv);
73- ifstream.close ();
74- if (chain < filenames.size () - 1 )
75- std::cout << " , " ;
76- else
77- std::cout << std::endl << std::endl;
56+ std::vector<stan::io::stan_csv> csv_parsed;
57+ for (int i = 0 ; i < filenames.size (); ++i) {
58+ std::ifstream infile;
59+ std::stringstream out;
60+ stan::io::stan_csv sample;
61+ infile.open (filenames[i].c_str ());
62+ try {
63+ sample = stan::io::stan_csv_reader::parse (infile, &out);
64+ // csv_reader warnings are errors - fail fast.
65+ if (!out.str ().empty ()) {
66+ throw std::invalid_argument (out.str ());
67+ }
68+ csv_parsed.push_back (sample);
69+ } catch (const std::invalid_argument &e) {
70+ std::cout << " Cannot parse input csv file: " << filenames[i] << e.what ()
71+ << " ." << std::endl;
72+ return return_codes::NOT_OK;
73+ }
7874 }
79-
75+ stan::mcmc::chainset chains (csv_parsed);
76+ stan::io::stan_csv_metadata metadata = csv_parsed[0 ].metadata ;
77+ std::vector<std::string> param_names = csv_parsed[0 ].header ;
78+ size_t num_params = param_names.size ();
8079 int num_samples = chains.num_samples ();
8180 std::vector<std::string> bad_n_eff_names;
8281 std::vector<std::string> bad_rhat_names;
8382 bool has_errors = false ;
8483
85- for (int i = 0 ; i < chains. num_params () ; ++i) {
86- if (chains. param_name (i) == std::string (" treedepth__" )) {
84+ for (int i = 0 ; i < num_params; ++i) {
85+ if (param_names[i] == std::string (" treedepth__" )) {
8786 std::cout << " Checking sampler transitions treedepth." << std::endl;
88- int max_limit = stan_csv. metadata .max_depth ;
87+ int max_limit = metadata.max_depth ;
8988 long n_max = 0 ;
90- Eigen::VectorXd t_samples = chains.samples (i);
89+ Eigen::MatrixXd draws = chains.samples (i);
90+ Eigen::VectorXd t_samples
91+ = Eigen::Map<Eigen::VectorXd>(draws.data (), draws.size ());
9192 for (long n = 0 ; n < t_samples.size (); ++n) {
9293 if (t_samples (n) >= max_limit) {
9394 ++n_max;
@@ -109,7 +110,7 @@ int main(int argc, const char *argv[]) {
109110 std::cout << " Treedepth satisfactory for all transitions." << std::endl
110111 << std::endl;
111112 }
112- } else if (chains. param_name (i) == std::string (" divergent__" )) {
113+ } else if (param_names[i] == std::string (" divergent__" )) {
113114 std::cout << " Checking sampler transitions for divergences." << std::endl;
114115 int n_divergent = chains.samples (i).sum ();
115116 if (n_divergent > 0 ) {
@@ -129,26 +130,22 @@ int main(int argc, const char *argv[]) {
129130 std::cout << " No divergent transitions found." << std::endl
130131 << std::endl;
131132 }
132- } else if (chains. param_name (i) == std::string (" energy__" )) {
133+ } else if (param_names[i] == std::string (" energy__" )) {
133134 std::cout << " Checking E-BFMI - sampler transitions HMC potential energy."
134135 << std::endl;
135- Eigen::VectorXd e_samples = chains.samples (i);
136+ Eigen::MatrixXd draws = chains.samples (i);
137+ Eigen::VectorXd e_samples
138+ = Eigen::Map<Eigen::VectorXd>(draws.data (), draws.size ());
136139 double delta_e_sq_mean = 0 ;
137- double e_mean = 0 ;
138- double e_var = 0 ;
139- e_mean += e_samples (0 );
140- e_var += e_samples (0 ) * (e_samples (0 ) - e_mean);
140+ double e_mean = chains.mean (i);
141+ double e_var = chains.variance (i);
141142 for (long n = 1 ; n < e_samples.size (); ++n) {
142143 double e = e_samples (n);
143144 double delta_e_sq = (e - e_samples (n - 1 )) * (e - e_samples (n - 1 ));
144145 double d = delta_e_sq - delta_e_sq_mean;
145146 delta_e_sq_mean += d / n;
146147 d = e - e_mean;
147- e_mean += d / (n + 1 );
148- e_var += d * (e - e_mean);
149148 }
150-
151- e_var /= static_cast <double >(e_samples.size () - 1 );
152149 double e_bfmi = delta_e_sq_mean / e_var;
153150 double e_bfmi_threshold = 0.3 ;
154151 if (e_bfmi < e_bfmi_threshold) {
@@ -163,14 +160,16 @@ int main(int argc, const char *argv[]) {
163160 } else {
164161 std::cout << " E-BFMI satisfactory." << std::endl << std::endl;
165162 }
166- } else if (chains.param_name (i).find (" __" ) == std::string::npos) {
167- double n_eff = chains.effective_sample_size (i);
163+ } else if (param_names[i].find (" __" ) == std::string::npos) {
164+ auto [ess_bulk, ess_tail] = chains.split_rank_normalized_ess (i);
165+ double n_eff = ess_bulk < ess_tail ? ess_bulk : ess_tail;
168166 if (n_eff / num_samples < 0.001 )
169- bad_n_eff_names.push_back (chains. param_name (i) );
167+ bad_n_eff_names.push_back (param_names[i] );
170168
171- double split_rhat = chains.split_potential_scale_reduction (i);
169+ auto [rhat_bulk, rhat_tail] = chains.split_rank_normalized_rhat (i);
170+ double split_rhat = rhat_bulk > rhat_tail ? rhat_bulk : rhat_tail;
172171 if (split_rhat > RHAT_MAX)
173- bad_rhat_names.push_back (chains. param_name (i) );
172+ bad_rhat_names.push_back (param_names[i] );
174173 }
175174 }
176175 if (bad_n_eff_names.size () > 0 ) {
@@ -187,13 +186,15 @@ int main(int argc, const char *argv[]) {
187186 << " may be substantially lower than quoted." << std::endl
188187 << std::endl;
189188 } else {
190- std::cout << " Effective sample size satisfactory." << std::endl
189+ std::cout << " Rank-normalized split effective sample size satisfactory "
190+ << " for all parameters." << std::endl
191191 << std::endl;
192192 }
193193
194194 if (bad_rhat_names.size () > 0 ) {
195195 has_errors = true ;
196- std::cout << " The following parameters had split R-hat greater than "
196+ std::cout << " The following parameters had rank-normalized split R-hat "
197+ " greater than "
197198 << RHAT_MAX << " :" << std::endl;
198199 std::cout << " " ;
199200 for (size_t n = 0 ; n < bad_rhat_names.size () - 1 ; ++n)
@@ -207,13 +208,14 @@ int main(int argc, const char *argv[]) {
207208 << " effective parameterization." << std::endl
208209 << std::endl;
209210 } else {
210- std::cout << " Split R-hat values satisfactory all parameters." << std::endl
211+ std::cout << " Rank-normalized split R-hat values satisfactory "
212+ << " for all parameters." << std::endl
211213 << std::endl;
212214 }
213215 if (!has_errors)
214216 std::cout << " Processing complete, no problems detected." << std::endl;
215217 else
216218 std::cout << " Processing complete." << std::endl;
217219
218- return 0 ;
220+ return return_codes::OK ;
219221}
0 commit comments