Skip to content

Commit bb4113c

Browse files
author
RJ Agrawal
committed
removed cross_val_score as it not required
1 parent 22dc685 commit bb4113c

File tree

5 files changed

+26
-61
lines changed

5 files changed

+26
-61
lines changed

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ Import
3232
Import what you need from the ``sklearn_pandas`` package. The choices are:
3333

3434
* ``DataFrameMapper``, a class for mapping pandas data frame columns to different sklearn transformations
35-
* ``cross_val_score``, similar to ``sklearn.cross_validation.cross_val_score`` but working on pandas DataFrames
35+
3636

3737
For this demonstration, we will import both::
3838

39-
>>> from sklearn_pandas import DataFrameMapper, cross_val_score
39+
>>> from sklearn_pandas import DataFrameMapper
4040

4141
For these examples, we'll also use pandas, numpy, and sklearn::
4242

sklearn_pandas/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
__version__ = '2.0.0'
22

33
from .dataframe_mapper import DataFrameMapper # NOQA
4-
from .cross_validation import cross_val_score, GridSearchCV, RandomizedSearchCV # NOQA
54
from .features_generator import gen_features # NOQA
65
from .transformers import NumericalTransformer # NOQA

sklearn_pandas/cross_validation.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,3 @@
1-
import warnings
2-
try:
3-
from sklearn.model_selection import cross_val_score as sk_cross_val_score
4-
from sklearn.model_selection import GridSearchCV as SKGridSearchCV
5-
from sklearn.model_selection import RandomizedSearchCV as \
6-
SKRandomizedSearchCV
7-
except ImportError:
8-
from sklearn.cross_validation import cross_val_score as sk_cross_val_score
9-
from sklearn.grid_search import GridSearchCV as SKGridSearchCV
10-
from sklearn.grid_search import RandomizedSearchCV as SKRandomizedSearchCV
11-
12-
DEPRECATION_MSG = '''
13-
Custom cross-validation compatibility shims are no longer needed for
14-
scikit-learn>=0.16.0 and will be dropped in sklearn-pandas==2.0.
15-
'''
16-
17-
18-
def cross_val_score(model, X, *args, **kwargs):
19-
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
20-
X = DataWrapper(X)
21-
return sk_cross_val_score(model, X, *args, **kwargs)
22-
23-
24-
class GridSearchCV(SKGridSearchCV):
25-
26-
def __init__(self, *args, **kwargs):
27-
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
28-
super(GridSearchCV, self).__init__(*args, **kwargs)
29-
30-
def fit(self, X, *params, **kwparams):
31-
return super(GridSearchCV, self).fit(
32-
DataWrapper(X), *params, **kwparams)
33-
34-
def predict(self, X, *params, **kwparams):
35-
return super(GridSearchCV, self).predict(
36-
DataWrapper(X), *params, **kwparams)
37-
38-
39-
try:
40-
class RandomizedSearchCV(SKRandomizedSearchCV):
41-
42-
def __init__(self, *args, **kwargs):
43-
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
44-
super(RandomizedSearchCV, self).__init__(*args, **kwargs)
45-
46-
def fit(self, X, *params, **kwparams):
47-
return super(RandomizedSearchCV, self).fit(
48-
DataWrapper(X), *params, **kwparams)
49-
50-
def predict(self, X, *params, **kwparams):
51-
return super(RandomizedSearchCV, self).predict(
52-
DataWrapper(X), *params, **kwparams)
53-
except AttributeError:
54-
pass
55-
56-
571
class DataWrapper(object):
582

593
def __init__(self, df):

tests/test_dataframe_mapper.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from scipy import sparse
1515
from sklearn.datasets import load_iris
1616
from sklearn.pipeline import Pipeline
17+
from sklearn.model_selection import cross_val_score
1718
from sklearn.svm import SVC
1819
from sklearn.feature_extraction.text import CountVectorizer
1920
from sklearn.feature_extraction import DictVectorizer
@@ -27,7 +28,7 @@
2728
from numpy.testing import assert_array_equal
2829
import pickle
2930

30-
from sklearn_pandas import DataFrameMapper, cross_val_score
31+
from sklearn_pandas import DataFrameMapper
3132
from sklearn_pandas.dataframe_mapper import _handle_feature, _build_transformer
3233
from sklearn_pandas.pipeline import TransformerPipeline
3334

@@ -882,6 +883,27 @@ def test_with_car_dataframe(cars_dataframe):
882883
assert scores.mean() > 0.30
883884

884885

886+
def test_direct_cross_validation(iris_dataframe):
887+
"""
888+
Starting with sklearn>=0.16.0 we no longer need CV wrappers for dataframes.
889+
See https://github.com/paulgb/sklearn-pandas/issues/11
890+
"""
891+
pipeline = Pipeline([
892+
("preprocess", DataFrameMapper([
893+
("petal length (cm)", None),
894+
("petal width (cm)", None),
895+
("sepal length (cm)", None),
896+
("sepal width (cm)", None),
897+
])),
898+
("classify", SVC(kernel='linear'))
899+
])
900+
data = iris_dataframe.drop("species", axis=1)
901+
labels = iris_dataframe["species"]
902+
scores = cross_val_score(pipeline, data, labels)
903+
assert scores.mean() > 0.96
904+
assert (scores.std() * 2) < 0.04
905+
906+
885907
def test_heterogeneous_output_types_input_df():
886908
"""
887909
Modify feat2, but pass feat1 through unmodified.

tests/test_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import joblib
66

77
from sklearn_pandas import DataFrameMapper
8-
from sklearn_pandas.transformers import NumericalTransformer
8+
from sklearn_pandas import NumericalTransformer
99

1010

1111
@pytest.fixture

0 commit comments

Comments
 (0)