Skip to content

Commit b7d4869

Browse files
authored
Merge pull request #23 from ngoix/fix_check_estimator
[MRG] Fix check estimator
2 parents 0cb31c7 + 38a2f23 commit b7d4869

File tree

4 files changed

+14
-15
lines changed

4 files changed

+14
-15
lines changed

doc/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ estimator with unit tests, along with examples and benchmarks.
2525
auto_examples/index
2626
...
2727

28-
See the `README <https://github.com/skope-rules/skope-rules/blob/master/README.md>`_
28+
See the `README <https://github.com/scikit-learn-contrib/skope-rules/blob/master/README.md>`_
2929
for more information.
3030

3131

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
setup(name='skope-rules',
1818
version='1.0.0',
1919
description='Machine Learning with Interpretable Rules',
20-
url='https://github.com/skope-rules/skope-rules',
20+
url='https://github.com/scikit-learn-contrib/skope-rules',
2121
author='see AUTHORS.rst',
2222
license='BSD 3 clause',
2323
packages=find_packages(),

skrules/rule.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import re
22

3+
34
def replace_feature_name(rule, replace_dict):
45
def replace(match):
56
return replace_dict[match.group(0)]
67

78
rule = re.sub('|'.join(r'\b%s\b' % re.escape(s) for s in replace_dict),
8-
replace, rule)
9+
replace, rule)
910
return rule
1011

12+
1113
class Rule:
12-
""" An object modelizing a logical rule and add factorization methods.
14+
""" An object modelling a logical rule and add factorization methods.
1315
It is used to simplify rules and deduplicate them.
1416
1517
Parameters
@@ -66,4 +68,3 @@ def __repr__(self):
6668
[feature, symbol, str(self.agg_dict[(feature, symbol)])])
6769
for feature, symbol in sorted(self.agg_dict.keys())
6870
])
69-

skrules/skope_rules.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
INTEGER_TYPES = (numbers.Integral, np.integer)
1818
BASE_FEATURE_NAME = "__C__"
1919

20+
2021
class SkopeRules(BaseEstimator):
21-
""" An easy-interpretable classifier optimizing simple logical rules.
22+
"""An easy-interpretable classifier optimizing simple logical rules.
2223
2324
Parameters
2425
----------
@@ -161,8 +162,6 @@ def __init__(self,
161162
self.bootstrap = bootstrap
162163
self.bootstrap_features = bootstrap_features
163164
self.max_depth = max_depth
164-
self.max_depths = max_depth \
165-
if isinstance(max_depth, Iterable) else [max_depth]
166165
self.max_depth_duplication = max_depth_duplication
167166
self.max_features = max_features
168167
self.min_samples_split = min_samples_split
@@ -251,7 +250,7 @@ def fit(self, X, y, sample_weight=None):
251250

252251
# default columns names :
253252
feature_names_ = [BASE_FEATURE_NAME + x for x in
254-
np.arange(X.shape[1]).astype(str)]
253+
np.arange(X.shape[1]).astype(str)]
255254
if self.feature_names is not None:
256255
self.feature_dict_ = {BASE_FEATURE_NAME + str(i): feat
257256
for i, feat in enumerate(self.feature_names)}
@@ -263,7 +262,10 @@ def fit(self, X, y, sample_weight=None):
263262
clfs = []
264263
regs = []
265264

266-
for max_depth in self.max_depths:
265+
self._max_depths = self.max_depth \
266+
if isinstance(self.max_depth, Iterable) else [self.max_depth]
267+
268+
for max_depth in self._max_depths:
267269
bagging_clf = BaggingClassifier(
268270
base_estimator=DecisionTreeClassifier(
269271
max_depth=max_depth,
@@ -362,10 +364,6 @@ def fit(self, X, y, sample_weight=None):
362364
for rule in
363365
[Rule(r, args=args) for r, args in rules_]]
364366

365-
366-
367-
368-
369367
# keep only rules verifying precision_min and recall_min:
370368
for rule, score in rules_:
371369
if score[0] >= self.precision_min and score[1] >= self.recall_min:
@@ -393,7 +391,7 @@ def fit(self, X, y, sample_weight=None):
393391

394392
# Replace generic feature names by real feature names
395393
self.rules_ = [(replace_feature_name(rule, self.feature_dict_), perf)
396-
for rule, perf in self.rules_]
394+
for rule, perf in self.rules_]
397395

398396
return self
399397

0 commit comments

Comments
 (0)