Skip to content

Commit 59c03c5

Browse files
authored
Merge pull request #30 from floriangardin/master
add factorization of rules
2 parents 4ee939d + 81f08f3 commit 59c03c5

File tree

4 files changed

+131
-6
lines changed

4 files changed

+131
-6
lines changed

skrules/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .skope_rules import SkopeRules
2+
from .rule import Rule
23

3-
__all__ = ['SkopeRules']
4+
__all__ = ['SkopeRules', 'Rule']

skrules/rule.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
class Rule:
2+
""" An object modelizing a logical rule and add factorization methods.
3+
It is used to simplify rules and deduplicate them.
4+
5+
Parameters
6+
----------
7+
8+
rule : str
9+
The logical rule that is interpretable by a pandas query.
10+
11+
args : object, optional
12+
Arguments associated to the rule, it is not used for factorization
13+
but it takes part of the output when the rule is converted to an array.
14+
"""
15+
16+
def __init__(self, rule, args=None):
17+
self.rule = rule
18+
self.args = args
19+
self.terms = [t.split(' ') for t in self.rule.split(' and ')]
20+
self.agg_dict = {}
21+
self.factorize()
22+
self.rule = str(self)
23+
24+
def __eq__(self, other):
25+
return self.agg_dict == other.agg_dict
26+
27+
def __hash__(self):
28+
# FIXME : Easier method ?
29+
return hash(tuple(sorted(((i, j) for i, j in self.agg_dict.items()))))
30+
31+
def factorize(self):
32+
for feature, symbol, value in self.terms:
33+
if (feature, symbol) not in self.agg_dict:
34+
if symbol != '==':
35+
self.agg_dict[(feature, symbol)] = str(float(value))
36+
else:
37+
self.agg_dict[(feature, symbol)] = value
38+
else:
39+
if symbol[0] == '<':
40+
self.agg_dict[(feature, symbol)] = str(min(
41+
float(self.agg_dict[(feature, symbol)]),
42+
float(value)))
43+
elif symbol[0] == '>':
44+
self.agg_dict[(feature, symbol)] = str(max(
45+
float(self.agg_dict[(feature, symbol)]),
46+
float(value)))
47+
else: # Handle the c0 == c0 case
48+
self.agg_dict[(feature, symbol)] = value
49+
50+
def __iter__(self):
51+
yield str(self)
52+
yield self.args
53+
54+
def __repr__(self):
55+
return ' and '.join([' '.join(
56+
[feature, symbol, str(self.agg_dict[(feature, symbol)])])
57+
for feature, symbol in sorted(self.agg_dict.keys())
58+
])

skrules/skope_rules.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from sklearn.externals import six
1313
from sklearn.tree import _tree
1414

15+
from .rule import Rule
16+
1517
INTEGER_TYPES = (numbers.Integral, np.integer)
1618

1719

@@ -205,7 +207,8 @@ def fit(self, X, y, sample_weight=None):
205207
" in the data, but the data contains only one"
206208
" class: %r" % self.classes_[0])
207209

208-
if not isinstance(self.max_depth_duplication, int) and self.max_depth_duplication is not None:
210+
if not isinstance(self.max_depth_duplication, int) \
211+
and self.max_depth_duplication is not None:
209212
raise ValueError("max_depth_duplication should be an integer"
210213
)
211214
if not set(self.classes_) == set([0, 1]):
@@ -265,7 +268,8 @@ def fit(self, X, y, sample_weight=None):
265268
max_features=self.max_samples_features,
266269
bootstrap=self.bootstrap,
267270
bootstrap_features=self.bootstrap_features,
268-
# oob_score=... XXX may be added if selection on tree perf needed.
271+
# oob_score=... XXX may be added
272+
# if selection on tree perf needed.
269273
# warm_start=... XXX may be added to increase computation perf.
270274
n_jobs=self.n_jobs,
271275
random_state=self.random_state,
@@ -281,7 +285,8 @@ def fit(self, X, y, sample_weight=None):
281285
max_features=self.max_samples_features,
282286
bootstrap=self.bootstrap,
283287
bootstrap_features=self.bootstrap_features,
284-
# oob_score=... XXX may be added if selection on tree perf needed.
288+
# oob_score=... XXX may be added
289+
# if selection on tree perf needed.
285290
# warm_start=... XXX may be added to increase computation perf.
286291
n_jobs=self.n_jobs,
287292
random_state=self.random_state,
@@ -345,6 +350,12 @@ def fit(self, X, y, sample_weight=None):
345350
for r in set(rules_from_tree)]
346351
rules_ += rules_from_tree
347352

353+
# Factorize rules before semantic tree filtering
354+
rules_ = [
355+
tuple(rule)
356+
for rule in
357+
[Rule(r, args=args) for r, args in rules_]]
358+
348359
# keep only rules verifying precision_min and recall_min:
349360
for rule, score in rules_:
350361
if score[0] >= self.precision_min and score[1] >= self.recall_min:
@@ -363,7 +374,7 @@ def fit(self, X, y, sample_weight=None):
363374
self.rules_ = sorted(self.rules_.items(),
364375
key=lambda x: (x[1][0], x[1][1]), reverse=True)
365376

366-
# count representation of feature
377+
# Deduplicate the rule using semantic tree
367378
if self.max_depth_duplication is not None:
368379
self.rules_ = self.deduplicate(self.rules_)
369380
return self
@@ -576,7 +587,7 @@ def recurse(node, base_name):
576587
else:
577588
rule = str.join(' and ', base_name)
578589
rule = (rule if rule != ''
579-
else '=='.join([feature_names[0]] * 2))
590+
else ' == '.join([feature_names[0]] * 2))
580591
# a rule selecting all is set to "c0==c0"
581592
rules.append(rule)
582593

skrules/tests/test_rule.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from sklearn.utils.testing import assert_equal, assert_not_equal
2+
3+
from skrules import Rule
4+
5+
6+
def test_rule():
7+
assert_equal(Rule('a <= 10 and a <= 12'),
8+
Rule('a <= 10'))
9+
assert_equal(Rule('a <= 10 and a <= 12 and a > 3'),
10+
Rule('a > 3 and a <= 10'))
11+
12+
assert_equal(Rule('a <= 10 and a <= 10 and a > 3'),
13+
Rule('a > 3 and a <= 10'))
14+
15+
assert_equal(Rule('a <= 10 and a <= 12 and b > 3 and b > 6'),
16+
Rule('a <= 10 and b > 6'))
17+
18+
assert_equal(len({Rule('a <= 2 and a <= 3'),
19+
Rule('a <= 2')
20+
}), 1)
21+
22+
assert_equal(len({Rule('a > 2 and a > 3 and b <= 2 and b <= 3'),
23+
Rule('a > 3 and b <= 2')
24+
}), 1)
25+
26+
assert_equal(len({Rule('a <= 3 and b <= 2'),
27+
Rule('b <= 2 and a <= 3')
28+
}), 1)
29+
30+
31+
def test_hash_rule():
32+
assert_equal(len({
33+
Rule('a <= 2 and a <= 3'),
34+
Rule('a <= 2')
35+
}), 1)
36+
assert_not_equal(len({
37+
Rule('a <= 4 and a <= 3'),
38+
Rule('a <= 2')
39+
}), 1)
40+
41+
42+
def test_str_rule():
43+
rule = 'a <= 10.0 and b > 3.0'
44+
assert_equal(rule, str(Rule(rule)))
45+
46+
47+
def test_equals_rule():
48+
rule = "a == a"
49+
assert_equal(rule, str(Rule(rule)))
50+
51+
rule2 = "a == a and a == a"
52+
assert_equal(rule, str(Rule(rule2)))
53+
54+
rule3 = "a < 3.0 and a == a"
55+
assert_equal(rule3, str(Rule(rule3)))

0 commit comments

Comments
 (0)