Skip to content

Commit 628f4a4

Browse files
authored
MAINT add support for dataframe in parameter validation framework (#957)
1 parent a84b63f commit 628f4a4

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

imblearn/utils/_param_validation.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020

2121
sklearn_version = parse_version(sklearn.__version__)
2222

23-
if sklearn_version < parse_version("1.2"):
23+
# if sklearn_version < parse_version("1.2"):
24+
if True:
25+
# TODO: remove `if True` when we have clear support for:
26+
# - dataframe
2427

2528
def validate_parameter_constraints(parameter_constraints, params, caller_name):
2629
"""Validate types and values of given parameters.
@@ -35,6 +38,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):
3538
Constraints can be:
3639
- an Interval object, representing a continuous or discrete range of numbers
3740
- the string "array-like"
41+
- the string "dataframe"
3842
- the string "sparse matrix"
3943
- the string "random_state"
4044
- callable
@@ -115,6 +119,8 @@ def make_constraint(constraint):
115119
return _ArrayLikes()
116120
if isinstance(constraint, str) and constraint == "sparse matrix":
117121
return _SparseMatrices()
122+
if isinstance(constraint, str) and constraint == "dataframe":
123+
return _DataFrames()
118124
if isinstance(constraint, str) and constraint == "random_state":
119125
return _RandomStates()
120126
if constraint is callable:
@@ -466,6 +472,17 @@ def is_satisfied_by(self, val):
466472
def __str__(self):
467473
return "a sparse matrix"
468474

475+
class _DataFrames(_Constraint):
476+
"""Constraint representing a DataFrame"""
477+
478+
def is_satisfied_by(self, val):
479+
# Let's first try the dataframe protocol and then duck-typing for the older
480+
# pandas versions.
481+
return hasattr(val, "__dataframe__") or hasattr(val, "iloc")
482+
483+
def __str__(self):
484+
return "a DataFrame"
485+
469486
class _Callables(_Constraint):
470487
"""Constraint representing callables."""
471488

@@ -845,6 +862,11 @@ def generate_valid_param(constraint):
845862
if isinstance(constraint, _SparseMatrices):
846863
return csr_matrix([[0, 1], [1, 0]])
847864

865+
if isinstance(constraint, _DataFrames):
866+
import pandas as pd
867+
868+
return pd.DataFrame({"a": [1, 2, 3]})
869+
848870
if isinstance(constraint, _RandomStates):
849871
return np.random.RandomState(42)
850872

imblearn/utils/tests/test_param_validation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_Booleans,
2222
_Callables,
2323
_CVObjects,
24+
_DataFrames,
2425
_InstancesOf,
2526
_IterablesNotString,
2627
_MissingValues,
@@ -36,6 +37,15 @@
3637
)
3738

3839

40+
def has_pandas():
41+
try:
42+
import pandas as pd
43+
44+
return True, pd.DataFrame({"a": [1, 2, 3]})
45+
except ImportError:
46+
return False, None
47+
48+
3949
# Some helpers for the tests
4050
@validate_params({"a": [Real], "b": [Real], "c": [Real], "d": [Real]})
4151
def _func(a, b=0, *args, c, d=0, **kwargs):
@@ -317,6 +327,12 @@ def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval)
317327
"constraints",
318328
[
319329
[_ArrayLikes()],
330+
pytest.param(
331+
[_DataFrames()],
332+
marks=pytest.mark.skipif(
333+
not has_pandas()[0], reason="Pandas not installed"
334+
),
335+
),
320336
[_InstancesOf(list)],
321337
[_Callables()],
322338
[_NoneConstraint()],
@@ -342,6 +358,12 @@ def test_generate_invalid_param_val_all_valid(constraints):
342358
"constraint",
343359
[
344360
_ArrayLikes(),
361+
pytest.param(
362+
_DataFrames(),
363+
marks=pytest.mark.skipif(
364+
not has_pandas()[0], reason="Pandas not installed"
365+
),
366+
),
345367
_Callables(),
346368
_InstancesOf(list),
347369
_NoneConstraint(),
@@ -381,6 +403,13 @@ def test_generate_valid_param(constraint):
381403
(None, None),
382404
("array-like", [[1, 2], [3, 4]]),
383405
("array-like", np.array([[1, 2], [3, 4]])),
406+
pytest.param(
407+
"dataframe",
408+
has_pandas()[1],
409+
marks=pytest.mark.skipif(
410+
not has_pandas()[0], reason="Pandas not installed"
411+
),
412+
),
384413
("sparse matrix", csr_matrix([[1, 2], [3, 4]])),
385414
("random_state", 0),
386415
("random_state", np.random.RandomState(0)),
@@ -414,6 +443,13 @@ def test_is_satisfied_by(constraint_declaration, value):
414443
(Options(Real, {0.42, 1.23}), Options),
415444
("array-like", _ArrayLikes),
416445
("sparse matrix", _SparseMatrices),
446+
pytest.param(
447+
"dataframe",
448+
_DataFrames,
449+
marks=pytest.mark.skipif(
450+
not has_pandas()[0], reason="Pandas not installed"
451+
),
452+
),
417453
("random_state", _RandomStates),
418454
(None, _NoneConstraint),
419455
(callable, _Callables),

0 commit comments

Comments
 (0)