11import numpy as np
2- from collections import Counter
2+ from collections import Counter , Iterable
33import pandas
44import numbers
55from warnings import warn
@@ -54,10 +54,13 @@ class SkopeRules(BaseEstimator):
5454 bootstrap_features : boolean, optional (default=False)
5555 Whether features are drawn with replacement.
5656
57- max_depth : integer or None, optional (default=3)
57+ max_depth : integer or List or None, optional (default=3)
5858 The maximum depth of the decision trees. If None, then nodes are
5959 expanded until all leaves are pure or until all leaves contain less
6060 than min_samples_split samples.
61+ If an iterable is passed, you will train n_estimators
62+ for each tree depth. It allows you to create and compare
63+ rules of different length.
6164
6265 max_depth_duplication : integer or None, optional (default=3)
6366 The maximum depth of the decision tree for rule deduplication,
@@ -155,6 +158,7 @@ def __init__(self,
155158 self .bootstrap = bootstrap
156159 self .bootstrap_features = bootstrap_features
157160 self .max_depth = max_depth
161+ self .max_depths = max_depth if isinstance (max_depth , Iterable ) else [max_depth ]
158162 self .max_depth_duplication = max_depth_duplication
159163 self .max_features = max_features
160164 self .min_samples_split = min_samples_split
@@ -242,40 +246,44 @@ def fit(self, X, y, sample_weight=None):
242246 else ['c' + x for x in
243247 np .arange (X .shape [1 ]).astype (str )])
244248 self .feature_names_ = feature_names_
245-
246- bagging_clf = BaggingClassifier (
247- base_estimator = DecisionTreeClassifier (
248- max_depth = self .max_depth ,
249- max_features = self .max_features ,
250- min_samples_split = self .min_samples_split ),
251- n_estimators = self .n_estimators ,
252- max_samples = self .max_samples_ ,
253- max_features = self .max_samples_features ,
254- bootstrap = self .bootstrap ,
255- bootstrap_features = self .bootstrap_features ,
256- # oob_score=... XXX may be added if selection on tree perf needed.
257- # warm_start=... XXX may be added to increase computation perf.
258- n_jobs = self .n_jobs ,
259- random_state = self .random_state ,
260- verbose = self .verbose )
261-
262- bagging_reg = BaggingRegressor (
263- base_estimator = DecisionTreeRegressor (
264- max_depth = self .max_depth ,
265- max_features = self .max_features ,
266- min_samples_split = self .min_samples_split ),
267- n_estimators = self .n_estimators ,
268- max_samples = self .max_samples_ ,
269- max_features = self .max_samples_features ,
270- bootstrap = self .bootstrap ,
271- bootstrap_features = self .bootstrap_features ,
272- # oob_score=... XXX may be added if selection on tree perf needed.
273- # warm_start=... XXX may be added to increase computation perf.
274- n_jobs = self .n_jobs ,
275- random_state = self .random_state ,
276- verbose = self .verbose )
277-
278- bagging_clf .fit (X , y )
249+ clfs = []
250+ regs = []
251+
252+ for max_depth in self .max_depths :
253+ bagging_clf = BaggingClassifier (
254+ base_estimator = DecisionTreeClassifier (
255+ max_depth = max_depth ,
256+ max_features = self .max_features ,
257+ min_samples_split = self .min_samples_split ),
258+ n_estimators = self .n_estimators ,
259+ max_samples = self .max_samples_ ,
260+ max_features = self .max_samples_features ,
261+ bootstrap = self .bootstrap ,
262+ bootstrap_features = self .bootstrap_features ,
263+ # oob_score=... XXX may be added if selection on tree perf needed.
264+ # warm_start=... XXX may be added to increase computation perf.
265+ n_jobs = self .n_jobs ,
266+ random_state = self .random_state ,
267+ verbose = self .verbose )
268+
269+ bagging_reg = BaggingRegressor (
270+ base_estimator = DecisionTreeRegressor (
271+ max_depth = max_depth ,
272+ max_features = self .max_features ,
273+ min_samples_split = self .min_samples_split ),
274+ n_estimators = self .n_estimators ,
275+ max_samples = self .max_samples_ ,
276+ max_features = self .max_samples_features ,
277+ bootstrap = self .bootstrap ,
278+ bootstrap_features = self .bootstrap_features ,
279+ # oob_score=... XXX may be added if selection on tree perf needed.
280+ # warm_start=... XXX may be added to increase computation perf.
281+ n_jobs = self .n_jobs ,
282+ random_state = self .random_state ,
283+ verbose = self .verbose )
284+
285+ clfs .append (bagging_clf )
286+ regs .append (bagging_reg )
279287
280288 # define regression target:
281289 if sample_weight is not None :
@@ -290,16 +298,17 @@ def fit(self, X, y, sample_weight=None):
290298 else :
291299 y_reg = y # same as an other classification bagging
292300
293- bagging_reg .fit (X , y_reg )
294-
295- self .estimators_ += bagging_clf .estimators_
296- self .estimators_ += bagging_reg .estimators_
297-
298- self .estimators_samples_ += bagging_clf .estimators_samples_
299- self .estimators_samples_ += bagging_reg .estimators_samples_
300-
301- self .estimators_features_ += bagging_clf .estimators_features_
302- self .estimators_features_ += bagging_reg .estimators_features_
301+ for clf in clfs :
302+ clf .fit (X , y )
303+ self .estimators_ += clf .estimators_
304+ self .estimators_samples_ += clf .estimators_samples_
305+ self .estimators_features_ += clf .estimators_features_
306+
307+ for reg in regs :
308+ reg .fit (X , y_reg )
309+ self .estimators_ += reg .estimators_
310+ self .estimators_samples_ += reg .estimators_samples_
311+ self .estimators_features_ += reg .estimators_features_
303312
304313 rules_ = []
305314 for estimator , samples , features in zip (self .estimators_ ,
0 commit comments