Skip to content

Commit 44c1d75

Browse files
authored
fix: update HTPSNode json parsing to include effects and error (#26)
* fix: update HTPSNode json parsing to include effects and error, remove default constructor * fix: update samples json to include effects and error
1 parent c4601a1 commit 44c1d75

File tree

5 files changed

+71076
-30
lines changed

5 files changed

+71076
-30
lines changed

samples/expansions_6.json

Lines changed: 14960 additions & 1 deletion
Large diffs are not rendered by default.

samples/search_6.json

Lines changed: 55959 additions & 1 deletion
Large diffs are not rendered by default.

samples/test.json

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,63 @@
7373
0,
7474
0
7575
],
76+
"effects": [
77+
{
78+
"children": [
79+
{
80+
"conclusion": "C",
81+
"ctx": {
82+
"namespaces": []
83+
},
84+
"hypotheses": [],
85+
"past_tactics": [],
86+
"unique_string": "C"
87+
}
88+
],
89+
"goal": {
90+
"conclusion": "B",
91+
"ctx": {
92+
"namespaces": []
93+
},
94+
"hypotheses": [],
95+
"past_tactics": [],
96+
"unique_string": "B"
97+
},
98+
"tac": {
99+
"duration": 0,
100+
"is_valid": true,
101+
"unique_string": "TACC"
102+
}
103+
},
104+
{
105+
"children": [
106+
{
107+
"conclusion": "C",
108+
"ctx": {
109+
"namespaces": []
110+
},
111+
"hypotheses": [],
112+
"past_tactics": [],
113+
"unique_string": "C"
114+
}
115+
],
116+
"goal": {
117+
"conclusion": "B",
118+
"ctx": {
119+
"namespaces": []
120+
},
121+
"hypotheses": [],
122+
"past_tactics": [],
123+
"unique_string": "B"
124+
},
125+
"tac": {
126+
"duration": 0,
127+
"is_valid": true,
128+
"unique_string": "TACD"
129+
}
130+
}
131+
],
132+
"error": false,
76133
"exploration": 0.3,
77134
"in_minimum_proof": {
78135
"DEPTH": false,
@@ -178,6 +235,63 @@
178235
2,
179236
1
180237
],
238+
"effects": [
239+
{
240+
"children": [
241+
{
242+
"conclusion": "B",
243+
"ctx": {
244+
"namespaces": []
245+
},
246+
"hypotheses": [],
247+
"past_tactics": [],
248+
"unique_string": "B"
249+
}
250+
],
251+
"goal": {
252+
"conclusion": "A",
253+
"ctx": {
254+
"namespaces": []
255+
},
256+
"hypotheses": [],
257+
"past_tactics": [],
258+
"unique_string": "A"
259+
},
260+
"tac": {
261+
"duration": 0,
262+
"is_valid": true,
263+
"unique_string": "TACA"
264+
}
265+
},
266+
{
267+
"children": [
268+
{
269+
"conclusion": "B",
270+
"ctx": {
271+
"namespaces": []
272+
},
273+
"hypotheses": [],
274+
"past_tactics": [],
275+
"unique_string": "B"
276+
}
277+
],
278+
"goal": {
279+
"conclusion": "A",
280+
"ctx": {
281+
"namespaces": []
282+
},
283+
"hypotheses": [],
284+
"past_tactics": [],
285+
"unique_string": "A"
286+
},
287+
"tac": {
288+
"duration": 0,
289+
"is_valid": true,
290+
"unique_string": "TACB"
291+
}
292+
}
293+
],
294+
"error": false,
181295
"exploration": 0.3,
182296
"in_minimum_proof": {
183297
"DEPTH": false,

src/graph/htps.cpp

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -772,13 +772,11 @@ bool HTPSNode::has_virtual_count() const {
772772
}
773773

774774
HTPSNode HTPSNode::from_json(const nlohmann::json &j) {
775-
HTPSNode node;
776-
node.thm =j["theorem"];
775+
TheoremPointer thm = j["theorem"];
777776
std::vector<std::shared_ptr<tactic>> tactics;
778777
for (const auto &tac: j["tactics"]) {
779778
tactics.push_back(tac);
780779
}
781-
node.tactics = tactics;
782780
std::vector<std::vector<TheoremPointer>> children_for_tactic;
783781
for (const auto &children: j["children_for_tactic"]) {
784782
std::vector<TheoremPointer> children_for_tactic_inner;
@@ -787,28 +785,30 @@ HTPSNode HTPSNode::from_json(const nlohmann::json &j) {
787785
}
788786
children_for_tactic.push_back(children_for_tactic_inner);
789787
}
790-
node.children_for_tactic = children_for_tactic;
791-
node.killed_tactics = j["killed_tactics"].get<std::unordered_set<size_t>>();
792-
node.solving_tactics = j["solving_tactics"].get<std::unordered_set<size_t>>();
793-
node.tactic_expandable = j["tactic_expandable"].get<std::vector<bool>>();
794-
node.minimum_proof_size = MinimumLengthMap::from_json(j["minimum_proof_size"]);
795-
node.minimum_tactics = MinimumTacticMap::from_json(j["minimum_tactics"]);
796-
node.minimum_tactic_length = MinimumTacticLengthMap::from_json(j["minimum_tactic_length"]);
797-
node.in_minimum_proof = MinimumBoolMap::from_json(j["in_minimum_proof"]);
798-
node.solved = j["solved"];
799-
node.is_solved_leaf = j["is_solved_leaf"];
800-
node.in_proof = j["in_proof"];
801-
node.old_critic_value = j["old_critic_value"];
788+
children_for_tactic = children_for_tactic;
789+
std::unordered_set<size_t> killed_tactics = j["killed_tactics"].get<std::unordered_set<size_t>>();
790+
std::unordered_set<size_t> solving_tactics = j["solving_tactics"].get<std::unordered_set<size_t>>();
791+
std::vector<bool> tactic_expandable = j["tactic_expandable"].get<std::vector<bool>>();
792+
auto minimum_proof_size = MinimumLengthMap::from_json(j["minimum_proof_size"]);
793+
auto minimum_tactics = MinimumTacticMap::from_json(j["minimum_tactics"]);
794+
auto minimum_tactic_length = MinimumTacticLengthMap::from_json(j["minimum_tactic_length"]);
795+
auto in_minimum_proof = MinimumBoolMap::from_json(j["in_minimum_proof"]);
796+
bool solved = j["solved"];
797+
bool is_solved_leaf = j["is_solved_leaf"];
798+
bool in_proof = j["in_proof"];
799+
double old_critic_value = j["old_critic_value"];
800+
double log_critic_value;
802801
if (!j["log_critic_value"].is_null()) {
803-
node.log_critic_value = j["log_critic_value"];
802+
log_critic_value = j["log_critic_value"];
804803
} else {
805-
node.log_critic_value = MIN_FLOAT;
804+
log_critic_value = MIN_FLOAT;
806805
}
807-
node.priors = static_cast<std::vector<double>>(j["priors"]);
808-
node.q_value_solved = j["q_value_solved"];
809-
node.policy = j["policy"];
810-
node.exploration = j["exploration"];
811-
node.tactic_init_value = j["tactic_init_value"];
806+
auto priors = static_cast<std::vector<double>>(j["priors"]);
807+
QValueSolved q_value_solved = j["q_value_solved"];
808+
std::vector<std::shared_ptr<htps::env_effect>> effects = j["effects"];
809+
std::shared_ptr<Policy> policy = j["policy"];
810+
double exploration = j["exploration"];
811+
double tactic_init_value = j["tactic_init_value"];
812812
std::vector<double> log_w;
813813
for (const auto &w: j["log_w"]) {
814814
if (w.is_null()) {
@@ -817,10 +817,25 @@ HTPSNode HTPSNode::from_json(const nlohmann::json &j) {
817817
log_w.push_back(w);
818818
}
819819
}
820-
node.log_w = log_w;
821-
node.counts = static_cast<std::vector<size_t>>(j["counts"]);
822-
node.virtual_counts = static_cast<std::vector<size_t>>(j["virtual_counts"]);
823-
node.reset_mask = static_cast<std::vector<bool>>(j["reset_mask"]);
820+
821+
std::vector<size_t> counts = static_cast<std::vector<size_t>>(j["counts"]);
822+
std::vector<size_t> virtual_counts = static_cast<std::vector<size_t>>(j["virtual_counts"]);
823+
std::vector<bool> reset_mask = static_cast<std::vector<bool>>(j["reset_mask"]);
824+
bool error = j["error"];
825+
826+
HTPSNode node = {thm, tactics, children_for_tactic, policy, priors, exploration, log_critic_value, q_value_solved, tactic_init_value, effects, error};
827+
node.killed_tactics = killed_tactics;
828+
node.solving_tactics = solving_tactics;
829+
node.tactic_expandable = tactic_expandable;
830+
node.minimum_proof_size = minimum_proof_size;
831+
node.minimum_tactics = minimum_tactics;
832+
node.minimum_tactic_length = minimum_tactic_length;
833+
node.in_minimum_proof = in_minimum_proof;
834+
node.solved = solved;
835+
node.is_solved_leaf = is_solved_leaf;
836+
node.in_proof = in_proof;
837+
node.old_critic_value = old_critic_value;
838+
node.log_critic_value = log_critic_value;
824839
return node;
825840
}
826841

@@ -850,6 +865,8 @@ HTPSNode::operator nlohmann::json() const {
850865
j["counts"] = counts;
851866
j["virtual_counts"] = virtual_counts;
852867
j["reset_mask"] = reset_mask;
868+
j["error"] = error;
869+
j["effects"] = effects;
853870
return j;
854871
}
855872

src/graph/htps.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,6 @@ namespace htps {
431431
reset_HTPS_stats();
432432
}
433433

434-
HTPSNode() = default;
435-
436434
/* Reset the HTPS statistics, resetting counts and logW values.
437435
* */
438436
void reset_HTPS_stats();

0 commit comments

Comments
 (0)