1717INTEGER_TYPES = (numbers .Integral , np .integer )
1818BASE_FEATURE_NAME = "__C__"
1919
20+
2021class SkopeRules (BaseEstimator ):
21- """ An easy-interpretable classifier optimizing simple logical rules.
22+ """An easy-interpretable classifier optimizing simple logical rules.
2223
2324 Parameters
2425 ----------
@@ -161,8 +162,6 @@ def __init__(self,
161162 self .bootstrap = bootstrap
162163 self .bootstrap_features = bootstrap_features
163164 self .max_depth = max_depth
164- self .max_depths = max_depth \
165- if isinstance (max_depth , Iterable ) else [max_depth ]
166165 self .max_depth_duplication = max_depth_duplication
167166 self .max_features = max_features
168167 self .min_samples_split = min_samples_split
@@ -251,7 +250,7 @@ def fit(self, X, y, sample_weight=None):
251250
252251 # default columns names :
253252 feature_names_ = [BASE_FEATURE_NAME + x for x in
254- np .arange (X .shape [1 ]).astype (str )]
253+ np .arange (X .shape [1 ]).astype (str )]
255254 if self .feature_names is not None :
256255 self .feature_dict_ = {BASE_FEATURE_NAME + str (i ): feat
257256 for i , feat in enumerate (self .feature_names )}
@@ -263,7 +262,10 @@ def fit(self, X, y, sample_weight=None):
263262 clfs = []
264263 regs = []
265264
266- for max_depth in self .max_depths :
265+ self ._max_depths = self .max_depth \
266+ if isinstance (self .max_depth , Iterable ) else [self .max_depth ]
267+
268+ for max_depth in self ._max_depths :
267269 bagging_clf = BaggingClassifier (
268270 base_estimator = DecisionTreeClassifier (
269271 max_depth = max_depth ,
@@ -362,10 +364,6 @@ def fit(self, X, y, sample_weight=None):
362364 for rule in
363365 [Rule (r , args = args ) for r , args in rules_ ]]
364366
365-
366-
367-
368-
369367 # keep only rules verifying precision_min and recall_min:
370368 for rule , score in rules_ :
371369 if score [0 ] >= self .precision_min and score [1 ] >= self .recall_min :
@@ -393,7 +391,7 @@ def fit(self, X, y, sample_weight=None):
393391
394392 # Replace generic feature names by real feature names
395393 self .rules_ = [(replace_feature_name (rule , self .feature_dict_ ), perf )
396- for rule , perf in self .rules_ ]
394+ for rule , perf in self .rules_ ]
397395
398396 return self
399397
0 commit comments