Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion skrules/skope_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class SkopeRules(BaseEstimator):
`ceil(min_samples_split * n_samples)` are the minimum
number of samples for each split.

class_weight: dict, list of dict or "balanced", default=None
The weights to be used for the DecisionTreeClassifier. Weights associated
with classes in the form {class_label: weight}. If None,
all classes are supposed to have weight one.

n_jobs : integer, optional (default=1)
The number of jobs to run in parallel for both `fit` and `predict`.
If -1, then the number of jobs is set to the number of cores.
Expand Down Expand Up @@ -150,6 +155,7 @@ def __init__(self,
max_depth_duplication=None,
max_features=1.,
min_samples_split=2,
class_weight=None,
n_jobs=1,
random_state=None,
verbose=0):
Expand All @@ -164,6 +170,7 @@ def __init__(self,
self.max_depth = max_depth
self.max_depth_duplication = max_depth_duplication
self.max_features = max_features
self.class_weight = class_weight
self.min_samples_split = min_samples_split
self.n_jobs = n_jobs
self.random_state = random_state
Expand Down Expand Up @@ -270,7 +277,8 @@ def fit(self, X, y, sample_weight=None):
base_estimator=DecisionTreeClassifier(
max_depth=max_depth,
max_features=self.max_features,
min_samples_split=self.min_samples_split),
min_samples_split=self.min_samples_split,
class_weight=self.class_weight),
n_estimators=self.n_estimators,
max_samples=self.max_samples_,
max_features=self.max_samples_features,
Expand Down
3 changes: 3 additions & 0 deletions skrules/tests/test_skope_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def test_skope_rules():
recall_min=0.,
precision_min=0.).fit(X_train, y_train).predict(X_test)

# with additional class weights
SkopeRules(n_estimators=50, class_weight='balanced').fit(X_train, y_train).predict(X_test)


def test_skope_rules_error():
"""Test that it gives proper exception on deficient input."""
Expand Down