|
25 | 25 | import json |
26 | 26 | import os |
27 | 27 | import pickle |
| 28 | +import numbers |
28 | 29 | import numpy as np |
29 | 30 | import pytest |
| 31 | +from copy import deepcopy |
30 | 32 | from sklearn.model_selection import train_test_split |
31 | 33 | from sklearn.preprocessing import MinMaxScaler |
32 | 34 | from sklearn.datasets import make_regression |
@@ -289,3 +291,93 @@ def test_seeding(data): |
289 | 291 | # clean up |
290 | 292 | if os.path.exists(POP_FILENAME): |
291 | 293 | os.remove(POP_FILENAME) |
| 294 | + |
| 295 | + |
| 296 | +def _compare_dicts(d1, d2, path=""): |
| 297 | + diffs = [] |
| 298 | + all_keys = set(d1.keys()) | set(d2.keys()) |
| 299 | + |
| 300 | + for key in all_keys: |
| 301 | + subpath = f"{path}.{key}" if path else key |
| 302 | + |
| 303 | + if key not in d1 or key not in d2: |
| 304 | + diffs.append((subpath, "Path exists in only one dict")) |
| 305 | + continue |
| 306 | + |
| 307 | + v1, v2 = d1[key], d2[key] |
| 308 | + |
| 309 | + if isinstance(v1, dict) and isinstance(v2, dict): |
| 310 | + diffs.extend(_compare_dicts(v1, v2, subpath)) |
| 311 | + elif isinstance(v1, list) and isinstance(v2, list): |
| 312 | + if len(v1) != len(v2): |
| 313 | + diffs.append((subpath, f"List length differs: {len(v1)} != {len(v2)}")) |
| 314 | + for i, (x, y) in enumerate(zip(v1, v2)): |
| 315 | + diffs.extend(_compare_dicts({0: x}, {0: y}, f"{subpath}[{i}]")) |
| 316 | + elif isinstance(v1, numbers.Real) and isinstance(v2, numbers.Real): |
| 317 | + if not np.isclose(v1, v2, atol=1e-10, rtol=0.0): |
| 318 | + diffs.append((subpath, f"{v1} != {v2}")) |
| 319 | + elif v1 != v2: |
| 320 | + diffs.append((subpath, f"{v1} != {v2}")) |
| 321 | + |
| 322 | + return diffs |
| 323 | + |
| 324 | + |
| 325 | +def _test_pop_replace(tmp_path, pop_init, clean, fitinbetween, warm_start): |
| 326 | + N = 500 |
| 327 | + DX = 3 |
| 328 | + X = np.random.random((N, DX)) |
| 329 | + y = np.random.randn(N, 1) |
| 330 | + |
| 331 | + xcs = xcsf.XCS(x_dim=DX, pop_size=5, max_trials=1000, pop_init=pop_init) |
| 332 | + xcs.fit(X, y, verbose=False) |
| 333 | + |
| 334 | + # Initial, “too large” population. |
| 335 | + json0 = xcs.json() |
| 336 | + pop0 = json.loads(json0) |
| 337 | + |
| 338 | + # “Pruning”. |
| 339 | + pop1 = deepcopy(pop0) |
| 340 | + del pop1["classifiers"][0] |
| 341 | + json1 = json.dumps(pop1) |
| 342 | + (tmp_path / "pset1.json").write_text(json1) |
| 343 | + |
| 344 | + if fitinbetween: |
| 345 | + xcs.fit(X, y, warm_start=True, verbose=False) |
| 346 | + |
| 347 | + xcs.json_read(str(tmp_path / "pset1.json"), clean=clean) |
| 348 | + |
| 349 | + # Pipe through `loads` b/c that was done above as well. |
| 350 | + json2 = json.dumps(json.loads(xcs.json())) |
| 351 | + |
| 352 | + list1 = json.loads(json1)["classifiers"] |
| 353 | + list2 = json.loads(json2)["classifiers"] |
| 354 | + |
| 355 | + if len(list1) != len(list2): |
| 356 | + return False |
| 357 | + else: |
| 358 | + unequal = False |
| 359 | + for cl1, cl2 in zip(list1, list2): |
| 360 | + # If there is any difference, … |
| 361 | + if _compare_dicts(cl1, cl2): |
| 362 | + unequal = True |
| 363 | + break |
| 364 | + return not unequal |
| 365 | + |
| 366 | + |
| 367 | +@pytest.mark.parametrize( |
| 368 | + "pop_init,clean,fitinbetween,warm_start", |
| 369 | + [ |
| 370 | + (False, True, False, False), |
| 371 | + (False, True, True, False), |
| 372 | + (False, True, True, True), |
| 373 | + (True, True, False, False), |
| 374 | + (True, True, True, False), |
| 375 | + (True, True, True, True), |
| 376 | + ], |
| 377 | +) |
| 378 | +def test_pop_replace(tmp_path, pop_init, clean, fitinbetween, warm_start): |
| 379 | + for seed in range(19): |
| 380 | + np.random.seed(seed) |
| 381 | + assert _test_pop_replace( |
| 382 | + tmp_path, pop_init, clean, fitinbetween, warm_start |
| 383 | + ), f"failed at seed {seed}" |
0 commit comments