@@ -45,47 +45,42 @@ static PyObject *make_enum(PyObject *module, PyObject *enum_module, const char *
4545static PyObject *make_policy_type (PyObject *module , PyObject *enum_module) {
4646 if (PolicyTypeEnum != NULL )
4747 return PolicyTypeEnum;
48- size_t value_size = 2 ;
49- const char *values[value_size] = {" AlphaZero" , " RPO" };
50- PolicyTypeEnum = make_enum (module , enum_module, values, value_size, " PolicyType" );
48+ const char *values[] = {" AlphaZero" , " RPO" };
49+ PolicyTypeEnum = make_enum (module , enum_module, values, 2 , " PolicyType" );
5150 return PolicyTypeEnum;
5251}
5352
5453static PyObject *make_q_value_solved (PyObject *module , PyObject *enum_module) {
5554 if (QValueSolvedEnum != NULL )
5655 return QValueSolvedEnum;
57- size_t value_size = 5 ;
58- const char *values[value_size] = {
56+ const char *values[] = {
5957 " OneOverCounts" , " CountOverCounts" , " One" , " OneOverVirtualCounts" , " OneOverCountsNoFPU" , " CountOverCountsNoFPU"
6058 };
61- QValueSolvedEnum = make_enum (module , enum_module, values, value_size , " QValueSolved" );
59+ QValueSolvedEnum = make_enum (module , enum_module, values, 6 , " QValueSolved" );
6260 return QValueSolvedEnum;
6361}
6462
6563static PyObject *make_node_mask (PyObject *module , PyObject *enum_module) {
6664 if (NodeMaskEnum != NULL )
6765 return NodeMaskEnum;
68- size_t value_size = 5 ;
69- const char *values[value_size] = {" NoMask" , " Solving" , " Proof" , " MinimalProof" , " MinimalProofSolving" };
70- NodeMaskEnum = make_enum (module , enum_module, values, value_size, " NodeMask" );
66+ const char *values[5 ] = {" NoMask" , " Solving" , " Proof" , " MinimalProof" , " MinimalProofSolving" };
67+ NodeMaskEnum = make_enum (module , enum_module, values, 5 , " NodeMask" );
7168 return NodeMaskEnum;
7269}
7370
7471static PyObject *make_metric (PyObject *module , PyObject *enum_module) {
7572 if (MetricEnum != NULL )
7673 return MetricEnum;
77- size_t value_size = 3 ;
78- const char *values[value_size] = {" Depth" , " Size" , " Time" };
79- MetricEnum = make_enum (module , enum_module, values, value_size, " Metric" );
74+ const char *values[3 ] = {" Depth" , " Size" , " Time" };
75+ MetricEnum = make_enum (module , enum_module, values, 3 , " Metric" );
8076 return MetricEnum;
8177}
8278
8379static PyObject *make_in_proof (PyObject *module , PyObject *enum_module) {
8480 if (InProofEnum != NULL )
8581 return InProofEnum;
86- size_t value_size = 3 ;
87- const char *values[value_size] = {" NotInProof" , " InProof" , " InMinimalProof" };
88- InProofEnum = make_enum (module , enum_module, values, value_size, " InProof" );
82+ const char *values[3 ] = {" NotInProof" , " InProof" , " InMinimalProof" };
83+ InProofEnum = make_enum (module , enum_module, values, 3 , " InProof" );
8984 return InProofEnum;
9085}
9186
@@ -1847,7 +1842,12 @@ static int EnvExpansion_init(PyObject *self, PyObject *args, PyObject *kwargs) {
18471842 PyErr_SetString (PyExc_ValueError, " priors must sum to 1" );
18481843 return -1 ;
18491844 }
1850-
1845+ std::vector<size_t > sizes = {env_durations.size (), effects.size (), tactics.size (), children_for_tactic.size ()};
1846+ bool are_same = std::all_of (sizes.begin (), sizes.end (), [priors](size_t value) {return priors.size () == value;});
1847+ if (!are_same) {
1848+ PyErr_SetString (PyExc_ValueError, " Priors, tactics, Durations, Effects and Children for Tactic must be of the same size!" );
1849+ return -1 ;
1850+ }
18511851
18521852 new (&(((PyEnvExpansion *)self)->expansion )) htps::env_expansion (
18531853 shared_thm, expander_duration, generation_duration, env_durations,
@@ -1912,6 +1912,23 @@ static int EnvExpansion_init(PyObject *self, PyObject *args, PyObject *kwargs) {
19121912
19131913}
19141914
1915+ static PyObject* EnvExpansion_get_jsonstr (PyObject *self, PyObject *args) {
1916+ auto *obj = (PyEnvExpansion *)self;
1917+ nlohmann::json j = obj->expansion .operator nlohmann::json ();
1918+ return PyObject_from_string (j.dump ());
1919+ }
1920+
1921+ static PyObject* EnvExpansion_from_jsonstr (PyTypeObject *type, PyObject *args) {
1922+ const char *json_str;
1923+ if (!PyArg_ParseTuple (args, " s" , &json_str)) {
1924+ return NULL ;
1925+ }
1926+ nlohmann::json j = nlohmann::json::parse (json_str);
1927+ auto *self = (PyEnvExpansion *)EnvExpansion_new (type, NULL , NULL );
1928+ self->expansion = htps::env_expansion::from_json (j);
1929+ return (PyObject *)self;
1930+ }
1931+
19151932static PyGetSetDef EnvExpansion_getsetters[] = {
19161933 {" thm" , (getter)EnvExpansion_get_thm, (setter)EnvExpansion_set_thm, " Theorem for expansion" , NULL },
19171934 {" expander_duration" , (getter)EnvExpansion_get_expander_duration, (setter)EnvExpansion_set_expander_duration, " Expander duration" , NULL },
@@ -1927,6 +1944,12 @@ static PyGetSetDef EnvExpansion_getsetters[] = {
19271944 {NULL }
19281945};
19291946
1947+ static PyMethodDef EnvExpansion_methods[] = {
1948+ {" get_json_str" , (PyCFunction)EnvExpansion_get_jsonstr, METH_NOARGS, " Get JSON string representation" },
1949+ {" from_json_str" , (PyCFunction)EnvExpansion_from_jsonstr, METH_VARARGS | METH_CLASS, " Create from JSON string" },
1950+ {NULL , NULL , 0 , NULL }
1951+ };
1952+
19301953
19311954static PyTypeObject EnvExpansionType = {
19321955 PyObject_HEAD_INIT (NULL ) " htps.EnvExpansion" ,
@@ -1955,7 +1978,7 @@ NULL,
19551978NULL ,
19561979NULL ,
19571980NULL ,
1958- NULL ,
1981+ EnvExpansion_methods ,
19591982NULL ,
19601983EnvExpansion_getsetters,
19611984NULL ,
@@ -3166,13 +3189,46 @@ static PyObject* PyHTPS_get_result(PyHTPS *self, PyObject *Py_UNUSED(ignored)) {
31663189 return PyHTPSResult_NewFromResult (result);
31673190}
31683191
3192+ static PyObject* PyHTPS_get_jsonstr (PyHTPS *self, PyObject *Py_UNUSED (ignored)) {
3193+ std::string result;
3194+ try {
3195+ result = nlohmann::json (self->graph ).dump ();
3196+ } catch (std::exception &e) {
3197+ PyErr_SetString (PyExc_RuntimeError, e.what ());
3198+ return NULL ;
3199+ }
3200+ return PyObject_from_string (result);
3201+ }
3202+
3203+ static PyObject* PyHTPS_from_jsonstr (PyTypeObject *type, PyObject *args) {
3204+ const char *json_str;
3205+ if (!PyArg_ParseTuple (args, " s" , &json_str)) {
3206+ PyErr_SetString (PyExc_TypeError, " from_jsonstr expects a string" );
3207+ return NULL ;
3208+ }
3209+ try {
3210+ auto json = nlohmann::json::parse (json_str);
3211+ auto graph = htps::HTPS::from_json (json);
3212+ PyObject *obj = HTPS_new (type, NULL , NULL );
3213+ if (obj == NULL )
3214+ return NULL ;
3215+ auto *py_graph = (PyHTPS *) obj;
3216+ py_graph->graph = graph;
3217+ return obj;
3218+ } catch (std::exception &e) {
3219+ PyErr_SetString (PyExc_RuntimeError, e.what ());
3220+ return NULL ;
3221+ }
3222+ }
31693223
31703224static PyMethodDef HTPS_methods[] = {
31713225 {" theorems_to_expand" , (PyCFunction)PyHTPS_theorems_to_expand, METH_NOARGS, " Returns a list of subsequent theorems to expand" },
31723226 {" expand_and_backup" , (PyCFunction)PyHTPS_expand_and_backup, METH_VARARGS, " Expands and backups using the provided list of EnvExpansion objects" },
31733227 {" proven" , (PyCFunction)PyHTPS_is_proven, METH_NOARGS, " Whether the start theorem is proven or not" },
31743228 {" get_result" , (PyCFunction)PyHTPS_get_result, METH_NOARGS, " Returns the result of the HTPS run" },
31753229 {" is_done" , (PyCFunction)PyHTPS_is_done, METH_NOARGS, " Whether the HTPS run is done or not" },
3230+ {" get_json_str" , (PyCFunction)PyHTPS_get_jsonstr, METH_NOARGS, " Returns a JSON string representation of the HTPS object" },
3231+ {" from_json_str" , (PyCFunction)PyHTPS_from_jsonstr, METH_VARARGS | METH_CLASS, " Creates a HTPS object from a JSON string" },
31763232 {NULL , NULL , 0 , NULL }
31773233};
31783234
0 commit comments