diff --git a/skrules/skope_rules.py b/skrules/skope_rules.py index 899f166..edb109f 100644 --- a/skrules/skope_rules.py +++ b/skrules/skope_rules.py @@ -95,6 +95,11 @@ class SkopeRules(BaseEstimator): `ceil(min_samples_split * n_samples)` are the minimum number of samples for each split. + class_weight: dict, list of dict or "balanced", default=None + The weights to be used for the DecisionTreeClassifier. Weights associated + with classes in the form {class_label: weight}. If None, + all classes are supposed to have weight one. + n_jobs : integer, optional (default=1) The number of jobs to run in parallel for both `fit` and `predict`. If -1, then the number of jobs is set to the number of cores. @@ -150,6 +155,7 @@ def __init__(self, max_depth_duplication=None, max_features=1., min_samples_split=2, + class_weight=None, n_jobs=1, random_state=None, verbose=0): @@ -164,6 +170,7 @@ def __init__(self, self.max_depth = max_depth self.max_depth_duplication = max_depth_duplication self.max_features = max_features + self.class_weight = class_weight self.min_samples_split = min_samples_split self.n_jobs = n_jobs self.random_state = random_state @@ -270,7 +277,8 @@ def fit(self, X, y, sample_weight=None): base_estimator=DecisionTreeClassifier( max_depth=max_depth, max_features=self.max_features, - min_samples_split=self.min_samples_split), + min_samples_split=self.min_samples_split, + class_weight=self.class_weight), n_estimators=self.n_estimators, max_samples=self.max_samples_, max_features=self.max_samples_features, diff --git a/skrules/tests/test_skope_rules.py b/skrules/tests/test_skope_rules.py index 238871e..56c5bdc 100644 --- a/skrules/tests/test_skope_rules.py +++ b/skrules/tests/test_skope_rules.py @@ -72,6 +72,9 @@ def test_skope_rules(): recall_min=0., precision_min=0.).fit(X_train, y_train).predict(X_test) + # with additional class weights + SkopeRules(n_estimators=50, class_weight='balanced').fit(X_train, y_train).predict(X_test) + def test_skope_rules_error(): """Test that it gives proper exception on deficient input."""