Skip to content

Commit ee76033

Browse files
authored
Merge pull request #1318 from stan-dev/config-file-comma-output
Some missed cases in comma separation
2 parents 7d8fcbb + 03439d6 commit ee76033

File tree

4 files changed

+105
-23
lines changed

4 files changed

+105
-23
lines changed

src/cmdstan/command.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,14 @@ int command(int argc, const char *argv[]) {
277277
= get_vec_var_context(init, num_chains, id);
278278

279279
if (get_arg_val<bool_argument>(parser, "output", "save_cmdstan_config")) {
280+
auto base_file = output_file;
281+
// when there are commas, take first file
282+
auto comma_pos = base_file.find(',');
283+
if (comma_pos != std::string::npos) {
284+
base_file = base_file.substr(0, comma_pos);
285+
}
280286
auto config_filename
281-
= file::get_basename_suffix(output_file).first + "_config.json";
287+
= file::get_basename_suffix(base_file).first + "_config.json";
282288
auto ofs_args = file::safe_create(config_filename, sig_figs);
283289
stan::callbacks::json_writer<std::ostream> json_args(std::move(ofs_args));
284290
write_config(json_args, parser, model);

src/cmdstan/file.hpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ std::pair<std::string, std::string> get_basename_suffix(
105105
return {base, suffix};
106106
}
107107

108+
std::vector<std::string> split_on_comma(const std::string &input) {
109+
std::vector<std::string> result;
110+
boost::algorithm::split(result, input, boost::is_any_of(","),
111+
boost::token_compress_on);
112+
return result;
113+
}
114+
108115
/**
109116
* Check if two file paths are the same file.
110117
* @note This function only handles very basic access patterns.
@@ -117,6 +124,27 @@ std::pair<std::string, std::string> get_basename_suffix(
117124
*/
118125
bool check_approx_same_file(const std::string &path1,
119126
const std::string &path2) {
127+
// if path1 has a comma in it, check all of them
128+
// if any match, return true
129+
if (path1.find(',') != std::string::npos) {
130+
for (const auto &name : split_on_comma(path1)) {
131+
if (check_approx_same_file(name, path2)) {
132+
return true;
133+
}
134+
}
135+
return false;
136+
}
137+
138+
// same for path2
139+
if (path2.find(',') != std::string::npos) {
140+
for (const auto &name : split_on_comma(path2)) {
141+
if (check_approx_same_file(path1, name)) {
142+
return true;
143+
}
144+
}
145+
return false;
146+
}
147+
120148
const auto path1_size = path1.size();
121149
const auto path2_size = path2.size();
122150
if (path1.empty() || path2.empty()) {
@@ -163,6 +191,14 @@ bool check_approx_same_file(const std::string &path1,
163191
* @param fname candidate output filename
164192
*/
165193
void validate_output_filename(const std::string &fname) {
194+
// if a , is present, check all values
195+
if (fname.find(',') != std::string::npos) {
196+
for (const auto &name : split_on_comma(fname)) {
197+
validate_output_filename(name);
198+
}
199+
return;
200+
}
201+
166202
std::string sep = std::string(1, cmdstan::file::PATH_SEPARATOR);
167203
if (!fname.empty()
168204
&& (fname[fname.size() - 1] == PATH_SEPARATOR
@@ -195,9 +231,7 @@ std::vector<std::string> make_filenames(const std::string &filename,
195231

196232
// if a ',' is present, we assume the user fully specified the names
197233
if (filename.find(',') != std::string::npos) {
198-
std::vector<std::string> filenames;
199-
boost::algorithm::split(filenames, filename, boost::is_any_of(","),
200-
boost::token_compress_on);
234+
std::vector<std::string> filenames = split_on_comma(filename);
201235
if (filenames.size() != num_chains) {
202236
std::stringstream msg;
203237
msg << "Number of filenames does not match number of chains: got "

src/test/interface/config_json_test.cpp

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,67 @@ using cmdstan::test::run_command_output;
1010
class CmdStan : public testing::Test {
1111
public:
1212
void SetUp() {
13-
multi_normal_model = {"src", "test", "test-models", "multi_normal_model"};
14-
arg_output = {"test", "output"};
15-
output_csv = {"test", "output.csv"};
16-
output_json = {"test", "output_config.json"};
13+
multi_normal_model = convert_model_path(
14+
std::vector{"src", "test", "test-models", "multi_normal_model"});
15+
arg_output = convert_model_path(std::vector{"test", "output"});
16+
17+
output_csv = convert_model_path(std::vector{"test", "output.csv"});
18+
output_json = convert_model_path(std::vector{"test", "output_config.json"});
19+
20+
output_csv_multi
21+
= convert_model_path(std::vector{"test", "output_multi.csv"});
22+
output_json_multi
23+
= convert_model_path(std::vector{"test", "output_multi_config.json"});
1724
}
1825

1926
void TearDown() {
20-
std::remove(convert_model_path(output_csv).c_str());
21-
std::remove(convert_model_path(output_json).c_str());
27+
std::remove(output_csv.c_str());
28+
std::remove(output_json.c_str());
29+
std::remove(output_csv_multi.c_str());
30+
std::remove(output_json_multi.c_str());
2231
}
2332

24-
std::vector<std::string> multi_normal_model;
25-
std::vector<std::string> arg_output;
26-
std::vector<std::string> output_csv;
27-
std::vector<std::string> output_json;
33+
std::string multi_normal_model;
34+
std::string arg_output;
35+
std::string output_csv;
36+
std::string output_json;
37+
38+
std::string output_csv_multi;
39+
std::string output_json_multi;
2840
};
2941

3042
TEST_F(CmdStan, config_json_output_valid) {
3143
std::stringstream ss;
32-
ss << convert_model_path(multi_normal_model)
33-
<< " sample output file=" << convert_model_path(arg_output)
44+
ss << multi_normal_model << " sample output file=" << arg_output
3445
<< " save_cmdstan_config=1";
3546
run_command_output out = run_command(ss.str());
3647
ASSERT_FALSE(out.hasError) << out.output;
37-
ASSERT_TRUE(file_exists(convert_model_path(output_csv)));
38-
ASSERT_TRUE(file_exists(convert_model_path(output_json)));
48+
ASSERT_TRUE(file_exists(output_csv));
49+
ASSERT_TRUE(file_exists(output_json));
50+
51+
std::fstream json_in(output_json);
52+
std::stringstream result_json_sstream;
53+
result_json_sstream << json_in.rdbuf();
54+
json_in.close();
55+
std::string json = result_json_sstream.str();
56+
57+
ASSERT_FALSE(json.empty());
58+
ASSERT_TRUE(is_valid_JSON(json));
59+
}
60+
61+
TEST_F(CmdStan, config_json_output_valid_multi) {
62+
std::stringstream ss;
63+
ss << multi_normal_model
64+
<< " sample num_chains=2 output file=" << output_csv_multi << ","
65+
<< output_csv << " save_cmdstan_config=true";
66+
run_command_output out = run_command(ss.str());
67+
ASSERT_FALSE(out.hasError) << out.output;
68+
ASSERT_TRUE(file_exists(output_csv_multi));
69+
ASSERT_TRUE(file_exists(output_csv));
70+
ASSERT_TRUE(file_exists(output_json_multi));
71+
ASSERT_FALSE(file_exists(output_json));
3972

40-
std::fstream json_in(convert_model_path(output_json));
73+
std::fstream json_in(output_json_multi);
4174
std::stringstream result_json_sstream;
4275
result_json_sstream << json_in.rdbuf();
4376
json_in.close();
@@ -49,10 +82,9 @@ TEST_F(CmdStan, config_json_output_valid) {
4982

5083
TEST_F(CmdStan, config_json_output_not_requested) {
5184
std::stringstream ss;
52-
ss << convert_model_path(multi_normal_model)
53-
<< " sample output file=" << convert_model_path(arg_output);
85+
ss << multi_normal_model << " sample output file=" << arg_output;
5486
run_command_output out = run_command(ss.str());
5587
ASSERT_FALSE(out.hasError);
56-
ASSERT_TRUE(file_exists(convert_model_path(output_csv)));
57-
ASSERT_FALSE(file_exists(convert_model_path(output_json)));
88+
ASSERT_TRUE(file_exists(output_csv));
89+
ASSERT_FALSE(file_exists(output_json));
5890
}

src/test/interface/file_test.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ TEST(CommandHelper, validate_output_filename) {
7575

7676
std::string fp4 = "foo.bar" + sep + ".";
7777
EXPECT_THROW(validate_output_filename(fp4), std::invalid_argument);
78+
79+
std::string fp5 = "foo.bar,baz.gak";
80+
EXPECT_NO_THROW(validate_output_filename(fp5));
81+
82+
std::string fp6 = "foo.bar,foo.bar" + sep + ".";
83+
EXPECT_THROW(validate_output_filename(fp6), std::invalid_argument);
7884
}
7985

8086
TEST(CommandHelper, make_filenames) {
@@ -253,4 +259,8 @@ TEST(CommandHelper, check_same_file_test) {
253259
EXPECT_FALSE(cmdstan::file::check_approx_same_file(path, dot_path_bad));
254260
std::string dot_path_good = "a/b/file.txt";
255261
EXPECT_TRUE(cmdstan::file::check_approx_same_file(path, dot_path_good));
262+
263+
std::string comma_path = path + "," + path;
264+
EXPECT_TRUE(cmdstan::file::check_approx_same_file(path, comma_path));
265+
EXPECT_TRUE(cmdstan::file::check_approx_same_file(comma_path, path));
256266
}

0 commit comments

Comments
 (0)