forked from biolab/orange3
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_tree.py
More file actions
125 lines (104 loc) · 4.14 KB
/
test_tree.py
File metadata and controls
125 lines (104 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
import unittest
from unittest.mock import Mock
import numpy as np
import sklearn.tree as skl_tree
from sklearn.tree._tree import TREE_LEAF
from Orange.data import Table
from Orange.classification import SklTreeLearner, TreeLearner
from Orange.regression import SklTreeRegressionLearner
class TestSklTreeLearner(unittest.TestCase):
def test_classification(self):
table = Table('iris')
learn = SklTreeLearner()
clf = learn(table)
Z = clf(table)
self.assertTrue(np.all(table.Y.flatten() == Z))
def test_regression(self):
table = Table('housing')
learn = SklTreeRegressionLearner()
model = learn(table)
pred = model(table)
self.assertTrue(np.all(table.Y.flatten() == pred))
class TestTreeLearner(unittest.TestCase):
def test_uses_preprocessors(self):
iris = Table('iris')
mock_preprocessor = Mock(return_value=iris)
tree = TreeLearner(preprocessors=[mock_preprocessor])
tree(iris)
mock_preprocessor.assert_called_with(iris)
class TestDecisionTreeClassifier(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.iris = Table('iris')
def test_full_tree(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier()
clf = clf.fit(table.X, table.Y)
Z = clf.predict(table.X)
self.assertTrue(np.all(table.Y.flatten() == Z))
def test_min_samples_split(self):
table = self.iris
lim = 5
clf = skl_tree.DecisionTreeClassifier(min_samples_split=lim)
clf = clf.fit(table.X, table.Y)
t = clf.tree_
for i in range(t.node_count):
if t.children_left[i] != TREE_LEAF:
self.assertGreaterEqual(t.n_node_samples[i], lim)
def test_min_samples_leaf(self):
table = self.iris
lim = 5
clf = skl_tree.DecisionTreeClassifier(min_samples_leaf=lim)
clf = clf.fit(table.X, table.Y)
t = clf.tree_
for i in range(t.node_count):
if t.children_left[i] == TREE_LEAF:
self.assertGreaterEqual(t.n_node_samples[i], lim)
def test_max_leaf_nodes(self):
table = self.iris
lim = 5
clf = skl_tree.DecisionTreeClassifier(max_leaf_nodes=lim)
clf = clf.fit(table.X, table.Y)
t = clf.tree_
self.assertLessEqual(t.node_count, lim * 2 - 1)
def test_criterion(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier(criterion="entropy")
clf = clf.fit(table.X, table.Y)
def test_splitter(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier(splitter="random")
clf = clf.fit(table.X, table.Y)
def test_weights(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier(max_depth=2)
clf = clf.fit(table.X, table.Y)
clfw = skl_tree.DecisionTreeClassifier(max_depth=2)
clfw = clfw.fit(table.X, table.Y, sample_weight=np.arange(len(table)))
self.assertFalse(len(clf.tree_.feature) == len(clfw.tree_.feature) and
np.all(clf.tree_.feature == clfw.tree_.feature))
def test_impurity(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier()
clf = clf.fit(table.X, table.Y)
t = clf.tree_
for i in range(t.node_count):
if t.children_left[i] == TREE_LEAF:
self.assertEqual(t.impurity[i], 0)
else:
l, r = t.children_left[i], t.children_right[i]
child_impurity = min(t.impurity[l], t.impurity[r])
self.assertLessEqual(child_impurity, t.impurity[i])
def test_navigate_tree(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier(max_depth=1)
clf = clf.fit(table.X, table.Y.reshape(-1, 1))
t = clf.tree_
x = table.X[0]
if x[t.feature[0]] <= t.threshold[0]:
v = t.value[t.children_left[0]][0]
else:
v = t.value[t.children_right[0]][0]
self.assertEqual(np.argmax(v), clf.predict(table.X[:1]))