Skip to content

Commit 0ad0677

Browse files
author
Florian Gardin
committed
fix pep8
1 parent 75091e3 commit 0ad0677

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

skrules/skope_rules.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
INTEGER_TYPES = (numbers.Integral, np.integer)
1616

17+
1718
class SkopeRules(BaseEstimator):
1819
""" An easy-interpretable classifier optimizing simple logical rules.
1920
@@ -158,7 +159,8 @@ def __init__(self,
158159
self.bootstrap = bootstrap
159160
self.bootstrap_features = bootstrap_features
160161
self.max_depth = max_depth
161-
self.max_depths = max_depth if isinstance(max_depth, Iterable) else [max_depth]
162+
self.max_depths = max_depth \
163+
if isinstance(max_depth, Iterable) else [max_depth]
162164
self.max_depth_duplication = max_depth_duplication
163165
self.max_features = max_features
164166
self.min_samples_split = min_samples_split
@@ -361,7 +363,6 @@ def fit(self, X, y, sample_weight=None):
361363
# count representation of feature
362364
if self.max_depth_duplication is not None:
363365
self.rules_ = self.deduplicate(self.rules_)
364-
# TODO : Factorize disjoints performing rules (ex : c0 > 0 and c1 > 1 , c0 > 0 and c1 <= 1)
365366
return self
366367

367368
def predict(self, X):
@@ -592,10 +593,12 @@ def _eval_rule_perf(self, rule, X, y):
592593
return y_detected.mean(), float(true_pos) / pos
593594

594595
def deduplicate(self, rules):
595-
return [max(rules_set, key=self.f1_score) for rules_set in self._find_similar_rulesets(rules)]
596+
return [max(rules_set, key=self.f1_score)
597+
for rules_set in self._find_similar_rulesets(rules)]
596598

597599
def _find_similar_rulesets(self, rules):
598-
"""Create clusters of rules using a decision tree based on the terms of the rules
600+
"""Create clusters of rules using a decision tree based
601+
on the terms of the rules
599602
600603
Parameters
601604
----------
@@ -635,11 +638,11 @@ def split_with_best_feature(rules, depth, exceptions=[]):
635638
rules_splitted[1].append(rule)
636639
else:
637640
rules_splitted[2].append(rule)
638-
641+
new_exceptions = exceptions+[most_represented_term]
639642
# Choose best term
640643
return [split_with_best_feature(ruleset,
641644
depth-1,
642-
exceptions=exceptions+[most_represented_term])
645+
exceptions=new_exceptions)
643646
for ruleset in rules_splitted]
644647

645648
def breadth_first_search(rules, leaves=None):
@@ -656,4 +659,5 @@ def breadth_first_search(rules, leaves=None):
656659
return leaves
657660

658661
def f1_score(self, x):
659-
return 2 * x[1][0] * x[1][1] / (x[1][0] + x[1][1]) if (x[1][0] + x[1][1]) > 0 else 0
662+
return 2 * x[1][0] * x[1][1] / \
663+
(x[1][0] + x[1][1]) if (x[1][0] + x[1][1]) > 0 else 0

skrules/tests/test_skope_rules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_skope_rules_works():
132132
rules_vote = clf.rules_vote(X_test)
133133
score_top_rules = clf.score_top_rules(X_test)
134134
pred = clf.predict(X_test)
135-
pred_score_top_rules = clf.predict_top_rules(X_test,1)
135+
pred_score_top_rules = clf.predict_top_rules(X_test, 1)
136136
# assert detect outliers:
137137
assert_greater(np.min(decision_func[-2:]), np.max(decision_func[:-2]))
138138
assert_greater(np.min(rules_vote[-2:]), np.max(rules_vote[:-2]))
@@ -141,6 +141,7 @@ def test_skope_rules_works():
141141
assert_array_equal(pred, 6 * [0] + 2 * [1])
142142
assert_array_equal(pred_score_top_rules, 6 * [0] + 2 * [1])
143143

144+
144145
def test_deduplication_works():
145146
# toy sample (the last two samples are outliers)
146147
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [6, 3], [4, -7]]
@@ -154,7 +155,7 @@ def test_deduplication_works():
154155
rules_vote = clf.rules_vote(X_test)
155156
score_top_rules = clf.score_top_rules(X_test)
156157
pred = clf.predict(X_test)
157-
pred_score_top_rules = clf.predict_top_rules(X_test,1)
158+
pred_score_top_rules = clf.predict_top_rules(X_test, 1)
158159

159160

160161
def test_performances():

0 commit comments

Comments
 (0)