Skip to content

Commit 0b6d67b

Browse files
authored
Merge pull request #1191 from stan-dev/fix/init-multichain-bug
Fix: multichain does not use any init file besides the first
2 parents 5c0b7b2 + c23160f commit 0b6d67b

File tree

5 files changed

+78
-31
lines changed

5 files changed

+78
-31
lines changed

src/cmdstan/command.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ int command(int argc, const char *argv[]) {
210210
}
211211

212212
std::vector<std::shared_ptr<stan::io::var_context>> init_contexts
213-
= get_vec_var_context(init, num_chains);
213+
= get_vec_var_context(init, num_chains, id);
214214
std::vector<std::string> model_compile_info = model.model_compile_info();
215215

216216
for (int i = 0; i < num_chains; ++i) {
@@ -510,7 +510,7 @@ int command(int argc, const char *argv[]) {
510510
dynamic_cast<string_argument *>(algo->arg("hmc")->arg("metric_file"))
511511
->value());
512512
context_vector metric_contexts
513-
= get_vec_var_context(metric_filename, num_chains);
513+
= get_vec_var_context(metric_filename, num_chains, id);
514514
categorical_argument *adapt
515515
= dynamic_cast<categorical_argument *>(sample_arg->arg("adapt"));
516516
categorical_argument *hmc

src/cmdstan/command_helper.hpp

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,29 @@ inline shared_context_ptr get_var_context(const std::string &file) {
192192
return std::make_shared<stan::io::dump>(var_context);
193193
}
194194

195+
std::vector<std::string> make_filenames(const std::string &filename,
196+
const std::string &tag,
197+
const std::string &type,
198+
unsigned int num_chains,
199+
unsigned int id) {
200+
std::vector<std::string> names(num_chains);
201+
auto base_sfx = get_basename_suffix(filename);
202+
if (base_sfx.second.empty()) {
203+
base_sfx.second = type;
204+
}
205+
auto name_iterator = [num_chains, id](auto i) {
206+
if (num_chains == 1) {
207+
return std::string("");
208+
} else {
209+
return std::string("_" + std::to_string(i + id));
210+
}
211+
};
212+
for (int i = 0; i < num_chains; ++i) {
213+
names[i] = base_sfx.first + tag + name_iterator(i) + base_sfx.second;
214+
}
215+
return names;
216+
}
217+
195218
using context_vector = std::vector<shared_context_ptr>;
196219
/**
197220
* Make a vector of shared pointers to contexts.
@@ -201,7 +224,8 @@ using context_vector = std::vector<shared_context_ptr>;
201224
* @param num_chains The number of chains to run
202225
* @return a std vector of shared pointers to var contexts
203226
*/
204-
context_vector get_vec_var_context(const std::string &file, size_t num_chains) {
227+
context_vector get_vec_var_context(const std::string &file, size_t num_chains,
228+
unsigned int id) {
205229
using stan::io::var_context;
206230
if (num_chains == 1) {
207231
return context_vector(1, get_var_context(file));
@@ -249,8 +273,9 @@ context_vector get_vec_var_context(const std::string &file, size_t num_chains) {
249273
"\tConsider saving your data in JSON format instead."
250274
<< std::endl;
251275
}
252-
std::string file_1
253-
= std::string(file_name + "_" + std::to_string(1) + file_ending);
276+
277+
auto filenames = make_filenames(file_name, "", file_ending, num_chains, id);
278+
auto &file_1 = filenames[0];
254279
std::fstream stream_1(file_1.c_str(), std::fstream::in);
255280
// if file_1 exists we'll assume num_chains of these files exist
256281
if (stream_1.rdstate() & std::ifstream::failbit) {
@@ -274,9 +299,8 @@ context_vector get_vec_var_context(const std::string &file, size_t num_chains) {
274299
ret.reserve(num_chains);
275300
ret.push_back(make_context(file_1, stream_1, file_ending));
276301
for (size_t i = 1; i < num_chains; ++i) {
277-
std::string file_i
278-
= std::string(file_name + "_" + std::to_string(i) + file_ending);
279-
std::fstream stream_i(file_1.c_str(), std::fstream::in);
302+
auto &file_i = filenames[i];
303+
std::fstream stream_i(file_i.c_str(), std::fstream::in);
280304
// If any stream fails here something went wrong with file names
281305
if (stream_i.rdstate() & std::ifstream::failbit) {
282306
std::string file_name_err = std::string(
@@ -737,29 +761,6 @@ void check_file_config(argument_parser &parser) {
737761
}
738762
}
739763

740-
std::vector<std::string> make_filenames(const std::string &filename,
741-
const std::string &tag,
742-
const std::string &type,
743-
unsigned int num_chains,
744-
unsigned int id) {
745-
std::vector<std::string> names(num_chains);
746-
auto base_sfx = get_basename_suffix(filename);
747-
if (base_sfx.second.empty()) {
748-
base_sfx.second = type;
749-
}
750-
auto name_iterator = [num_chains, id](auto i) {
751-
if (num_chains == 1) {
752-
return std::string("");
753-
} else {
754-
return std::string("_" + std::to_string(i + id));
755-
}
756-
};
757-
for (int i = 0; i < num_chains; ++i) {
758-
names[i] = base_sfx.first + tag + name_iterator(i) + base_sfx.second;
759-
}
760-
return names;
761-
}
762-
763764
void init_callbacks(
764765
argument_parser &parser,
765766
std::vector<stan::callbacks::unique_stream_writer<std::ofstream>>

src/test/interface/multi_chain_init_test.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class CmdStan : public testing::Test {
1818
init_data = {"src", "test", "test-models", "bern_init.json"};
1919
init2_data = {"src", "test", "test-models", "bern_init2.json"};
2020
init3_data = {"src", "test", "test-models", "bern_init2.R"};
21+
init_bad_data = {"src", "test", "test-models", "bern_init_bad.json"};
2122
dev_null_path = {"/dev", "null"};
2223
}
2324
std::vector<std::string> bern_model;
@@ -26,6 +27,7 @@ class CmdStan : public testing::Test {
2627
std::vector<std::string> init_data;
2728
std::vector<std::string> init2_data;
2829
std::vector<std::string> init3_data;
30+
std::vector<std::string> init_bad_data;
2931
};
3032

3133
TEST_F(CmdStan, multi_chain_single_init_file_good) {
@@ -52,6 +54,44 @@ TEST_F(CmdStan, multi_chain_multi_init_file_good) {
5254
ASSERT_FALSE(out.hasError);
5355
}
5456

57+
TEST_F(CmdStan, multi_chain_multi_init_file_id_good) {
58+
std::stringstream ss;
59+
ss << convert_model_path(bern_model)
60+
<< " data file=" << convert_model_path(bern_data)
61+
<< " output file=" << convert_model_path(dev_null_path)
62+
<< " init=" << convert_model_path(init2_data) << " id=2"
63+
<< " method=sample num_chains=2";
64+
std::string cmd = ss.str();
65+
run_command_output out = run_command(cmd);
66+
ASSERT_FALSE(out.hasError) << out.output;
67+
}
68+
69+
TEST_F(CmdStan, multi_chain_multi_init_file_id_bad) {
70+
// this will start by requesting ..._4.json, which doesn't exist
71+
std::stringstream ss;
72+
ss << convert_model_path(bern_model)
73+
<< " data file=" << convert_model_path(bern_data)
74+
<< " output file=" << convert_model_path(dev_null_path)
75+
<< " init=" << convert_model_path(init2_data) << " id=4"
76+
<< " method=sample num_chains=3";
77+
std::string cmd = ss.str();
78+
run_command_output out = run_command(cmd);
79+
ASSERT_TRUE(out.hasError);
80+
}
81+
82+
TEST_F(CmdStan, multi_chain_multi_init_file_actually_used) {
83+
// the second chain has a bad init value
84+
std::stringstream ss;
85+
ss << convert_model_path(bern_model)
86+
<< " data file=" << convert_model_path(bern_data)
87+
<< " output file=" << convert_model_path(dev_null_path)
88+
<< " init=" << convert_model_path(init_bad_data)
89+
<< " method=sample num_chains=2";
90+
std::string cmd = ss.str();
91+
run_command_output out = run_command(cmd);
92+
ASSERT_TRUE(out.hasError) << out.output;
93+
}
94+
5595
TEST_F(CmdStan, multi_chain_multi_init_file_R) {
5696
std::stringstream ss;
5797
ss << convert_model_path(bern_model)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"theta" : 0.1
3+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"theta" : 3.0
3+
}

0 commit comments

Comments
 (0)