@@ -60,6 +60,7 @@ def _local_parallel_build_trees(
60
60
class_weight = None ,
61
61
n_samples_bootstrap = None ,
62
62
forest = None ,
63
+ missing_values_in_feature_mask = None ,
63
64
):
64
65
# resample before to fit the tree
65
66
X_resampled , y_resampled = sampler .fit_resample (X , y )
@@ -68,33 +69,34 @@ def _local_parallel_build_trees(
68
69
if _get_n_samples_bootstrap is not None :
69
70
n_samples_bootstrap = min (n_samples_bootstrap , X_resampled .shape [0 ])
70
71
71
- if sklearn_version >= parse_version ("1.1" ):
72
- tree = _parallel_build_trees (
73
- tree ,
74
- bootstrap ,
75
- X_resampled ,
76
- y_resampled ,
77
- sample_weight ,
78
- tree_idx ,
79
- n_trees ,
80
- verbose = verbose ,
81
- class_weight = class_weight ,
82
- n_samples_bootstrap = n_samples_bootstrap ,
83
- )
72
+ params_parallel_build_trees = {
73
+ "tree" : tree ,
74
+ "X" : X_resampled ,
75
+ "y" : y_resampled ,
76
+ "sample_weight" : sample_weight ,
77
+ "tree_idx" : tree_idx ,
78
+ "n_trees" : n_trees ,
79
+ "verbose" : verbose ,
80
+ "class_weight" : class_weight ,
81
+ "n_samples_bootstrap" : n_samples_bootstrap ,
82
+ }
83
+
84
+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.4" ):
85
+ # TODO: remove when the minimum supported version of scikit-learn will be 1.4
86
+ # support for missing values
87
+ params_parallel_build_trees [
88
+ "missing_values_in_feature_mask"
89
+ ] = missing_values_in_feature_mask
90
+
91
+ # TODO: remove when the minimum supported version of scikit-learn will be 1.1
92
+ # change of signature in scikit-learn 1.1
93
+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.1" ):
94
+ params_parallel_build_trees ["bootstrap" ] = bootstrap
84
95
else :
85
- # TODO: remove when the minimum version of scikit-learn supported is 1.1
86
- tree = _parallel_build_trees (
87
- tree ,
88
- forest ,
89
- X_resampled ,
90
- y_resampled ,
91
- sample_weight ,
92
- tree_idx ,
93
- n_trees ,
94
- verbose = verbose ,
95
- class_weight = class_weight ,
96
- n_samples_bootstrap = n_samples_bootstrap ,
97
- )
96
+ params_parallel_build_trees ["forest" ] = forest
97
+
98
+ tree = _parallel_build_trees (** params_parallel_build_trees )
99
+
98
100
return sampler , tree
99
101
100
102
@@ -305,6 +307,25 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
305
307
.. versionadded:: 0.6
306
308
Added in `scikit-learn` in 0.22
307
309
310
+ monotonic_cst : array-like of int of shape (n_features), default=None
311
+ Indicates the monotonicity constraint to enforce on each feature.
312
+ - 1: monotonic increase
313
+ - 0: no constraint
314
+ - -1: monotonic decrease
315
+
316
+ If monotonic_cst is None, no constraints are applied.
317
+
318
+ Monotonicity constraints are not supported for:
319
+ - multiclass classifications (i.e. when `n_classes > 2`),
320
+ - multioutput classifications (i.e. when `n_outputs_ > 1`),
321
+ - classifications trained on data with missing values.
322
+
323
+ The constraints hold over the probability of the positive class.
324
+
325
+ .. versionadded:: 0.12
326
+ Only supported when scikit-learn >= 1.4 is installed. Otherwise, a
327
+ `ValueError` is raised.
328
+
308
329
Attributes
309
330
----------
310
331
estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` instance
@@ -415,7 +436,7 @@ class labels (multi-output problem).
415
436
"""
416
437
417
438
# make a deepcopy to not modify the original dictionary
418
- if sklearn_version >= parse_version ("1.3 " ):
439
+ if sklearn_version >= parse_version ("1.4 " ):
419
440
_parameter_constraints = deepcopy (RandomForestClassifier ._parameter_constraints )
420
441
else :
421
442
_parameter_constraints = deepcopy (
@@ -459,27 +480,42 @@ def __init__(
459
480
class_weight = None ,
460
481
ccp_alpha = 0.0 ,
461
482
max_samples = None ,
483
+ monotonic_cst = None ,
462
484
):
463
- super ().__init__ (
464
- criterion = criterion ,
465
- max_depth = max_depth ,
466
- n_estimators = n_estimators ,
467
- bootstrap = bootstrap ,
468
- oob_score = oob_score ,
469
- n_jobs = n_jobs ,
470
- random_state = random_state ,
471
- verbose = verbose ,
472
- warm_start = warm_start ,
473
- class_weight = class_weight ,
474
- min_samples_split = min_samples_split ,
475
- min_samples_leaf = min_samples_leaf ,
476
- min_weight_fraction_leaf = min_weight_fraction_leaf ,
477
- max_features = max_features ,
478
- max_leaf_nodes = max_leaf_nodes ,
479
- min_impurity_decrease = min_impurity_decrease ,
480
- ccp_alpha = ccp_alpha ,
481
- max_samples = max_samples ,
482
- )
485
+ params_random_forest = {
486
+ "criterion" : criterion ,
487
+ "max_depth" : max_depth ,
488
+ "n_estimators" : n_estimators ,
489
+ "bootstrap" : bootstrap ,
490
+ "oob_score" : oob_score ,
491
+ "n_jobs" : n_jobs ,
492
+ "random_state" : random_state ,
493
+ "verbose" : verbose ,
494
+ "warm_start" : warm_start ,
495
+ "class_weight" : class_weight ,
496
+ "min_samples_split" : min_samples_split ,
497
+ "min_samples_leaf" : min_samples_leaf ,
498
+ "min_weight_fraction_leaf" : min_weight_fraction_leaf ,
499
+ "max_features" : max_features ,
500
+ "max_leaf_nodes" : max_leaf_nodes ,
501
+ "min_impurity_decrease" : min_impurity_decrease ,
502
+ "ccp_alpha" : ccp_alpha ,
503
+ "max_samples" : max_samples ,
504
+ }
505
+ # TODO: remove when the minimum supported version of scikit-learn will be 1.4
506
+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.4" ):
507
+ # use scikit-learn support for monotonic constraints
508
+ params_random_forest ["monotonic_cst" ] = monotonic_cst
509
+ else :
510
+ if monotonic_cst is not None :
511
+ raise ValueError (
512
+ "Monotonic constraints are not supported for scikit-learn "
513
+ "version < 1.4."
514
+ )
515
+ # create an attribute for compatibility with other scikit-learn tools such
516
+ # as HTML representation.
517
+ self .monotonic_cst = monotonic_cst
518
+ super ().__init__ (** params_random_forest )
483
519
484
520
self .sampling_strategy = sampling_strategy
485
521
self .replacement = replacement
@@ -591,11 +627,41 @@ def fit(self, X, y, sample_weight=None):
591
627
# Validate or convert input data
592
628
if issparse (y ):
593
629
raise ValueError ("sparse multilabel-indicator for y is not supported." )
630
+
631
+ # TODO: remove when the minimum supported version of scipy will be 1.4
632
+ # Support for missing values
633
+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.4" ):
634
+ force_all_finite = False
635
+ else :
636
+ force_all_finite = True
637
+
594
638
X , y = self ._validate_data (
595
- X , y , multi_output = True , accept_sparse = "csc" , dtype = DTYPE
639
+ X ,
640
+ y ,
641
+ multi_output = True ,
642
+ accept_sparse = "csc" ,
643
+ dtype = DTYPE ,
644
+ force_all_finite = force_all_finite ,
596
645
)
646
+
647
+ # TODO: remove when the minimum supported version of scikit-learn will be 1.4
648
+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.4" ):
649
+ # _compute_missing_values_in_feature_mask checks if X has missing values and
650
+ # will raise an error if the underlying tree base estimator can't handle
651
+ # missing values. Only the criterion is required to determine if the tree
652
+ # supports missing values.
653
+ estimator = type (self .estimator )(criterion = self .criterion )
654
+ missing_values_in_feature_mask = (
655
+ estimator ._compute_missing_values_in_feature_mask (
656
+ X , estimator_name = self .__class__ .__name__
657
+ )
658
+ )
659
+ else :
660
+ missing_values_in_feature_mask = None
661
+
597
662
if sample_weight is not None :
598
663
sample_weight = _check_sample_weight (sample_weight , X )
664
+
599
665
self ._n_features = X .shape [1 ]
600
666
601
667
if issparse (X ):
@@ -713,6 +779,7 @@ def fit(self, X, y, sample_weight=None):
713
779
class_weight = self .class_weight ,
714
780
n_samples_bootstrap = n_samples_bootstrap ,
715
781
forest = self ,
782
+ missing_values_in_feature_mask = missing_values_in_feature_mask ,
716
783
)
717
784
for i , (s , t ) in enumerate (zip (samplers , trees ))
718
785
)
0 commit comments