11import numpy as np
2+ from collections import Counter
23import pandas
34import numbers
45from warnings import warn
6+
57from sklearn .base import BaseEstimator
68from sklearn .utils .validation import check_X_y , check_array , check_is_fitted
79from sklearn .utils .multiclass import check_classification_targets
8-
910from sklearn .tree import DecisionTreeClassifier , DecisionTreeRegressor
10-
1111from sklearn .ensemble import BaggingClassifier , BaggingRegressor
12-
1312from sklearn .externals import six
1413from sklearn .tree import _tree
1514
@@ -66,6 +65,9 @@ class SkopeRules(BaseEstimator):
6665 expanded until all leaves are pure or until all leaves contain less
6766 than min_samples_split samples.
6867
68+ max_depth_duplication : integer or None, optional (default=3)
69+ The maximum depth of the decision tree for rule deduplication.
70+
6971 max_features : int, float, string or None, optional (default="auto")
7072 The number of features considered (by each decision tree) when looking
7173 for the best split:
@@ -144,6 +146,7 @@ def __init__(self,
144146 bootstrap = False ,
145147 bootstrap_features = False ,
146148 max_depth = 3 ,
149+ max_depth_duplication = 3 ,
147150 max_features = 1. ,
148151 min_samples_split = 2 ,
149152 n_jobs = 1 ,
@@ -159,6 +162,7 @@ def __init__(self,
159162 self .bootstrap = bootstrap
160163 self .bootstrap_features = bootstrap_features
161164 self .max_depth = max_depth
165+ self .max_depth_duplication = max_depth_duplication
162166 self .max_features = max_features
163167 self .min_samples_split = min_samples_split
164168 self .n_jobs = n_jobs
@@ -357,34 +361,10 @@ def fit(self, X, y, sample_weight=None):
357361 self .rules_ = sorted (self .rules_ .items (),
358362 key = lambda x : (x [1 ][0 ], x [1 ][1 ]), reverse = True )
359363
360- # removing rules which have very similar domains
361- X_ = pandas .DataFrame (X , columns = np .array (self .feature_names_ ))
362- omit_these_rules_list = []
363- perimeter_index_of_all_rules = []
364- for i in range (len (self .rules_ )):
365- current = self .rules_ [i ]
366- perimeter_index_of_all_rules .append (
367- set (list (X_ .query (current [0 ]).index ))
368- )
369- index_current = perimeter_index_of_all_rules [i ]
370-
371- for j in range (i ):
372- if j in omit_these_rules_list :
373- continue
374- # if a rule have already been discarded,
375- # it should not be processed again
376-
377- index_rival = perimeter_index_of_all_rules [j ]
378- size_union = len (index_rival .union (index_current ))
379- size_intersection = len (
380- index_rival .intersection (index_current ))
381-
382- if float (size_intersection )/ size_union > self .similarity_thres :
383- omit_these_rules_list .append (j )
384-
385- self .rules_ = [self .rules_ [i ] for i in range (
386- len (self .rules_ )) if i not in omit_these_rules_list ]
387-
364+ # count representation of feature
365+ if self .max_depth_duplication is not None :
366+ self .rules_ = self .deduplicate (self .rules_ )
367+ # TODO : Factorize disjoints performing rules (ex : c0 > 0 and c1 > 1 , c0 > 0 and c1 <= 1)
388368 return self
389369
390370 def predict (self , X ):
@@ -613,3 +593,66 @@ def _eval_rule_perf(self, rule, X, y):
613593 return (0 , 0 )
614594 pos = y [y > 0 ].sum ()
615595 return y_detected .mean (), float (true_pos ) / pos
596+
597+ def deduplicate (self , rules ):
598+ return [max (rules_set , key = self .f1_score ) for rules_set in self ._find_similar_rulesets (rules )]
599+
600+ def _find_similar_rulesets (self , rules ):
601+ """Create clusters of rules using a decision tree based on the terms of the rules
602+
603+ Parameters
604+ ----------
605+ rules : List, List of rules
606+
607+ Returns
608+ -------
609+ rules : List of list of rules
610+
611+ """
612+ def split_with_best_feature (rules , depth , exceptions = []):
613+ """
614+ Method to find a split of rules given most represented feature
615+ """
616+ if depth == 0 :
617+ return rules
618+
619+ rulelist = [rule .split (' and ' ) for rule , score in rules ]
620+ terms = [t .split (' ' )[0 ] for term in rulelist for t in term ]
621+ counter = Counter (terms )
622+ # Drop exception list
623+ for exception in exceptions :
624+ del counter [exception ]
625+
626+ if len (counter ) == 0 :
627+ return rules
628+
629+ most_represented_term = counter .most_common ()[0 ][0 ]
630+ # Proceed to split
631+ rules_splitted = [[], [], []]
632+ for rule in rules :
633+ if (most_represented_term + ' <=' ) in rule [0 ]:
634+ rules_splitted [0 ].append (rule )
635+ elif (most_represented_term + ' >' ) in rule [0 ]:
636+ rules_splitted [1 ].append (rule )
637+ else :
638+ rules_splitted [2 ].append (rule )
639+
640+ # Choose best term
641+ return [split_with_best_feature (ruleset , depth - 1 , exceptions = exceptions + [most_represented_term ]) for ruleset in rules_splitted ]
642+
643+
644+ def breadth_first_search (rules , leaves = None ):
645+ if len (rules ) == 0 or not isinstance (rules [0 ], list ):
646+ if len (rules )> 0 :
647+ return leaves .append (rules )
648+ else :
649+ for rules_child in rules :
650+ breadth_first_search (rules_child , leaves = leaves )
651+ return leaves
652+ leaves = []
653+ res = split_with_best_feature (rules , self .max_depth_duplication )
654+ breadth_first_search (res , leaves = leaves )
655+ return leaves
656+
657+ def f1_score (self , x ):
658+ return 2 * x [1 ][0 ] * x [1 ][1 ] / (x [1 ][0 ] + x [1 ][1 ]) if (x [1 ][0 ] + x [1 ][1 ]) > 0 else 0
0 commit comments