Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions boruta/boruta_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import numpy as np
import scipy as sp
from sklearn.utils import check_random_state, check_X_y
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.base import BaseEstimator
from sklearn.feature_selection import SelectorMixin
from sklearn.utils.validation import check_is_fitted
import warnings


class BorutaPy(BaseEstimator, TransformerMixin):
class BorutaPy(BaseEstimator, SelectorMixin):
"""
Improved Python implementation of the Boruta R package.

Expand Down Expand Up @@ -452,6 +454,10 @@ def _transform(self, X, weak=False, return_df=False):
X = X[:, indices]
return X

def _get_support_mask(self):
check_is_fitted(self, 'support_')
return self.support_

def _get_tree_num(self, n_feat):
depth = None
try:
Expand Down