2727from sklearn .utils .multiclass import type_of_target
2828from sklearn .utils .parallel import Parallel , delayed
2929from sklearn .utils .validation import _check_sample_weight
30+ from sklearn .utils ._param_validation import Hidden , Interval , StrOptions
3031
31- from ..base import _ParamsValidationMixin
3232from ..pipeline import make_pipeline
3333from ..under_sampling import RandomUnderSampler
3434from ..utils import Substitution
3535from ..utils ._docstring import _n_jobs_docstring , _random_state_docstring
36- from ..utils ._param_validation import Hidden , Interval , StrOptions
36+ from ..utils ._sklearn_compat import _fit_context , validate_data
3737from ..utils ._validation import check_sampling_strategy
38- from ..utils .fixes import _fit_context
3938from ._common import _random_forest_classifier_parameter_constraints
4039
4140MAX_INT = np .iinfo (np .int32 ).max
42- sklearn_version = parse_version (sklearn .__version__ )
41+ sklearn_version = parse_version (parse_version ( sklearn .__version__ ). base_version )
4342
4443
4544def _local_parallel_build_trees (
@@ -77,7 +76,7 @@ def _local_parallel_build_trees(
7776 "bootstrap" : bootstrap ,
7877 }
7978
80- if parse_version ( sklearn_version . base_version ) >= parse_version ("1.4" ):
79+ if sklearn_version >= parse_version ("1.4" ):
8180 # TODO: remove when the minimum supported version of scikit-learn will be 1.4
8281 # support for missing values
8382 params_parallel_build_trees ["missing_values_in_feature_mask" ] = (
@@ -93,7 +92,7 @@ def _local_parallel_build_trees(
9392 n_jobs = _n_jobs_docstring ,
9493 random_state = _random_state_docstring ,
9594)
96- class BalancedRandomForestClassifier (_ParamsValidationMixin , RandomForestClassifier ):
95+ class BalancedRandomForestClassifier (RandomForestClassifier ):
9796 """A balanced random forest classifier.
9897
9998 A balanced random forest differs from a classical random forest by the
@@ -474,7 +473,7 @@ def __init__(
474473 "max_samples" : max_samples ,
475474 }
476475 # TODO: remove when the minimum supported version of scikit-learn will be 1.4
477- if parse_version ( sklearn_version . base_version ) >= parse_version ("1.4" ):
476+ if sklearn_version >= parse_version ("1.4" ):
478477 # use scikit-learn support for monotonic constraints
479478 params_random_forest ["monotonic_cst" ] = monotonic_cst
480479 else :
@@ -596,22 +595,23 @@ def fit(self, X, y, sample_weight=None):
596595
597596 # TODO: remove when the minimum supported version of scipy will be 1.4
598597 # Support for missing values
599- if parse_version ( sklearn_version . base_version ) >= parse_version ("1.4" ):
600- force_all_finite = False
598+ if sklearn_version >= parse_version ("1.4" ):
599+ ensure_all_finite = False
601600 else :
602- force_all_finite = True
601+ ensure_all_finite = True
603602
604- X , y = self ._validate_data (
605- X ,
606- y ,
603+ X , y = validate_data (
604+ self ,
605+ X = X ,
606+ y = y ,
607607 multi_output = True ,
608608 accept_sparse = "csc" ,
609609 dtype = DTYPE ,
610- force_all_finite = force_all_finite ,
610+ ensure_all_finite = ensure_all_finite ,
611611 )
612612
613613 # TODO: remove when the minimum supported version of scikit-learn will be 1.4
614- if parse_version ( sklearn_version . base_version ) >= parse_version ("1.4" ):
614+ if sklearn_version >= parse_version ("1.4" ):
615615 # _compute_missing_values_in_feature_mask checks if X has missing values and
616616 # will raise an error if the underlying tree base estimator can't handle
617617 # missing values. Only the criterion is required to determine if the tree
@@ -882,3 +882,10 @@ def _compute_oob_predictions(self, X, y):
882882
883883 def _more_tags (self ):
884884 return {"multioutput" : False , "multilabel" : False }
885+
886+ def __sklearn_tags__ (self ):
887+ tags = super ().__sklearn_tags__ ()
888+ tags .target_tags .multi_output = False
889+ tags .classifier_tags .multi_label = False
890+ tags .input_tags .allow_nan = sklearn_version >= parse_version ("1.4" )
891+ return tags
0 commit comments