Skip to content

Commit 2f8ae28

Browse files
authored
Merge pull request #3303 from stan-dev/bugfix/3071-fix-unit-test
Fix and add unit tests for issue 3071 - adaptive sampler runs on models w/ zero params.
2 parents cd8b2e0 + e546442 commit 2f8ae28

File tree

3 files changed

+193
-2
lines changed

3 files changed

+193
-2
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
transformed data {
2+
int N = 2;
3+
}
4+
generated quantities {
5+
real theta = beta_rng(1, 1);
6+
real eta = beta_rng(10, 10);
7+
}
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#include <stan/services/sample/hmc_nuts_dense_e.hpp>
2+
#include <gtest/gtest.h>
3+
#include <stan/io/empty_var_context.hpp>
4+
#include <test/test-models/good/services/zero_params.hpp>
5+
#include <test/unit/services/instrumented_callbacks.hpp>
6+
#include <iostream>
7+
8+
auto&& blah = stan::math::init_threadpool_tbb();
9+
10+
static constexpr size_t num_chains = 4;
11+
class ServicesSampleHMCNutsDenseENoParams : public testing::Test {
12+
public:
13+
ServicesSampleHMCNutsDenseENoParams() : model(data_context, 0, &model_log) {
14+
for (int i = 0; i < num_chains; ++i) {
15+
init.push_back(stan::test::unit::instrumented_writer{});
16+
parameter.push_back(stan::test::unit::instrumented_writer{});
17+
diagnostic.push_back(stan::test::unit::instrumented_writer{});
18+
context.push_back(std::make_shared<stan::io::empty_var_context>());
19+
}
20+
}
21+
stan::io::empty_var_context data_context;
22+
std::stringstream model_log;
23+
stan::test::unit::instrumented_logger logger;
24+
std::vector<stan::test::unit::instrumented_writer> init;
25+
std::vector<stan::test::unit::instrumented_writer> parameter;
26+
std::vector<stan::test::unit::instrumented_writer> diagnostic;
27+
std::vector<std::shared_ptr<stan::io::empty_var_context>> context;
28+
stan_model model;
29+
};
30+
31+
TEST_F(ServicesSampleHMCNutsDenseENoParams, call_count) {
32+
unsigned int random_seed = 0;
33+
unsigned int chain = 1;
34+
double init_radius = 0;
35+
int num_warmup = 200;
36+
int num_samples = 400;
37+
int num_thin = 5;
38+
bool save_warmup = true;
39+
int refresh = 0;
40+
double stepsize = 0.1;
41+
double stepsize_jitter = 0;
42+
int max_depth = 8;
43+
double delta = .1;
44+
double gamma = .1;
45+
double kappa = .1;
46+
double t0 = .1;
47+
unsigned int init_buffer = 50;
48+
unsigned int term_buffer = 50;
49+
unsigned int window = 100;
50+
stan::test::unit::instrumented_interrupt interrupt;
51+
EXPECT_EQ(interrupt.call_count(), 0);
52+
53+
int return_code = stan::services::sample::hmc_nuts_dense_e(
54+
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
55+
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
56+
max_depth, interrupt, logger, init, parameter, diagnostic);
57+
58+
EXPECT_EQ(0, return_code);
59+
60+
int num_output_lines = (num_warmup + num_samples) / num_thin;
61+
EXPECT_EQ((num_warmup + num_samples) * num_chains, interrupt.call_count());
62+
for (int i = 0; i < num_chains; ++i) {
63+
EXPECT_EQ(1, parameter[i].call_count("vector_string"));
64+
EXPECT_EQ(num_output_lines, parameter[i].call_count("vector_double"));
65+
EXPECT_EQ(1, diagnostic[i].call_count("vector_string"));
66+
EXPECT_EQ(num_output_lines, diagnostic[i].call_count("vector_double"));
67+
}
68+
}
69+
70+
TEST_F(ServicesSampleHMCNutsDenseENoParams, parameter_checks) {
71+
unsigned int random_seed = 0;
72+
unsigned int chain = 1;
73+
double init_radius = 0;
74+
int num_warmup = 200;
75+
int num_samples = 400;
76+
int num_thin = 5;
77+
bool save_warmup = true;
78+
int refresh = 0;
79+
double stepsize = 0.1;
80+
double stepsize_jitter = 0;
81+
int max_depth = 8;
82+
double delta = .1;
83+
double gamma = .1;
84+
double kappa = .1;
85+
double t0 = .1;
86+
unsigned int init_buffer = 50;
87+
unsigned int term_buffer = 50;
88+
unsigned int window = 100;
89+
stan::test::unit::instrumented_interrupt interrupt;
90+
EXPECT_EQ(interrupt.call_count(), 0);
91+
92+
int return_code = stan::services::sample::hmc_nuts_dense_e(
93+
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
94+
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
95+
max_depth, interrupt, logger, init, parameter, diagnostic);
96+
97+
for (size_t i = 0; i < num_chains; ++i) {
98+
std::vector<std::vector<std::string>> parameter_names;
99+
parameter_names = parameter[i].vector_string_values();
100+
std::vector<std::vector<double>> parameter_values;
101+
parameter_values = parameter[i].vector_double_values();
102+
std::vector<std::string> strings;
103+
strings = parameter[i].string_values();
104+
std::vector<std::vector<std::string>> diagnostic_names;
105+
diagnostic_names = diagnostic[i].vector_string_values();
106+
std::vector<std::vector<double>> diagnostic_values;
107+
diagnostic_values = diagnostic[i].vector_double_values();
108+
109+
// Expect message at end of warmup
110+
EXPECT_EQ("Adaptation terminated", strings[0]);
111+
112+
// Expectations of sampler and model variables names.
113+
ASSERT_EQ(9, parameter_names[0].size());
114+
EXPECT_EQ("lp__", parameter_names[0][0]);
115+
EXPECT_EQ("accept_stat__", parameter_names[0][1]);
116+
EXPECT_EQ("stepsize__", parameter_names[0][2]);
117+
EXPECT_EQ("treedepth__", parameter_names[0][3]);
118+
EXPECT_EQ("n_leapfrog__", parameter_names[0][4]);
119+
EXPECT_EQ("divergent__", parameter_names[0][5]);
120+
EXPECT_EQ("energy__", parameter_names[0][6]);
121+
EXPECT_EQ("theta", parameter_names[0][7]);
122+
EXPECT_EQ("eta", parameter_names[0][8]);
123+
124+
// Expect one name per parameter value.
125+
EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size());
126+
EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size());
127+
128+
EXPECT_EQ((num_warmup + num_samples) / num_thin, parameter_values.size());
129+
130+
// Expect one call to set parameter names, and one set of output per
131+
// iteration.
132+
EXPECT_EQ("lp__", diagnostic_names[0][0]);
133+
EXPECT_EQ("accept_stat__", diagnostic_names[0][1]);
134+
}
135+
EXPECT_EQ(return_code, 0);
136+
}
137+
138+
TEST_F(ServicesSampleHMCNutsDenseENoParams, output_regression) {
139+
unsigned int random_seed = 0;
140+
unsigned int chain = 1;
141+
double init_radius = 0;
142+
int num_warmup = 200;
143+
int num_samples = 400;
144+
int num_thin = 5;
145+
bool save_warmup = true;
146+
int refresh = 0;
147+
double stepsize = 0.1;
148+
double stepsize_jitter = 0;
149+
int max_depth = 8;
150+
double delta = .1;
151+
double gamma = .1;
152+
double kappa = .1;
153+
double t0 = .1;
154+
unsigned int init_buffer = 50;
155+
unsigned int term_buffer = 50;
156+
unsigned int window = 100;
157+
stan::test::unit::instrumented_interrupt interrupt;
158+
EXPECT_EQ(interrupt.call_count(), 0);
159+
160+
stan::services::sample::hmc_nuts_dense_e(
161+
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
162+
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
163+
max_depth, interrupt, logger, init, parameter, diagnostic);
164+
165+
for (auto&& init_it : init) {
166+
std::vector<std::string> init_values;
167+
init_values = init_it.string_values();
168+
169+
EXPECT_EQ(0, init_values.size());
170+
}
171+
172+
EXPECT_EQ(num_chains, logger.find_info("Elapsed Time:"));
173+
EXPECT_EQ(num_chains, logger.find_info("seconds (Warm-up)"));
174+
EXPECT_EQ(num_chains, logger.find_info("seconds (Sampling)"));
175+
EXPECT_EQ(num_chains, logger.find_info("seconds (Total)"));
176+
EXPECT_EQ(0, logger.call_count_error());
177+
}

src/test/unit/services/util/inv_metric_test.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ TEST(inv_metric, create_diag_sz100) {
2727
ASSERT_NEAR(1.0, diag_vals[99], 0.0001);
2828
}
2929

30+
TEST(inv_metric, create_dense_sz0) {
31+
auto default_metric = stan::services::util::create_unit_e_dense_inv_metric(0);
32+
stan::io::var_context& inv_inv_metric = default_metric;
33+
std::vector<double> diag_vals = inv_inv_metric.vals_r("inv_metric");
34+
EXPECT_EQ(0, diag_vals.size());
35+
}
36+
3037
TEST(inv_metric, create_dense_sz2) {
3138
auto default_metric = stan::services::util::create_unit_e_dense_inv_metric(2);
3239
stan::io::var_context& inv_inv_metric = default_metric;
@@ -122,9 +129,9 @@ TEST(inv_metric, read_dense_OK) {
122129

123130
TEST(inv_metric, read_dense_sz0) {
124131
stan::callbacks::logger logger;
125-
stan::io::dump dmp = stan::services::util::create_unit_e_dense_inv_metric(0);
132+
auto zero_metric = stan::services::util::create_unit_e_dense_inv_metric(0);
126133
Eigen::MatrixXd inv_inv_metric
127-
= stan::services::util::read_dense_inv_metric(dmp, 0, logger);
134+
= stan::services::util::read_dense_inv_metric(zero_metric, 0, logger);
128135
EXPECT_EQ(0, inv_inv_metric.size());
129136
EXPECT_EQ(0, inv_inv_metric.rows());
130137
EXPECT_EQ(0, inv_inv_metric.cols());

0 commit comments

Comments
 (0)