Skip to content

Commit 2d1d29c

Browse files
author
Florian Gardin
committed
add choice of several max_depth
1 parent dd38580 commit 2d1d29c

File tree

1 file changed

+55
-46
lines changed

1 file changed

+55
-46
lines changed

skrules/skope_rules.py

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from collections import Counter
2+
from collections import Counter, Iterable
33
import pandas
44
import numbers
55
from warnings import warn
@@ -54,10 +54,13 @@ class SkopeRules(BaseEstimator):
5454
bootstrap_features : boolean, optional (default=False)
5555
Whether features are drawn with replacement.
5656
57-
max_depth : integer or None, optional (default=3)
57+
max_depth : integer or List or None, optional (default=3)
5858
The maximum depth of the decision trees. If None, then nodes are
5959
expanded until all leaves are pure or until all leaves contain less
6060
than min_samples_split samples.
61+
If an iterable is passed, you will train n_estimators
62+
for each tree depth. It allows you to create and compare
63+
rules of different length.
6164
6265
max_depth_duplication : integer or None, optional (default=3)
6366
The maximum depth of the decision tree for rule deduplication,
@@ -155,6 +158,7 @@ def __init__(self,
155158
self.bootstrap = bootstrap
156159
self.bootstrap_features = bootstrap_features
157160
self.max_depth = max_depth
161+
self.max_depths = max_depth if isinstance(max_depth, Iterable) else [max_depth]
158162
self.max_depth_duplication = max_depth_duplication
159163
self.max_features = max_features
160164
self.min_samples_split = min_samples_split
@@ -242,40 +246,44 @@ def fit(self, X, y, sample_weight=None):
242246
else ['c' + x for x in
243247
np.arange(X.shape[1]).astype(str)])
244248
self.feature_names_ = feature_names_
245-
246-
bagging_clf = BaggingClassifier(
247-
base_estimator=DecisionTreeClassifier(
248-
max_depth=self.max_depth,
249-
max_features=self.max_features,
250-
min_samples_split=self.min_samples_split),
251-
n_estimators=self.n_estimators,
252-
max_samples=self.max_samples_,
253-
max_features=self.max_samples_features,
254-
bootstrap=self.bootstrap,
255-
bootstrap_features=self.bootstrap_features,
256-
# oob_score=... XXX may be added if selection on tree perf needed.
257-
# warm_start=... XXX may be added to increase computation perf.
258-
n_jobs=self.n_jobs,
259-
random_state=self.random_state,
260-
verbose=self.verbose)
261-
262-
bagging_reg = BaggingRegressor(
263-
base_estimator=DecisionTreeRegressor(
264-
max_depth=self.max_depth,
265-
max_features=self.max_features,
266-
min_samples_split=self.min_samples_split),
267-
n_estimators=self.n_estimators,
268-
max_samples=self.max_samples_,
269-
max_features=self.max_samples_features,
270-
bootstrap=self.bootstrap,
271-
bootstrap_features=self.bootstrap_features,
272-
# oob_score=... XXX may be added if selection on tree perf needed.
273-
# warm_start=... XXX may be added to increase computation perf.
274-
n_jobs=self.n_jobs,
275-
random_state=self.random_state,
276-
verbose=self.verbose)
277-
278-
bagging_clf.fit(X, y)
249+
clfs = []
250+
regs = []
251+
252+
for max_depth in self.max_depths:
253+
bagging_clf = BaggingClassifier(
254+
base_estimator=DecisionTreeClassifier(
255+
max_depth=max_depth,
256+
max_features=self.max_features,
257+
min_samples_split=self.min_samples_split),
258+
n_estimators=self.n_estimators,
259+
max_samples=self.max_samples_,
260+
max_features=self.max_samples_features,
261+
bootstrap=self.bootstrap,
262+
bootstrap_features=self.bootstrap_features,
263+
# oob_score=... XXX may be added if selection on tree perf needed.
264+
# warm_start=... XXX may be added to increase computation perf.
265+
n_jobs=self.n_jobs,
266+
random_state=self.random_state,
267+
verbose=self.verbose)
268+
269+
bagging_reg = BaggingRegressor(
270+
base_estimator=DecisionTreeRegressor(
271+
max_depth=max_depth,
272+
max_features=self.max_features,
273+
min_samples_split=self.min_samples_split),
274+
n_estimators=self.n_estimators,
275+
max_samples=self.max_samples_,
276+
max_features=self.max_samples_features,
277+
bootstrap=self.bootstrap,
278+
bootstrap_features=self.bootstrap_features,
279+
# oob_score=... XXX may be added if selection on tree perf needed.
280+
# warm_start=... XXX may be added to increase computation perf.
281+
n_jobs=self.n_jobs,
282+
random_state=self.random_state,
283+
verbose=self.verbose)
284+
285+
clfs.append(bagging_clf)
286+
regs.append(bagging_reg)
279287

280288
# define regression target:
281289
if sample_weight is not None:
@@ -290,16 +298,17 @@ def fit(self, X, y, sample_weight=None):
290298
else:
291299
y_reg = y # same as an other classification bagging
292300

293-
bagging_reg.fit(X, y_reg)
294-
295-
self.estimators_ += bagging_clf.estimators_
296-
self.estimators_ += bagging_reg.estimators_
297-
298-
self.estimators_samples_ += bagging_clf.estimators_samples_
299-
self.estimators_samples_ += bagging_reg.estimators_samples_
300-
301-
self.estimators_features_ += bagging_clf.estimators_features_
302-
self.estimators_features_ += bagging_reg.estimators_features_
301+
for clf in clfs:
302+
clf.fit(X, y)
303+
self.estimators_ += clf.estimators_
304+
self.estimators_samples_ += clf.estimators_samples_
305+
self.estimators_features_ += clf.estimators_features_
306+
307+
for reg in regs:
308+
reg.fit(X, y_reg)
309+
self.estimators_ += reg.estimators_
310+
self.estimators_samples_ += reg.estimators_samples_
311+
self.estimators_features_ += reg.estimators_features_
303312

304313
rules_ = []
305314
for estimator, samples, features in zip(self.estimators_,

0 commit comments

Comments
 (0)