Skip to content

Commit e80cb94

Browse files
author
Florian Gardin
committed
add new deduplication algorithm
1 parent 5f4a0f4 commit e80cb94

File tree

2 files changed

+119
-31
lines changed

2 files changed

+119
-31
lines changed

skrules/skope_rules.py

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import numpy as np
2+
from collections import Counter
23
import pandas
34
import numbers
45
from warnings import warn
6+
57
from sklearn.base import BaseEstimator
68
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
79
from sklearn.utils.multiclass import check_classification_targets
8-
910
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
10-
1111
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
12-
1312
from sklearn.externals import six
1413
from sklearn.tree import _tree
1514

@@ -66,6 +65,9 @@ class SkopeRules(BaseEstimator):
6665
expanded until all leaves are pure or until all leaves contain less
6766
than min_samples_split samples.
6867
68+
max_depth_duplication : integer or None, optional (default=3)
69+
The maximum depth of the decision tree for rule deduplication.
70+
6971
max_features : int, float, string or None, optional (default="auto")
7072
The number of features considered (by each decision tree) when looking
7173
for the best split:
@@ -144,6 +146,7 @@ def __init__(self,
144146
bootstrap=False,
145147
bootstrap_features=False,
146148
max_depth=3,
149+
max_depth_duplication=3,
147150
max_features=1.,
148151
min_samples_split=2,
149152
n_jobs=1,
@@ -159,6 +162,7 @@ def __init__(self,
159162
self.bootstrap = bootstrap
160163
self.bootstrap_features = bootstrap_features
161164
self.max_depth = max_depth
165+
self.max_depth_duplication = max_depth_duplication
162166
self.max_features = max_features
163167
self.min_samples_split = min_samples_split
164168
self.n_jobs = n_jobs
@@ -357,34 +361,10 @@ def fit(self, X, y, sample_weight=None):
357361
self.rules_ = sorted(self.rules_.items(),
358362
key=lambda x: (x[1][0], x[1][1]), reverse=True)
359363

360-
# removing rules which have very similar domains
361-
X_ = pandas.DataFrame(X, columns=np.array(self.feature_names_))
362-
omit_these_rules_list = []
363-
perimeter_index_of_all_rules = []
364-
for i in range(len(self.rules_)):
365-
current = self.rules_[i]
366-
perimeter_index_of_all_rules.append(
367-
set(list(X_.query(current[0]).index))
368-
)
369-
index_current = perimeter_index_of_all_rules[i]
370-
371-
for j in range(i):
372-
if j in omit_these_rules_list:
373-
continue
374-
# if a rule have already been discarded,
375-
# it should not be processed again
376-
377-
index_rival = perimeter_index_of_all_rules[j]
378-
size_union = len(index_rival.union(index_current))
379-
size_intersection = len(
380-
index_rival.intersection(index_current))
381-
382-
if float(size_intersection)/size_union > self.similarity_thres:
383-
omit_these_rules_list.append(j)
384-
385-
self.rules_ = [self.rules_[i] for i in range(
386-
len(self.rules_)) if i not in omit_these_rules_list]
387-
364+
# count representation of feature
365+
if self.max_depth_duplication is not None:
366+
self.rules_ = self.deduplicate(self.rules_)
367+
# TODO : Factorize disjoints performing rules (ex : c0 > 0 and c1 > 1 , c0 > 0 and c1 <= 1)
388368
return self
389369

390370
def predict(self, X):
@@ -613,3 +593,66 @@ def _eval_rule_perf(self, rule, X, y):
613593
return (0, 0)
614594
pos = y[y > 0].sum()
615595
return y_detected.mean(), float(true_pos) / pos
596+
597+
def deduplicate(self, rules):
598+
return [max(rules_set, key=self.f1_score) for rules_set in self._find_similar_rulesets(rules)]
599+
600+
def _find_similar_rulesets(self, rules):
601+
"""Create clusters of rules using a decision tree based on the terms of the rules
602+
603+
Parameters
604+
----------
605+
rules : List, List of rules
606+
607+
Returns
608+
-------
609+
rules : List of list of rules
610+
611+
"""
612+
def split_with_best_feature(rules, depth, exceptions=[]):
613+
"""
614+
Method to find a split of rules given most represented feature
615+
"""
616+
if depth == 0:
617+
return rules
618+
619+
rulelist = [rule.split(' and ') for rule, score in rules]
620+
terms = [t.split(' ')[0] for term in rulelist for t in term]
621+
counter = Counter(terms)
622+
# Drop exception list
623+
for exception in exceptions:
624+
del counter[exception]
625+
626+
if len(counter) == 0:
627+
return rules
628+
629+
most_represented_term = counter.most_common()[0][0]
630+
# Proceed to split
631+
rules_splitted = [[], [], []]
632+
for rule in rules:
633+
if (most_represented_term + ' <=') in rule[0]:
634+
rules_splitted[0].append(rule)
635+
elif (most_represented_term + ' >') in rule[0]:
636+
rules_splitted[1].append(rule)
637+
else:
638+
rules_splitted[2].append(rule)
639+
640+
# Choose best term
641+
return [split_with_best_feature(ruleset, depth-1, exceptions=exceptions+[most_represented_term]) for ruleset in rules_splitted]
642+
643+
644+
def breadth_first_search(rules, leaves=None):
645+
if len(rules) == 0 or not isinstance(rules[0], list):
646+
if len(rules)>0:
647+
return leaves.append(rules)
648+
else:
649+
for rules_child in rules:
650+
breadth_first_search(rules_child, leaves=leaves)
651+
return leaves
652+
leaves = []
653+
res = split_with_best_feature(rules, self.max_depth_duplication)
654+
breadth_first_search(res, leaves=leaves)
655+
return leaves
656+
657+
def f1_score(self, x):
658+
return 2 * x[1][0] * x[1][1] / (x[1][0] + x[1][1]) if (x[1][0] + x[1][1]) > 0 else 0

skrules/tests/test_skope_rules.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from sklearn.utils.testing import assert_raises
1414
from sklearn.utils.testing import assert_warns_message
1515
from sklearn.utils.testing import assert_equal
16+
from sklearn.utils.testing import assert_in
17+
from sklearn.utils.testing import assert_not_in
18+
from sklearn.utils.testing import assert_not_equal
1619
from sklearn.utils.testing import assert_no_warnings
1720
from sklearn.utils.testing import assert_greater
1821
from sklearn.utils.testing import ignore_warnings
@@ -170,3 +173,45 @@ def test_performances():
170173
assert_equal(decision.shape, (n_samples,))
171174
dec_pred = (decision.ravel() < 0).astype(np.int)
172175
assert_array_equal(dec_pred, y_pred)
176+
177+
178+
def test_similarity_tree():
179+
## Test that rules are well splitted
180+
rules = [("a <= 2 and b > 45 and c <= 3 and a > 4", (1, 1, 0)),
181+
("a <= 2 and b > 45 and c <= 3 and a > 4", (1, 1, 0)),
182+
("a > 2 and b > 45", (0.5, 0.3, 0)),
183+
("a > 2 and b > 40", (0.5, 0.2, 0)),
184+
("a <= 2 and b <= 45", (1, 1, 0)),
185+
("a > 2 and c <= 3", (1, 1, 0)),
186+
("b > 45", (1, 1, 0)),
187+
]
188+
189+
sk = SkopeRules(max_depth_duplication=2)
190+
rulesets = sk._find_similar_rulesets(rules)
191+
# Assert some couples of rules are in the same bag
192+
idx_bags_rules = []
193+
for idx_rule, r in enumerate(rules):
194+
idx_bags_for_rule = []
195+
for idx_bag, bag in enumerate(rulesets):
196+
if r in bag:
197+
idx_bags_for_rule.append(idx_bag)
198+
idx_bags_rules.append(idx_bags_for_rule)
199+
200+
assert_equal(idx_bags_rules[0], idx_bags_rules[1])
201+
assert_not_equal(idx_bags_rules[0], idx_bags_rules[2])
202+
# Assert the best rules are kept
203+
final_rules = sk.deduplicate(rules)
204+
assert_in(rules[0], final_rules)
205+
assert_in(rules[2], final_rules)
206+
assert_not_in(rules[3], final_rules)
207+
208+
209+
def test_f1_score():
210+
clf = SkopeRules()
211+
rule0 = ('a > 0', (0, 0, 0))
212+
rule1 = ('a > 0', (0.5, 0.5, 0))
213+
rule2 = ('a > 0', (0.5, 0, 0))
214+
215+
assert_equal(clf.f1_score(rule0), 0)
216+
assert_equal(clf.f1_score(rule1), 0.5)
217+
assert_equal(clf.f1_score(rule2), 0)

0 commit comments

Comments
 (0)