Skip to content

Commit 38a2f23

Browse files
committed
adapt code to new check_estimator + cosmit
1 parent 6ea5a25 commit 38a2f23

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

skrules/rule.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import re
22

3+
34
def replace_feature_name(rule, replace_dict):
45
def replace(match):
56
return replace_dict[match.group(0)]
67

78
rule = re.sub('|'.join(r'\b%s\b' % re.escape(s) for s in replace_dict),
8-
replace, rule)
9+
replace, rule)
910
return rule
1011

12+
1113
class Rule:
12-
""" An object modelizing a logical rule and add factorization methods.
14+
""" An object modelling a logical rule and add factorization methods.
1315
It is used to simplify rules and deduplicate them.
1416
1517
Parameters
@@ -66,4 +68,3 @@ def __repr__(self):
6668
[feature, symbol, str(self.agg_dict[(feature, symbol)])])
6769
for feature, symbol in sorted(self.agg_dict.keys())
6870
])
69-

skrules/skope_rules.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
INTEGER_TYPES = (numbers.Integral, np.integer)
1818
BASE_FEATURE_NAME = "__C__"
1919

20+
2021
class SkopeRules(BaseEstimator):
21-
""" An easy-interpretable classifier optimizing simple logical rules.
22+
"""An easy-interpretable classifier optimizing simple logical rules.
2223
2324
Parameters
2425
----------
@@ -161,8 +162,6 @@ def __init__(self,
161162
self.bootstrap = bootstrap
162163
self.bootstrap_features = bootstrap_features
163164
self.max_depth = max_depth
164-
self.max_depths = max_depth \
165-
if isinstance(max_depth, Iterable) else [max_depth]
166165
self.max_depth_duplication = max_depth_duplication
167166
self.max_features = max_features
168167
self.min_samples_split = min_samples_split
@@ -251,7 +250,7 @@ def fit(self, X, y, sample_weight=None):
251250

252251
# default columns names :
253252
feature_names_ = [BASE_FEATURE_NAME + x for x in
254-
np.arange(X.shape[1]).astype(str)]
253+
np.arange(X.shape[1]).astype(str)]
255254
if self.feature_names is not None:
256255
self.feature_dict_ = {BASE_FEATURE_NAME + str(i): feat
257256
for i, feat in enumerate(self.feature_names)}
@@ -263,7 +262,10 @@ def fit(self, X, y, sample_weight=None):
263262
clfs = []
264263
regs = []
265264

266-
for max_depth in self.max_depths:
265+
self._max_depths = self.max_depth \
266+
if isinstance(self.max_depth, Iterable) else [self.max_depth]
267+
268+
for max_depth in self._max_depths:
267269
bagging_clf = BaggingClassifier(
268270
base_estimator=DecisionTreeClassifier(
269271
max_depth=max_depth,
@@ -362,10 +364,6 @@ def fit(self, X, y, sample_weight=None):
362364
for rule in
363365
[Rule(r, args=args) for r, args in rules_]]
364366

365-
366-
367-
368-
369367
# keep only rules verifying precision_min and recall_min:
370368
for rule, score in rules_:
371369
if score[0] >= self.precision_min and score[1] >= self.recall_min:
@@ -393,7 +391,7 @@ def fit(self, X, y, sample_weight=None):
393391

394392
# Replace generic feature names by real feature names
395393
self.rules_ = [(replace_feature_name(rule, self.feature_dict_), perf)
396-
for rule, perf in self.rules_]
394+
for rule, perf in self.rules_]
397395

398396
return self
399397

0 commit comments

Comments
 (0)