Skip to content

Commit 5912381

Browse files
authored
fix: load log_w minfloat correctly (#23)
1 parent 7f61f9f commit 5912381

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/graph/htps.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,15 @@ HTPSNode HTPSNode::from_json(const nlohmann::json &j) {
809809
node.policy = j["policy"];
810810
node.exploration = j["exploration"];
811811
node.tactic_init_value = j["tactic_init_value"];
812-
node.log_w = static_cast<std::vector<double>>(j["log_w"]);
812+
std::vector<double> log_w;
813+
for (const auto &w: j["log_w"]) {
814+
if (w.is_null()) {
815+
log_w.push_back(MIN_FLOAT);
816+
} else {
817+
log_w.push_back(w);
818+
}
819+
}
820+
node.log_w = log_w;
813821
node.counts = static_cast<std::vector<size_t>>(j["counts"]);
814822
node.virtual_counts = static_cast<std::vector<size_t>>(j["virtual_counts"]);
815823
node.reset_mask = static_cast<std::vector<bool>>(j["reset_mask"]);

0 commit comments

Comments
 (0)