Skip to content

Commit ea6e3e1

Browse files
committed
refactoring
1 parent 2c070c9 commit ea6e3e1

14 files changed

+340
-272
lines changed

_doc/examples/plot_piecewise_linear_regression_criterion.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -152,25 +152,27 @@
152152
#
153153
# ::
154154
#
155-
# cdef void _mean(self, SIZE_t start, SIZE_t end, DOUBLE_t *mean,
156-
# DOUBLE_t *weight) nogil:
155+
# ctypedef double float64_t
156+
#
157+
# cdef void _mean(self, SIZE_t start, SIZE_t end, float64_t *mean,
158+
# float64_t *weight) nogil:
157159
# if start == end:
158160
# mean[0] = 0.
159161
# return
160-
# cdef DOUBLE_t m = 0.
161-
# cdef DOUBLE_t w = 0.
162+
# cdef float64_t m = 0.
163+
# cdef float64_t w = 0.
162164
# cdef int k
163165
# for k in range(start, end):
164166
# m += self.sample_wy[k]
165167
# w += self.sample_w[k]
166168
# weight[0] = w
167169
# mean[0] = 0. if w == 0. else m / w
168170
#
169-
# cdef double _mse(self, SIZE_t start, SIZE_t end, DOUBLE_t mean,
170-
# DOUBLE_t weight) nogil:
171+
# cdef float64_t _mse(self, SIZE_t start, SIZE_t end, float64_t mean,
172+
# float64_t weight) nogil:
171173
# if start == end:
172174
# return 0.
173-
# cdef DOUBLE_t squ = 0.
175+
# cdef float64_t squ = 0.
174176
# cdef int k
175177
# for k in range(start, end):
176178
# squ += (self.y[self.sample_i[k], 0] - mean) ** 2 * self.sample_w[k]
@@ -189,24 +191,26 @@
189191
#
190192
# ::
191193
#
192-
# cdef void _mean(self, SIZE_t start, SIZE_t end, DOUBLE_t *mean,
193-
# DOUBLE_t *weight) nogil:
194+
# ctypedef double float64_t
195+
#
196+
# cdef void _mean(self, SIZE_t start, SIZE_t end, float64_t *mean,
197+
# float64_t *weight) nogil:
194198
# if start == end:
195199
# mean[0] = 0.
196200
# return
197-
# cdef DOUBLE_t m = self.sample_wy_left[end-1] -
198-
# (self.sample_wy_left[start-1] if start > 0 else 0)
199-
# cdef DOUBLE_t w = self.sample_w_left[end-1] -
200-
# (self.sample_w_left[start-1] if start > 0 else 0)
201+
# cdef float64_t m = self.sample_wy_left[end-1] -
202+
# (self.sample_wy_left[start-1] if start > 0 else 0)
203+
# cdef float64_t w = self.sample_w_left[end-1] -
204+
# (self.sample_w_left[start-1] if start > 0 else 0)
201205
# weight[0] = w
202206
# mean[0] = 0. if w == 0. else m / w
203207
#
204-
# cdef double _mse(self, SIZE_t start, SIZE_t end, DOUBLE_t mean,
205-
# DOUBLE_t weight) nogil:
208+
# cdef float64_t _mse(self, SIZE_t start, SIZE_t end, float64_t mean,
209+
# float64_t weight) nogil:
206210
# if start == end:
207211
# return 0.
208-
# cdef DOUBLE_t squ = self.sample_wy2_left[end-1] -
209-
# (self.sample_wy2_left[start-1] if start > 0 else 0)
212+
# cdef float64_t squ = self.sample_wy2_left[end-1] -
213+
# (self.sample_wy2_left[start-1] if start > 0 else 0)
210214
# # This formula only holds if mean is computed on the same interval.
211215
# # Otherwise, it is squ / weight - true_mean ** 2 + (mean - true_mean) ** 2.
212216
# return 0. if weight == 0. else squ / weight - mean ** 2

_unittests/ut_mlmodel/test_piecewise_decision_tree_experiment.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ def test_criterions(self):
2121
with warnings.catch_warnings(record=True) as w:
2222
warnings.simplefilter("always")
2323
from mlinsights.mlmodel._piecewise_tree_regression_common import (
24-
_test_criterion_check,
2524
assert_criterion_equal,
26-
)
27-
from mlinsights.mlmodel._piecewise_tree_regression_common import (
25+
_test_criterion_check,
2826
_test_criterion_init,
2927
_test_criterion_node_impurity,
3028
_test_criterion_node_impurity_children,
@@ -38,10 +36,6 @@ def test_criterions(self):
3836
SimpleRegressorCriterion,
3937
)
4038

41-
if len(w) > 0:
42-
msg = "\n".join(map(str, w))
43-
raise AssertionError(f"Warning while importing the library:\n{msg}")
44-
4539
X = numpy.array([[1.0, 2.0]]).T
4640
y = numpy.array([1.0, 2.0])
4741
c1 = MSE(1, X.shape[0])
@@ -113,6 +107,8 @@ def test_criterions(self):
113107
assert_criterion_equal(c1, c2)
114108
self.assertTrue(numpy.isnan(p1), numpy.isnan(p2))
115109

110+
expected_p2 = [-0.56, -0.04, -0.56]
111+
116112
for i in range(1, 4):
117113
_test_criterion_check(c2)
118114
_test_criterion_update(c1, i)
@@ -122,23 +118,27 @@ def test_criterions(self):
122118
self.assertIsInstance(c2.printd(), str)
123119
left1, right1 = _test_criterion_node_impurity_children(c1)
124120
left2, right2 = _test_criterion_node_impurity_children(c2)
125-
self.assertAlmostEqual(left1, left2)
121+
self.assertAlmostEqual(left1, left2, atol=1e-10)
126122
self.assertAlmostEqual(right1, right2, atol=1e-10)
127123
v1 = _test_criterion_node_value(c1)
128124
v2 = _test_criterion_node_value(c2)
129125
self.assertEqual(v1, v2)
130126
p1 = _test_criterion_impurity_improvement(c1, 0.0, left1, right1)
131127
p2 = _test_criterion_impurity_improvement(c2, 0.0, left2, right2)
132-
self.assertIn(
133-
"value: 1.500000 total=0.260000 left=0.000000 right=0.186667",
134-
_test_criterion_printf(c1),
135-
)
136-
self.assertIn(
137-
"value: 1.500000 total=0.260000 left=0.000000 right=0.186667",
138-
_test_criterion_printf(c2),
139-
)
128+
if i == 1:
129+
self.assertIn(
130+
"value: 1.500000 total=0.260000 left=0.000000 right=0.186667",
131+
_test_criterion_printf(c1),
132+
)
133+
self.assertIn(
134+
"value: 1.500000 total=0.260000 left=0.000000 right=0.186667",
135+
_test_criterion_printf(c2),
136+
)
140137
self.assertEqual(_test_criterion_printf(c1), _test_criterion_printf(c2))
141-
self.assertAlmostEqual(p1, p2, atol=1e-10)
138+
self.assertInAlmostEqual(
139+
p1, (0, p2), atol=1e-10
140+
) # 0 if the function is not called
141+
self.assertAlmostEqual(expected_p2[i - 1], p2, atol=1e-10)
142142

143143
X = numpy.array([[1.0, 2.0, 10.0, 11.0]]).T
144144
y = numpy.array([0.9, 1.1, 1.9, 2.1])
@@ -159,37 +159,62 @@ def test_criterions(self):
159159
p2 = _test_criterion_proxy_impurity_improvement(c2)
160160
self.assertTrue(numpy.isnan(p1), numpy.isnan(p2))
161161

162+
expected_p2 = [-0.32, -0.02]
163+
162164
for i in range(2, 4):
163165
_test_criterion_update(c1, i)
164166
_test_criterion_update(c2, i)
165167
left1, right1 = _test_criterion_node_impurity_children(c1)
166168
left2, right2 = _test_criterion_node_impurity_children(c2)
167-
self.assertAlmostEqual(left1, left2)
168-
self.assertAlmostEqual(right1, right2)
169+
self.assertAlmostEqual(left1, left2, atol=1e-10)
170+
self.assertAlmostEqual(right1, right2, atol=1e-10)
169171
v1 = _test_criterion_node_value(c1)
170172
v2 = _test_criterion_node_value(c2)
171173
self.assertEqual(v1, v2)
172174
p1 = _test_criterion_impurity_improvement(c1, 0.0, left1, right1)
173175
p2 = _test_criterion_impurity_improvement(c2, 0.0, left2, right2)
174-
self.assertAlmostEqual(p1, p2)
176+
self.assertInAlmostEqual(
177+
p1, (0, p2), atol=1e-10
178+
) # 0 if the function is not called
179+
self.assertAlmostEqual(expected_p2[i - 2], p2, atol=1e-10)
175180

176181
def test_decision_tree_criterion(self):
177182
from mlinsights.mlmodel.piecewise_tree_regression_criterion import (
178183
SimpleRegressorCriterion,
179184
)
180185

186+
debug = __name__ == "__main__"
187+
181188
X = numpy.array([[1.0, 2.0, 10.0, 11.0]]).T
182189
y = numpy.array([0.9, 1.1, 1.9, 2.1])
190+
if debug:
191+
print("create the tree")
183192
clr1 = DecisionTreeRegressor(max_depth=1)
193+
if debug:
194+
print("train the tree")
184195
clr1.fit(X, y)
196+
if debug:
197+
print("predict with the tree")
185198
p1 = clr1.predict(X)
199+
if debug:
200+
print(f"done {p1}")
186201

202+
if debug:
203+
print("create the criterion")
187204
crit = SimpleRegressorCriterion(
188205
1 if len(y.shape) <= 1 else y.shape[1], X.shape[0]
189206
)
207+
if debug:
208+
print("create the new tree")
190209
clr2 = DecisionTreeRegressor(criterion=crit, max_depth=1)
210+
if debug:
211+
print("train the new tree")
191212
clr2.fit(X, y)
213+
if debug:
214+
print("predict with the new tree")
192215
p2 = clr2.predict(X)
216+
if debug:
217+
print(f"done {p2}")
193218
self.assertEqual(p1, p2)
194219
self.assertEqual(clr1.tree_.node_count, clr2.tree_.node_count)
195220

_unittests/ut_mlmodel/test_piecewise_decision_tree_experiment_fast.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,37 @@
11
# -*- coding: utf-8 -*-
22
import unittest
3+
import warnings
34
import numpy
45
from sklearn.tree._criterion import MSE
56
from sklearn.tree import DecisionTreeRegressor
67
from sklearn import datasets
78
from mlinsights.ext_test_case import ExtTestCase
89
from mlinsights.mlmodel.piecewise_tree_regression import PiecewiseTreeRegressor
9-
from mlinsights.mlmodel._piecewise_tree_regression_common import (
10-
_test_criterion_init,
11-
_test_criterion_node_impurity,
12-
_test_criterion_node_impurity_children,
13-
_test_criterion_update,
14-
_test_criterion_node_value,
15-
_test_criterion_proxy_impurity_improvement,
16-
_test_criterion_impurity_improvement,
17-
)
18-
from mlinsights.mlmodel._piecewise_tree_regression_common import (
19-
assert_criterion_equal,
20-
)
21-
from mlinsights.mlmodel.piecewise_tree_regression_criterion_fast import (
22-
SimpleRegressorCriterionFast,
23-
)
10+
11+
with warnings.catch_warnings(record=True) as w:
12+
warnings.simplefilter("always")
13+
from mlinsights.mlmodel._piecewise_tree_regression_common import (
14+
_test_criterion_init,
15+
_test_criterion_node_impurity,
16+
_test_criterion_node_impurity_children,
17+
_test_criterion_update,
18+
_test_criterion_node_value,
19+
_test_criterion_proxy_impurity_improvement,
20+
_test_criterion_impurity_improvement,
21+
)
22+
from mlinsights.mlmodel._piecewise_tree_regression_common import (
23+
assert_criterion_equal,
24+
)
25+
from mlinsights.mlmodel.piecewise_tree_regression_criterion_fast import (
26+
SimpleRegressorCriterionFast,
27+
)
2428

2529

2630
class TestPiecewiseDecisionTreeExperimentFast(ExtTestCase):
27-
@unittest.skip(
28-
reason="self.y = y raises: Fatal Python error: "
29-
"__pyx_fatalerror: Acquisition count is"
30-
)
31+
# @unittest.skip(
32+
# reason="self.y = y raises: Fatal Python error: "
33+
# "__pyx_fatalerror: Acquisition count is"
34+
# )
3135
def test_criterions(self):
3236
X = numpy.array([[1.0, 2.0]]).T
3337
y = numpy.array([1.0, 2.0])

_unittests/ut_mlmodel/test_piecewise_decision_tree_experiment_linear.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,35 @@
11
# -*- coding: utf-8 -*-
22
import unittest
3+
import warnings
34
import numpy
45
from sklearn.tree._criterion import MSE
56
from sklearn.tree import DecisionTreeRegressor
67
from sklearn import datasets
78
from sklearn.model_selection import train_test_split
89
from mlinsights.ext_test_case import ExtTestCase
910
from mlinsights.mlmodel.piecewise_tree_regression import PiecewiseTreeRegressor
10-
from mlinsights.mlmodel._piecewise_tree_regression_common import (
11-
_test_criterion_init,
12-
_test_criterion_node_impurity,
13-
_test_criterion_node_impurity_children,
14-
_test_criterion_update,
15-
_test_criterion_node_value,
16-
_test_criterion_proxy_impurity_improvement,
17-
_test_criterion_impurity_improvement,
18-
)
19-
from mlinsights.mlmodel.piecewise_tree_regression_criterion_linear import (
20-
LinearRegressorCriterion,
21-
)
11+
12+
with warnings.catch_warnings(record=True) as w:
13+
warnings.simplefilter("always")
14+
from mlinsights.mlmodel._piecewise_tree_regression_common import (
15+
_test_criterion_init,
16+
_test_criterion_node_impurity,
17+
_test_criterion_node_impurity_children,
18+
_test_criterion_update,
19+
_test_criterion_node_value,
20+
_test_criterion_proxy_impurity_improvement,
21+
_test_criterion_impurity_improvement,
22+
)
23+
from mlinsights.mlmodel.piecewise_tree_regression_criterion_linear import (
24+
LinearRegressorCriterion,
25+
)
2226

2327

2428
class TestPiecewiseDecisionTreeExperimentLinear(ExtTestCase):
25-
@unittest.skip(
26-
reason="self.y = y raises: Fatal Python error: "
27-
"__pyx_fatalerror: Acquisition count is"
28-
)
29+
# @unittest.skip(
30+
# reason="self.y = y raises: Fatal Python error: "
31+
# "__pyx_fatalerror: Acquisition count is"
32+
# )
2933
def test_criterions(self):
3034
X = numpy.array([[10.0, 12.0, 13.0]]).T
3135
y = numpy.array([20.0, 22.0, 23.0])
@@ -127,10 +131,10 @@ def test_criterions(self):
127131
self.assertGreater(dest[0], 0)
128132
self.assertGreater(dest[1], 0)
129133

130-
@unittest.skip(
131-
reason="self.y = y raises: Fatal Python error: "
132-
"__pyx_fatalerror: Acquisition count is"
133-
)
134+
# @unittest.skip(
135+
# reason="self.y = y raises: Fatal Python error: "
136+
# "__pyx_fatalerror: Acquisition count is"
137+
# )
134138
def test_criterions_check_value(self):
135139
X = numpy.array([[10.0, 12.0, 13.0]]).T
136140
y = numpy.array([[20.0, 22.0, 23.0]]).T
@@ -164,10 +168,10 @@ def test_decision_tree_criterion_iris(self):
164168
p2 = clr2.predict(X)
165169
self.assertEqual(p1.shape, p2.shape)
166170

167-
@unittest.skip(
168-
reason="self.y = y raises: Fatal Python error: "
169-
"__pyx_fatalerror: Acquisition count is"
170-
)
171+
# @unittest.skip(
172+
# reason="self.y = y raises: Fatal Python error: "
173+
# "__pyx_fatalerror: Acquisition count is"
174+
# )
171175
def test_decision_tree_criterion_iris_dtc(self):
172176
iris = datasets.load_iris()
173177
X, y = iris.data, iris.target
@@ -191,10 +195,10 @@ def test_decision_tree_criterion_iris_dtc(self):
191195
self.assertIsInstance(mp, dict)
192196
self.assertGreater(len(mp), 2)
193197

194-
@unittest.skip(
195-
reason="self.y = y raises: Fatal Python error: "
196-
"__pyx_fatalerror: Acquisition count is"
197-
)
198+
# @unittest.skip(
199+
# reason="self.y = y raises: Fatal Python error: "
200+
# "__pyx_fatalerror: Acquisition count is"
201+
# )
198202
def test_decision_tree_criterion_iris_dtc_traintest(self):
199203
iris = datasets.load_iris()
200204
X, y = iris.data, iris.target

mlinsights/ext_test_case.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from contextlib import redirect_stderr, redirect_stdout
1010
from io import StringIO
1111
from timeit import Timer
12-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
1313
import numpy
1414
from numpy.testing import assert_allclose
1515
import pandas
@@ -231,6 +231,25 @@ def assertEqualDataFrame(self, d1, d2, **kwargs):
231231

232232
assert_frame_equal(d1, d2, **kwargs)
233233

234+
def assertInAlmostEqual(
235+
self,
236+
value: float,
237+
expected_values: Sequence[float],
238+
atol: float = 0,
239+
rtol: float = 0,
240+
):
241+
last_e = None
242+
for s in expected_values:
243+
try:
244+
self.assertAlmostEqual(value, s, atol=atol, rtol=rtol)
245+
return
246+
except AssertionError as e:
247+
last_e = e
248+
if last_e is not None:
249+
raise AssertionError(
250+
f"Value {value} not in set {expected_values}."
251+
) from last_e
252+
234253
def assertAlmostEqual(
235254
self,
236255
expected: numpy.ndarray,

0 commit comments

Comments
 (0)