Skip to content

Commit 4dc20ff

Browse files
authored
Merge pull request #3291 from stan-dev/fix/init-err-msgs
update error message for different init types
2 parents 4ff44b8 + 39b8333 commit 4dc20ff

File tree

5 files changed

+131
-18
lines changed

5 files changed

+131
-18
lines changed

src/stan/services/util/initialize.hpp

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,19 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
102102
model.transform_inits(context, disc_vector, unconstrained, &msg);
103103
}
104104
} catch (std::domain_error& e) {
105-
if (msg.str().length() > 0)
105+
if (msg.str().length() > 0) {
106106
logger.info(msg);
107+
}
107108
logger.warn("Rejecting initial value:");
108109
logger.warn(
109110
" Error evaluating the log probability"
110111
" at the initial value.");
111112
logger.warn(e.what());
112113
continue;
113114
} catch (std::exception& e) {
114-
if (msg.str().length() > 0)
115+
if (msg.str().length() > 0) {
115116
logger.info(msg);
117+
}
116118
logger.error(
117119
"Unrecoverable error evaluating the log probability"
118120
" at the initial value.");
@@ -127,8 +129,9 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
127129
// the parameters.
128130
log_prob = model.template log_prob<false, Jacobian>(unconstrained,
129131
disc_vector, &msg);
130-
if (msg.str().length() > 0)
132+
if (msg.str().length() > 0) {
131133
logger.info(msg);
134+
}
132135
} catch (std::domain_error& e) {
133136
if (msg.str().length() > 0)
134137
logger.info(msg);
@@ -139,8 +142,9 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
139142
logger.warn(e.what());
140143
continue;
141144
} catch (std::exception& e) {
142-
if (msg.str().length() > 0)
145+
if (msg.str().length() > 0) {
143146
logger.info(msg);
147+
}
144148
logger.error(
145149
"Unrecoverable error evaluating the log probability"
146150
" at the initial value.");
@@ -165,8 +169,9 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
165169
log_prob = stan::model::log_prob_grad<true, Jacobian>(
166170
model, unconstrained, disc_vector, gradient, &log_prob_msg);
167171
} catch (const std::exception& e) {
168-
if (log_prob_msg.str().length() > 0)
172+
if (log_prob_msg.str().length() > 0) {
169173
logger.info(log_prob_msg);
174+
}
170175
logger.error(e.what());
171176
throw;
172177
}
@@ -210,8 +215,36 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
210215
return unconstrained;
211216
}
212217
}
213-
214-
if (!is_initialized_with_zero) {
218+
if (is_fully_initialized) {
219+
logger.info("");
220+
logger.error("User-specified initialization failed.");
221+
logger.error(
222+
" Try specifying new initial values,"
223+
" using partially specialized initialization,"
224+
" reducing the range of constrained values,"
225+
" or reparameterizing the model.");
226+
} else if (any_initialized) {
227+
logger.info("");
228+
std::stringstream msg;
229+
msg << "Partial user-specified initialization failed. "
230+
"Initialization of non user specified parameters "
231+
"between (-"
232+
<< init_radius << ", " << init_radius << ") failed after"
233+
<< " " << MAX_INIT_TRIES << " attempts. ";
234+
logger.error(msg);
235+
logger.error(
236+
" Try specifying full initial values,"
237+
" reducing the range of constrained values,"
238+
" or reparameterizing the model.");
239+
} else if (is_initialized_with_zero) {
240+
logger.info("");
241+
logger.error("Initial values of 0 failed to initialize.");
242+
logger.error(
243+
" Try specifying new initial values,"
244+
" using partially specialized initialization,"
245+
" reducing the range of constrained values,"
246+
" or reparameterizing the model.");
247+
} else {
215248
logger.info("");
216249
std::stringstream msg;
217250
msg << "Initialization between (-" << init_radius << ", " << init_radius
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
parameters {
2+
array[2] real<lower=-10, upper=10> y;
3+
}
4+
model {
5+
reject("");
6+
}

src/test/unit/services/instrumented_callbacks.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,23 @@ class instrumented_logger : public stan::callbacks::logger {
282282
return count;
283283
}
284284

285+
public:
286+
std::vector<std::string> return_all_logs() {
287+
std::vector<std::string> all_logs;
288+
all_logs.reserve(debug_.size() + info_.size() + warn_.size() + error_.size()
289+
+ fatal_.size() + 5);
290+
all_logs.emplace_back("DEBUG");
291+
all_logs.insert(all_logs.end(), debug_.begin(), debug_.end());
292+
all_logs.emplace_back("INFO");
293+
all_logs.insert(all_logs.end(), info_.begin(), info_.end());
294+
all_logs.emplace_back("WARN");
295+
all_logs.insert(all_logs.end(), warn_.begin(), warn_.end());
296+
all_logs.emplace_back("ERROR");
297+
all_logs.insert(all_logs.end(), error_.begin(), error_.end());
298+
all_logs.emplace_back("FATAL");
299+
all_logs.insert(all_logs.end(), fatal_.begin(), fatal_.end());
300+
return all_logs;
301+
}
285302
std::vector<std::string> debug_;
286303
std::vector<std::string> info_;
287304
std::vector<std::string> warn_;
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#include <stan/services/util/initialize.hpp>
2+
#include <stan/services/util/create_rng.hpp>
3+
#include <stan/io/empty_var_context.hpp>
4+
#include <stan/io/array_var_context.hpp>
5+
#include <stan/services/util/create_rng.hpp>
6+
#include <stan/callbacks/stream_writer.hpp>
7+
#include <stan/callbacks/stream_logger.hpp>
8+
#include <test/test-models/good/services/test_fail.hpp>
9+
#include <test/unit/util.hpp>
10+
#include <test/unit/services/instrumented_callbacks.hpp>
11+
#include <gtest/gtest.h>
12+
#include <sstream>
13+
14+
class ServicesUtilInitialize : public testing::Test {
15+
public:
16+
ServicesUtilInitialize()
17+
: model(empty_context, 12345, &model_ss),
18+
message(message_ss),
19+
rng(stan::services::util::create_rng(0, 1)) {}
20+
21+
stan_model model;
22+
stan::io::empty_var_context empty_context;
23+
std::stringstream model_ss;
24+
std::stringstream message_ss;
25+
stan::callbacks::stream_writer message;
26+
stan::test::unit::instrumented_logger logger;
27+
stan::test::unit::instrumented_writer init;
28+
stan::rng_t rng;
29+
};
30+
31+
TEST_F(ServicesUtilInitialize, model_throws__full_init) {
32+
std::vector<std::string> names_r;
33+
std::vector<double> values_r;
34+
std::vector<std::vector<size_t> > dim_r;
35+
names_r.push_back("y");
36+
values_r.push_back(6.35149); // 1.5 unconstrained: -10 + 20 * inv.logit(1.5)
37+
values_r.push_back(-2.449187); // -0.5 unconstrained
38+
std::vector<size_t> d;
39+
d.push_back(2);
40+
dim_r.push_back(d);
41+
stan::io::array_var_context init_context(names_r, values_r, dim_r);
42+
43+
double init_radius = 2;
44+
bool print_timing = false;
45+
EXPECT_THROW(
46+
stan::services::util::initialize(model, init_context, rng, init_radius,
47+
print_timing, logger, init),
48+
std::domain_error);
49+
/* Uncomment to print all logs
50+
auto logs = logger.return_all_logs();
51+
for (auto&& m : logs) {
52+
std::cout << m << std::endl;
53+
}
54+
*/
55+
EXPECT_EQ(6, logger.call_count());
56+
EXPECT_EQ(3, logger.call_count_warn());
57+
EXPECT_EQ(0, logger.find_warn("throwing within log_prob"));
58+
}

src/test/unit/services/util/initialize_test.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#include <stan/services/util/initialize.hpp>
22
#include <stan/services/util/create_rng.hpp>
3-
#include <gtest/gtest.h>
4-
#include <test/unit/util.hpp>
5-
#include <stan/callbacks/stream_writer.hpp>
6-
#include <stan/callbacks/stream_logger.hpp>
7-
#include <sstream>
8-
#include <test/test-models/good/services/test_lp.hpp>
93
#include <stan/io/empty_var_context.hpp>
104
#include <stan/io/array_var_context.hpp>
115
#include <stan/services/util/create_rng.hpp>
6+
#include <stan/callbacks/stream_writer.hpp>
7+
#include <stan/callbacks/stream_logger.hpp>
8+
#include <test/test-models/good/services/test_lp.hpp>
9+
#include <test/unit/util.hpp>
1210
#include <test/unit/services/instrumented_callbacks.hpp>
11+
#include <gtest/gtest.h>
12+
#include <sstream>
1313

1414
class ServicesUtilInitialize : public testing::Test {
1515
public:
@@ -28,7 +28,7 @@ class ServicesUtilInitialize : public testing::Test {
2828
stan::rng_t rng;
2929
};
3030

31-
TEST_F(ServicesUtilInitialize, radius_zero__print_false) {
31+
TEST_F(ServicesUtilInitialize, radius_zero_print_false) {
3232
std::vector<double> params;
3333

3434
double init_radius = 0;
@@ -250,7 +250,7 @@ class mock_throwing_model : public stan::model::prob_grad {
250250

251251
} // namespace test
252252

253-
TEST_F(ServicesUtilInitialize, model_throws__radius_zero) {
253+
TEST_F(ServicesUtilInitialize, model_throws_radius_zero) {
254254
test::mock_throwing_model throwing_model;
255255

256256
double init_radius = 0;
@@ -259,8 +259,7 @@ TEST_F(ServicesUtilInitialize, model_throws__radius_zero) {
259259
stan::services::util::initialize(throwing_model, empty_context, rng,
260260
init_radius, print_timing, logger, init),
261261
std::domain_error);
262-
263-
EXPECT_EQ(3, logger.call_count());
262+
EXPECT_EQ(6, logger.call_count());
264263
EXPECT_EQ(3, logger.call_count_warn());
265264
EXPECT_EQ(1, logger.find_warn("throwing within log_prob"));
266265
}
@@ -533,7 +532,7 @@ TEST_F(ServicesUtilInitialize, model_throws_in_write_array__radius_zero) {
533532
init_radius, print_timing, logger, init),
534533
std::domain_error);
535534

536-
EXPECT_EQ(3, logger.call_count());
535+
EXPECT_EQ(6, logger.call_count());
537536
EXPECT_EQ(3, logger.call_count_warn());
538537
EXPECT_EQ(1, logger.find_warn("throwing within write_array"));
539538
}

0 commit comments

Comments
 (0)