1414
1515INTEGER_TYPES = (numbers .Integral , np .integer )
1616
17+
1718class SkopeRules (BaseEstimator ):
1819 """ An easy-interpretable classifier optimizing simple logical rules.
1920
@@ -158,7 +159,8 @@ def __init__(self,
158159 self .bootstrap = bootstrap
159160 self .bootstrap_features = bootstrap_features
160161 self .max_depth = max_depth
161- self .max_depths = max_depth if isinstance (max_depth , Iterable ) else [max_depth ]
162+ self .max_depths = max_depth \
163+ if isinstance (max_depth , Iterable ) else [max_depth ]
162164 self .max_depth_duplication = max_depth_duplication
163165 self .max_features = max_features
164166 self .min_samples_split = min_samples_split
@@ -361,7 +363,6 @@ def fit(self, X, y, sample_weight=None):
361363 # count representation of feature
362364 if self .max_depth_duplication is not None :
363365 self .rules_ = self .deduplicate (self .rules_ )
364- # TODO : Factorize disjoints performing rules (ex : c0 > 0 and c1 > 1 , c0 > 0 and c1 <= 1)
365366 return self
366367
367368 def predict (self , X ):
@@ -592,10 +593,12 @@ def _eval_rule_perf(self, rule, X, y):
592593 return y_detected .mean (), float (true_pos ) / pos
593594
594595 def deduplicate (self , rules ):
595- return [max (rules_set , key = self .f1_score ) for rules_set in self ._find_similar_rulesets (rules )]
596+ return [max (rules_set , key = self .f1_score )
597+ for rules_set in self ._find_similar_rulesets (rules )]
596598
597599 def _find_similar_rulesets (self , rules ):
598- """Create clusters of rules using a decision tree based on the terms of the rules
600+ """Create clusters of rules using a decision tree based
601+ on the terms of the rules
599602
600603 Parameters
601604 ----------
@@ -635,11 +638,11 @@ def split_with_best_feature(rules, depth, exceptions=[]):
635638 rules_splitted [1 ].append (rule )
636639 else :
637640 rules_splitted [2 ].append (rule )
638-
641+ new_exceptions = exceptions + [ most_represented_term ]
639642 # Choose best term
640643 return [split_with_best_feature (ruleset ,
641644 depth - 1 ,
642- exceptions = exceptions + [ most_represented_term ] )
645+ exceptions = new_exceptions )
643646 for ruleset in rules_splitted ]
644647
645648 def breadth_first_search (rules , leaves = None ):
@@ -656,4 +659,5 @@ def breadth_first_search(rules, leaves=None):
656659 return leaves
657660
658661 def f1_score (self , x ):
659- return 2 * x [1 ][0 ] * x [1 ][1 ] / (x [1 ][0 ] + x [1 ][1 ]) if (x [1 ][0 ] + x [1 ][1 ]) > 0 else 0
662+ return 2 * x [1 ][0 ] * x [1 ][1 ] / \
663+ (x [1 ][0 ] + x [1 ][1 ]) if (x [1 ][0 ] + x [1 ][1 ]) > 0 else 0
0 commit comments