Skip to content

Commit 0a659af

Browse files
authored
MAINT compatibility with sklearn 1.4 (#1045)
1 parent 42a7909 commit 0a659af

File tree

7 files changed

+2546
-109
lines changed

7 files changed

+2546
-109
lines changed

doc/whats_new/v0.12.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ Bug fixes
1818
the number of samples in the minority class.
1919
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.
2020

21+
Compatibility
22+
.............
23+
24+
- :class:`~imblearn.ensemble.BalancedRandomForestClassifier` now support missing values
25+
and monotonic constraints if scikit-learn >= 1.4 is installed.
26+
- :class:`~imblearn.pipeline.Pipeline` support metadata routing if scikit-learn >= 1.4
27+
is installed.
28+
2129
Deprecations
2230
............
2331

imblearn/ensemble/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,5 @@ def check(self):
101101
list,
102102
None,
103103
],
104+
"monotonic_cst": ["array-like", None],
104105
}

imblearn/ensemble/_forest.py

Lines changed: 115 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def _local_parallel_build_trees(
6060
class_weight=None,
6161
n_samples_bootstrap=None,
6262
forest=None,
63+
missing_values_in_feature_mask=None,
6364
):
6465
# resample before to fit the tree
6566
X_resampled, y_resampled = sampler.fit_resample(X, y)
@@ -68,33 +69,34 @@ def _local_parallel_build_trees(
6869
if _get_n_samples_bootstrap is not None:
6970
n_samples_bootstrap = min(n_samples_bootstrap, X_resampled.shape[0])
7071

71-
if sklearn_version >= parse_version("1.1"):
72-
tree = _parallel_build_trees(
73-
tree,
74-
bootstrap,
75-
X_resampled,
76-
y_resampled,
77-
sample_weight,
78-
tree_idx,
79-
n_trees,
80-
verbose=verbose,
81-
class_weight=class_weight,
82-
n_samples_bootstrap=n_samples_bootstrap,
83-
)
72+
params_parallel_build_trees = {
73+
"tree": tree,
74+
"X": X_resampled,
75+
"y": y_resampled,
76+
"sample_weight": sample_weight,
77+
"tree_idx": tree_idx,
78+
"n_trees": n_trees,
79+
"verbose": verbose,
80+
"class_weight": class_weight,
81+
"n_samples_bootstrap": n_samples_bootstrap,
82+
}
83+
84+
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
85+
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
86+
# support for missing values
87+
params_parallel_build_trees[
88+
"missing_values_in_feature_mask"
89+
] = missing_values_in_feature_mask
90+
91+
# TODO: remove when the minimum supported version of scikit-learn will be 1.1
92+
# change of signature in scikit-learn 1.1
93+
if parse_version(sklearn_version.base_version) >= parse_version("1.1"):
94+
params_parallel_build_trees["bootstrap"] = bootstrap
8495
else:
85-
# TODO: remove when the minimum version of scikit-learn supported is 1.1
86-
tree = _parallel_build_trees(
87-
tree,
88-
forest,
89-
X_resampled,
90-
y_resampled,
91-
sample_weight,
92-
tree_idx,
93-
n_trees,
94-
verbose=verbose,
95-
class_weight=class_weight,
96-
n_samples_bootstrap=n_samples_bootstrap,
97-
)
96+
params_parallel_build_trees["forest"] = forest
97+
98+
tree = _parallel_build_trees(**params_parallel_build_trees)
99+
98100
return sampler, tree
99101

100102

@@ -305,6 +307,25 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
305307
.. versionadded:: 0.6
306308
Added in `scikit-learn` in 0.22
307309
310+
monotonic_cst : array-like of int of shape (n_features), default=None
311+
Indicates the monotonicity constraint to enforce on each feature.
312+
- 1: monotonic increase
313+
- 0: no constraint
314+
- -1: monotonic decrease
315+
316+
If monotonic_cst is None, no constraints are applied.
317+
318+
Monotonicity constraints are not supported for:
319+
- multiclass classifications (i.e. when `n_classes > 2`),
320+
- multioutput classifications (i.e. when `n_outputs_ > 1`),
321+
- classifications trained on data with missing values.
322+
323+
The constraints hold over the probability of the positive class.
324+
325+
.. versionadded:: 0.12
326+
Only supported when scikit-learn >= 1.4 is installed. Otherwise, a
327+
`ValueError` is raised.
328+
308329
Attributes
309330
----------
310331
estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` instance
@@ -415,7 +436,7 @@ class labels (multi-output problem).
415436
"""
416437

417438
# make a deepcopy to not modify the original dictionary
418-
if sklearn_version >= parse_version("1.3"):
439+
if sklearn_version >= parse_version("1.4"):
419440
_parameter_constraints = deepcopy(RandomForestClassifier._parameter_constraints)
420441
else:
421442
_parameter_constraints = deepcopy(
@@ -459,27 +480,42 @@ def __init__(
459480
class_weight=None,
460481
ccp_alpha=0.0,
461482
max_samples=None,
483+
monotonic_cst=None,
462484
):
463-
super().__init__(
464-
criterion=criterion,
465-
max_depth=max_depth,
466-
n_estimators=n_estimators,
467-
bootstrap=bootstrap,
468-
oob_score=oob_score,
469-
n_jobs=n_jobs,
470-
random_state=random_state,
471-
verbose=verbose,
472-
warm_start=warm_start,
473-
class_weight=class_weight,
474-
min_samples_split=min_samples_split,
475-
min_samples_leaf=min_samples_leaf,
476-
min_weight_fraction_leaf=min_weight_fraction_leaf,
477-
max_features=max_features,
478-
max_leaf_nodes=max_leaf_nodes,
479-
min_impurity_decrease=min_impurity_decrease,
480-
ccp_alpha=ccp_alpha,
481-
max_samples=max_samples,
482-
)
485+
params_random_forest = {
486+
"criterion": criterion,
487+
"max_depth": max_depth,
488+
"n_estimators": n_estimators,
489+
"bootstrap": bootstrap,
490+
"oob_score": oob_score,
491+
"n_jobs": n_jobs,
492+
"random_state": random_state,
493+
"verbose": verbose,
494+
"warm_start": warm_start,
495+
"class_weight": class_weight,
496+
"min_samples_split": min_samples_split,
497+
"min_samples_leaf": min_samples_leaf,
498+
"min_weight_fraction_leaf": min_weight_fraction_leaf,
499+
"max_features": max_features,
500+
"max_leaf_nodes": max_leaf_nodes,
501+
"min_impurity_decrease": min_impurity_decrease,
502+
"ccp_alpha": ccp_alpha,
503+
"max_samples": max_samples,
504+
}
505+
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
506+
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
507+
# use scikit-learn support for monotonic constraints
508+
params_random_forest["monotonic_cst"] = monotonic_cst
509+
else:
510+
if monotonic_cst is not None:
511+
raise ValueError(
512+
"Monotonic constraints are not supported for scikit-learn "
513+
"version < 1.4."
514+
)
515+
# create an attribute for compatibility with other scikit-learn tools such
516+
# as HTML representation.
517+
self.monotonic_cst = monotonic_cst
518+
super().__init__(**params_random_forest)
483519

484520
self.sampling_strategy = sampling_strategy
485521
self.replacement = replacement
@@ -591,11 +627,41 @@ def fit(self, X, y, sample_weight=None):
591627
# Validate or convert input data
592628
if issparse(y):
593629
raise ValueError("sparse multilabel-indicator for y is not supported.")
630+
631+
# TODO: remove when the minimum supported version of scipy will be 1.4
632+
# Support for missing values
633+
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
634+
force_all_finite = False
635+
else:
636+
force_all_finite = True
637+
594638
X, y = self._validate_data(
595-
X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE
639+
X,
640+
y,
641+
multi_output=True,
642+
accept_sparse="csc",
643+
dtype=DTYPE,
644+
force_all_finite=force_all_finite,
596645
)
646+
647+
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
648+
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
649+
# _compute_missing_values_in_feature_mask checks if X has missing values and
650+
# will raise an error if the underlying tree base estimator can't handle
651+
# missing values. Only the criterion is required to determine if the tree
652+
# supports missing values.
653+
estimator = type(self.estimator)(criterion=self.criterion)
654+
missing_values_in_feature_mask = (
655+
estimator._compute_missing_values_in_feature_mask(
656+
X, estimator_name=self.__class__.__name__
657+
)
658+
)
659+
else:
660+
missing_values_in_feature_mask = None
661+
597662
if sample_weight is not None:
598663
sample_weight = _check_sample_weight(sample_weight, X)
664+
599665
self._n_features = X.shape[1]
600666

601667
if issparse(X):
@@ -713,6 +779,7 @@ def fit(self, X, y, sample_weight=None):
713779
class_weight=self.class_weight,
714780
n_samples_bootstrap=n_samples_bootstrap,
715781
forest=self,
782+
missing_values_in_feature_mask=missing_values_in_feature_mask,
716783
)
717784
for i, (s, t) in enumerate(zip(samplers, trees))
718785
)

imblearn/ensemble/tests/test_forest.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,100 @@ def test_balanced_random_forest_change_behaviour(imbalanced_dataset):
258258
)
259259
with pytest.warns(FutureWarning, match="The default of `bootstrap`"):
260260
estimator.fit(*imbalanced_dataset)
261+
262+
263+
@pytest.mark.skipif(
264+
parse_version(sklearn_version.base_version) < parse_version("1.4"),
265+
reason="scikit-learn should be >= 1.4",
266+
)
267+
def test_missing_values_is_resilient():
268+
"""Check that forest can deal with missing values and has decent performance."""
269+
270+
rng = np.random.RandomState(0)
271+
n_samples, n_features = 1000, 10
272+
X, y = make_classification(
273+
n_samples=n_samples, n_features=n_features, random_state=rng
274+
)
275+
276+
# Create dataset with missing values
277+
X_missing = X.copy()
278+
X_missing[rng.choice([False, True], size=X.shape, p=[0.95, 0.05])] = np.nan
279+
assert np.isnan(X_missing).any()
280+
281+
X_missing_train, X_missing_test, y_train, y_test = train_test_split(
282+
X_missing, y, random_state=0
283+
)
284+
285+
# Train forest with missing values
286+
forest_with_missing = BalancedRandomForestClassifier(
287+
sampling_strategy="all",
288+
replacement=True,
289+
bootstrap=False,
290+
random_state=rng,
291+
n_estimators=50,
292+
)
293+
forest_with_missing.fit(X_missing_train, y_train)
294+
score_with_missing = forest_with_missing.score(X_missing_test, y_test)
295+
296+
# Train forest without missing values
297+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
298+
forest = BalancedRandomForestClassifier(
299+
sampling_strategy="all",
300+
replacement=True,
301+
bootstrap=False,
302+
random_state=rng,
303+
n_estimators=50,
304+
)
305+
forest.fit(X_train, y_train)
306+
score_without_missing = forest.score(X_test, y_test)
307+
308+
# Score is still 80 percent of the forest's score that had no missing values
309+
assert score_with_missing >= 0.80 * score_without_missing
310+
311+
312+
@pytest.mark.skipif(
313+
parse_version(sklearn_version.base_version) < parse_version("1.4"),
314+
reason="scikit-learn should be >= 1.4",
315+
)
316+
def test_missing_value_is_predictive():
317+
"""Check that the forest learns when missing values are only present for
318+
a predictive feature."""
319+
rng = np.random.RandomState(0)
320+
n_samples = 300
321+
322+
X_non_predictive = rng.standard_normal(size=(n_samples, 10))
323+
y = rng.randint(0, high=2, size=n_samples)
324+
325+
# Create a predictive feature using `y` and with some noise
326+
X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05])
327+
y_mask = y.astype(bool)
328+
y_mask[X_random_mask] = ~y_mask[X_random_mask]
329+
330+
predictive_feature = rng.standard_normal(size=n_samples)
331+
predictive_feature[y_mask] = np.nan
332+
assert np.isnan(predictive_feature).any()
333+
334+
X_predictive = X_non_predictive.copy()
335+
X_predictive[:, 5] = predictive_feature
336+
337+
(
338+
X_predictive_train,
339+
X_predictive_test,
340+
X_non_predictive_train,
341+
X_non_predictive_test,
342+
y_train,
343+
y_test,
344+
) = train_test_split(X_predictive, X_non_predictive, y, random_state=0)
345+
forest_predictive = BalancedRandomForestClassifier(
346+
sampling_strategy="all", replacement=True, bootstrap=False, random_state=0
347+
).fit(X_predictive_train, y_train)
348+
forest_non_predictive = BalancedRandomForestClassifier(
349+
sampling_strategy="all", replacement=True, bootstrap=False, random_state=0
350+
).fit(X_non_predictive_train, y_train)
351+
352+
predictive_test_score = forest_predictive.score(X_predictive_test, y_test)
353+
354+
assert predictive_test_score >= 0.75
355+
assert predictive_test_score >= forest_non_predictive.score(
356+
X_non_predictive_test, y_test
357+
)

0 commit comments

Comments
 (0)