@@ -772,13 +772,11 @@ bool HTPSNode::has_virtual_count() const {
772772}
773773
774774HTPSNode HTPSNode::from_json (const nlohmann::json &j) {
775- HTPSNode node;
776- node.thm =j[" theorem" ];
775+ TheoremPointer thm = j[" theorem" ];
777776 std::vector<std::shared_ptr<tactic>> tactics;
778777 for (const auto &tac: j[" tactics" ]) {
779778 tactics.push_back (tac);
780779 }
781- node.tactics = tactics;
782780 std::vector<std::vector<TheoremPointer>> children_for_tactic;
783781 for (const auto &children: j[" children_for_tactic" ]) {
784782 std::vector<TheoremPointer> children_for_tactic_inner;
@@ -787,28 +785,30 @@ HTPSNode HTPSNode::from_json(const nlohmann::json &j) {
787785 }
788786 children_for_tactic.push_back (children_for_tactic_inner);
789787 }
790- node.children_for_tactic = children_for_tactic;
791- node.killed_tactics = j[" killed_tactics" ].get <std::unordered_set<size_t >>();
792- node.solving_tactics = j[" solving_tactics" ].get <std::unordered_set<size_t >>();
793- node.tactic_expandable = j[" tactic_expandable" ].get <std::vector<bool >>();
794- node.minimum_proof_size = MinimumLengthMap::from_json (j[" minimum_proof_size" ]);
795- node.minimum_tactics = MinimumTacticMap::from_json (j[" minimum_tactics" ]);
796- node.minimum_tactic_length = MinimumTacticLengthMap::from_json (j[" minimum_tactic_length" ]);
797- node.in_minimum_proof = MinimumBoolMap::from_json (j[" in_minimum_proof" ]);
798- node.solved = j[" solved" ];
799- node.is_solved_leaf = j[" is_solved_leaf" ];
800- node.in_proof = j[" in_proof" ];
801- node.old_critic_value = j[" old_critic_value" ];
788+ children_for_tactic = children_for_tactic;
789+ std::unordered_set<size_t > killed_tactics = j[" killed_tactics" ].get <std::unordered_set<size_t >>();
790+ std::unordered_set<size_t > solving_tactics = j[" solving_tactics" ].get <std::unordered_set<size_t >>();
791+ std::vector<bool > tactic_expandable = j[" tactic_expandable" ].get <std::vector<bool >>();
792+ auto minimum_proof_size = MinimumLengthMap::from_json (j[" minimum_proof_size" ]);
793+ auto minimum_tactics = MinimumTacticMap::from_json (j[" minimum_tactics" ]);
794+ auto minimum_tactic_length = MinimumTacticLengthMap::from_json (j[" minimum_tactic_length" ]);
795+ auto in_minimum_proof = MinimumBoolMap::from_json (j[" in_minimum_proof" ]);
796+ bool solved = j[" solved" ];
797+ bool is_solved_leaf = j[" is_solved_leaf" ];
798+ bool in_proof = j[" in_proof" ];
799+ double old_critic_value = j[" old_critic_value" ];
800+ double log_critic_value;
802801 if (!j[" log_critic_value" ].is_null ()) {
803- node. log_critic_value = j[" log_critic_value" ];
802+ log_critic_value = j[" log_critic_value" ];
804803 } else {
805- node. log_critic_value = MIN_FLOAT;
804+ log_critic_value = MIN_FLOAT;
806805 }
807- node.priors = static_cast <std::vector<double >>(j[" priors" ]);
808- node.q_value_solved = j[" q_value_solved" ];
809- node.policy = j[" policy" ];
810- node.exploration = j[" exploration" ];
811- node.tactic_init_value = j[" tactic_init_value" ];
806+ auto priors = static_cast <std::vector<double >>(j[" priors" ]);
807+ QValueSolved q_value_solved = j[" q_value_solved" ];
808+ std::vector<std::shared_ptr<htps::env_effect>> effects = j[" effects" ];
809+ std::shared_ptr<Policy> policy = j[" policy" ];
810+ double exploration = j[" exploration" ];
811+ double tactic_init_value = j[" tactic_init_value" ];
812812 std::vector<double > log_w;
813813 for (const auto &w: j[" log_w" ]) {
814814 if (w.is_null ()) {
@@ -817,10 +817,25 @@ HTPSNode HTPSNode::from_json(const nlohmann::json &j) {
817817 log_w.push_back (w);
818818 }
819819 }
820- node.log_w = log_w;
821- node.counts = static_cast <std::vector<size_t >>(j[" counts" ]);
822- node.virtual_counts = static_cast <std::vector<size_t >>(j[" virtual_counts" ]);
823- node.reset_mask = static_cast <std::vector<bool >>(j[" reset_mask" ]);
820+
821+ std::vector<size_t > counts = static_cast <std::vector<size_t >>(j[" counts" ]);
822+ std::vector<size_t > virtual_counts = static_cast <std::vector<size_t >>(j[" virtual_counts" ]);
823+ std::vector<bool > reset_mask = static_cast <std::vector<bool >>(j[" reset_mask" ]);
824+ bool error = j[" error" ];
825+
826+ HTPSNode node = {thm, tactics, children_for_tactic, policy, priors, exploration, log_critic_value, q_value_solved, tactic_init_value, effects, error};
827+ node.killed_tactics = killed_tactics;
828+ node.solving_tactics = solving_tactics;
829+ node.tactic_expandable = tactic_expandable;
830+ node.minimum_proof_size = minimum_proof_size;
831+ node.minimum_tactics = minimum_tactics;
832+ node.minimum_tactic_length = minimum_tactic_length;
833+ node.in_minimum_proof = in_minimum_proof;
834+ node.solved = solved;
835+ node.is_solved_leaf = is_solved_leaf;
836+ node.in_proof = in_proof;
837+ node.old_critic_value = old_critic_value;
838+ node.log_critic_value = log_critic_value;
824839 return node;
825840}
826841
@@ -850,6 +865,8 @@ HTPSNode::operator nlohmann::json() const {
850865 j[" counts" ] = counts;
851866 j[" virtual_counts" ] = virtual_counts;
852867 j[" reset_mask" ] = reset_mask;
868+ j[" error" ] = error;
869+ j[" effects" ] = effects;
853870 return j;
854871}
855872
0 commit comments