Skip to content

Commit 0001f54

Browse files
authored
Add clean parameter to json_read (#202)
1 parent 07fb6f0 commit 0001f54

File tree

2 files changed

+98
-2
lines changed

2 files changed

+98
-2
lines changed

test/python/test_xcsf.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
import json
2626
import os
2727
import pickle
28+
import numbers
2829
import numpy as np
2930
import pytest
31+
from copy import deepcopy
3032
from sklearn.model_selection import train_test_split
3133
from sklearn.preprocessing import MinMaxScaler
3234
from sklearn.datasets import make_regression
@@ -289,3 +291,93 @@ def test_seeding(data):
289291
# clean up
290292
if os.path.exists(POP_FILENAME):
291293
os.remove(POP_FILENAME)
294+
295+
296+
def _compare_dicts(d1, d2, path=""):
297+
diffs = []
298+
all_keys = set(d1.keys()) | set(d2.keys())
299+
300+
for key in all_keys:
301+
subpath = f"{path}.{key}" if path else key
302+
303+
if key not in d1 or key not in d2:
304+
diffs.append((subpath, "Path exists in only one dict"))
305+
continue
306+
307+
v1, v2 = d1[key], d2[key]
308+
309+
if isinstance(v1, dict) and isinstance(v2, dict):
310+
diffs.extend(_compare_dicts(v1, v2, subpath))
311+
elif isinstance(v1, list) and isinstance(v2, list):
312+
if len(v1) != len(v2):
313+
diffs.append((subpath, f"List length differs: {len(v1)} != {len(v2)}"))
314+
for i, (x, y) in enumerate(zip(v1, v2)):
315+
diffs.extend(_compare_dicts({0: x}, {0: y}, f"{subpath}[{i}]"))
316+
elif isinstance(v1, numbers.Real) and isinstance(v2, numbers.Real):
317+
if not np.isclose(v1, v2, atol=1e-10, rtol=0.0):
318+
diffs.append((subpath, f"{v1} != {v2}"))
319+
elif v1 != v2:
320+
diffs.append((subpath, f"{v1} != {v2}"))
321+
322+
return diffs
323+
324+
325+
def _test_pop_replace(tmp_path, pop_init, clean, fitinbetween, warm_start):
326+
N = 500
327+
DX = 3
328+
X = np.random.random((N, DX))
329+
y = np.random.randn(N, 1)
330+
331+
xcs = xcsf.XCS(x_dim=DX, pop_size=5, max_trials=1000, pop_init=pop_init)
332+
xcs.fit(X, y, verbose=False)
333+
334+
# Initial, “too large” population.
335+
json0 = xcs.json()
336+
pop0 = json.loads(json0)
337+
338+
# “Pruning”.
339+
pop1 = deepcopy(pop0)
340+
del pop1["classifiers"][0]
341+
json1 = json.dumps(pop1)
342+
(tmp_path / "pset1.json").write_text(json1)
343+
344+
if fitinbetween:
345+
xcs.fit(X, y, warm_start=True, verbose=False)
346+
347+
xcs.json_read(str(tmp_path / "pset1.json"), clean=clean)
348+
349+
# Pipe through `loads` b/c that was done above as well.
350+
json2 = json.dumps(json.loads(xcs.json()))
351+
352+
list1 = json.loads(json1)["classifiers"]
353+
list2 = json.loads(json2)["classifiers"]
354+
355+
if len(list1) != len(list2):
356+
return False
357+
else:
358+
unequal = False
359+
for cl1, cl2 in zip(list1, list2):
360+
# If there is any difference, …
361+
if _compare_dicts(cl1, cl2):
362+
unequal = True
363+
break
364+
return not unequal
365+
366+
367+
@pytest.mark.parametrize(
368+
"pop_init,clean,fitinbetween,warm_start",
369+
[
370+
(False, True, False, False),
371+
(False, True, True, False),
372+
(False, True, True, True),
373+
(True, True, False, False),
374+
(True, True, True, False),
375+
(True, True, True, True),
376+
],
377+
)
378+
def test_pop_replace(tmp_path, pop_init, clean, fitinbetween, warm_start):
379+
for seed in range(19):
380+
np.random.seed(seed)
381+
assert _test_pop_replace(
382+
tmp_path, pop_init, clean, fitinbetween, warm_start
383+
), f"failed at seed {seed}"

xcsf/pybind_wrapper.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -943,8 +943,12 @@ class XCS
943943
* @param [in] filename Name of the input file.
944944
*/
945945
void
946-
json_read(const std::string &filename)
946+
json_read(const std::string &filename, const bool clean)
947947
{
948+
if (clean) {
949+
clset_kill(&xcs, &xcs.pset);
950+
clset_init(&xcs.pset);
951+
}
948952
std::ifstream infile(filename);
949953
std::stringstream buffer;
950954
buffer << infile.rdbuf();
@@ -1105,7 +1109,7 @@ PYBIND11_MODULE(xcsf, m)
11051109
py::arg("filename"))
11061110
.def("json_read", &XCS::json_read,
11071111
"Reads classifiers from a JSON file and adds to the population.",
1108-
py::arg("filename"))
1112+
py::arg("filename"), py::arg("clean") = true)
11091113
.def("get_params", &XCS::get_params, py::arg("deep") = true,
11101114
"Returns a dictionary of parameters and their values.")
11111115
.def("set_params", &XCS::set_params, "Sets parameters.")

0 commit comments

Comments
 (0)