File tree Expand file tree Collapse file tree 11 files changed +27
-15
lines changed Expand file tree Collapse file tree 11 files changed +27
-15
lines changed Original file line number Diff line number Diff line change @@ -60,6 +60,9 @@ Bug fixes
60
60
and thus to obtain a deterministic results when using the same random state.
61
61
:issue: `447 ` by :user: `Guillaume Lemaitre <glemaitre> `.
62
62
63
+ - Force to clone scikit-learn estimator passed as attributes to samplers.
64
+ :issue: `446 ` by :user: `Guillaume Lemaitre <glemaitre> `.
65
+
63
66
Maintenance
64
67
...........
65
68
Original file line number Diff line number Diff line change 9
9
import logging
10
10
import warnings
11
11
12
+ from sklearn .base import clone
12
13
from sklearn .utils import check_X_y
13
14
14
15
from ..base import SamplerMixin
@@ -103,7 +104,7 @@ def _validate_estimator(self):
103
104
"Private function to validate SMOTE and ENN objects"
104
105
if self .smote is not None :
105
106
if isinstance (self .smote , SMOTE ):
106
- self .smote_ = self .smote
107
+ self .smote_ = clone ( self .smote )
107
108
else :
108
109
raise ValueError ('smote needs to be a SMOTE object.'
109
110
'Got {} instead.' .format (type (self .smote )))
@@ -116,7 +117,7 @@ def _validate_estimator(self):
116
117
117
118
if self .enn is not None :
118
119
if isinstance (self .enn , EditedNearestNeighbours ):
119
- self .enn_ = self .enn
120
+ self .enn_ = clone ( self .enn )
120
121
else :
121
122
raise ValueError ('enn needs to be an EditedNearestNeighbours.'
122
123
' Got {} instead.' .format (type (self .enn )))
Original file line number Diff line number Diff line change 10
10
import logging
11
11
import warnings
12
12
13
+ from sklearn .base import clone
13
14
from sklearn .utils import check_X_y
14
15
15
16
from ..base import SamplerMixin
@@ -111,7 +112,7 @@ def _validate_estimator(self):
111
112
112
113
if self .smote is not None :
113
114
if isinstance (self .smote , SMOTE ):
114
- self .smote_ = self .smote
115
+ self .smote_ = clone ( self .smote )
115
116
else :
116
117
raise ValueError ('smote needs to be a SMOTE object.'
117
118
'Got {} instead.' .format (type (self .smote )))
@@ -124,7 +125,7 @@ def _validate_estimator(self):
124
125
125
126
if self .tomek is not None :
126
127
if isinstance (self .tomek , TomekLinks ):
127
- self .tomek_ = self .tomek
128
+ self .tomek_ = clone ( self .tomek )
128
129
else :
129
130
raise ValueError ('tomek needs to be a TomekLinks object.'
130
131
'Got {} instead.' .format (type (self .tomek )))
Original file line number Diff line number Diff line change 8
8
9
9
import numpy as np
10
10
11
- from sklearn .base import ClassifierMixin
11
+ from sklearn .base import ClassifierMixin , clone
12
12
from sklearn .neighbors import KNeighborsClassifier
13
13
from sklearn .utils import check_random_state , safe_indexing
14
14
from sklearn .model_selection import cross_val_predict
@@ -142,7 +142,7 @@ def _validate_estimator(self):
142
142
if (self .estimator is not None and
143
143
isinstance (self .estimator , ClassifierMixin ) and
144
144
hasattr (self .estimator , 'predict' )):
145
- self .estimator_ = self .estimator
145
+ self .estimator_ = clone ( self .estimator )
146
146
elif self .estimator is None :
147
147
self .estimator_ = KNeighborsClassifier ()
148
148
else :
Original file line number Diff line number Diff line change 14
14
15
15
from scipy import sparse
16
16
17
+ from sklearn .base import clone
17
18
from sklearn .svm import SVC
18
19
from sklearn .utils import check_random_state , safe_indexing
19
20
@@ -448,7 +449,7 @@ def _validate_estimator(self):
448
449
if self .svm_estimator is None :
449
450
self .svm_estimator_ = SVC (random_state = self .random_state )
450
451
elif isinstance (self .svm_estimator , SVC ):
451
- self .svm_estimator_ = self .svm_estimator
452
+ self .svm_estimator_ = clone ( self .svm_estimator )
452
453
else :
453
454
raise_isinstance_error ('svm_estimator' , [SVC ],
454
455
self .svm_estimator )
@@ -698,7 +699,7 @@ def _validate_estimator(self):
698
699
self .svm_estimator == 'deprecated' ):
699
700
self .svm_estimator_ = SVC (random_state = self .random_state )
700
701
elif isinstance (self .svm_estimator , SVC ):
701
- self .svm_estimator_ = self .svm_estimator
702
+ self .svm_estimator_ = clone ( self .svm_estimator )
702
703
else :
703
704
raise_isinstance_error ('svm_estimator' , [SVC ],
704
705
self .svm_estimator )
Original file line number Diff line number Diff line change 11
11
import numpy as np
12
12
from scipy import sparse
13
13
14
+ from sklearn .base import clone
14
15
from sklearn .cluster import KMeans
15
16
from sklearn .neighbors import NearestNeighbors
16
17
from sklearn .utils import safe_indexing
@@ -113,7 +114,7 @@ def _validate_estimator(self):
113
114
self .estimator_ = KMeans (
114
115
random_state = self .random_state , n_jobs = self .n_jobs )
115
116
elif isinstance (self .estimator , KMeans ):
116
- self .estimator_ = self .estimator
117
+ self .estimator_ = clone ( self .estimator )
117
118
else :
118
119
raise ValueError ('`estimator` has to be a KMeans clustering.'
119
120
' Got {} instead.' .format (type (self .estimator )))
Original file line number Diff line number Diff line change 13
13
14
14
from scipy .sparse import issparse
15
15
16
+ from sklearn .base import clone
16
17
from sklearn .neighbors import KNeighborsClassifier
17
18
from sklearn .utils import check_random_state , safe_indexing
18
19
@@ -121,7 +122,7 @@ def _validate_estimator(self):
121
122
self .estimator_ = KNeighborsClassifier (
122
123
n_neighbors = self .n_neighbors , n_jobs = self .n_jobs )
123
124
elif isinstance (self .n_neighbors , KNeighborsClassifier ):
124
- self .estimator_ = self .n_neighbors
125
+ self .estimator_ = clone ( self .n_neighbors )
125
126
else :
126
127
raise ValueError ('`n_neighbors` has to be a int or an object'
127
128
' inhereited from KNeighborsClassifier.'
Original file line number Diff line number Diff line change 12
12
13
13
import numpy as np
14
14
15
- from sklearn .base import ClassifierMixin
15
+ from sklearn .base import ClassifierMixin , clone
16
16
from sklearn .ensemble import RandomForestClassifier
17
17
from sklearn .model_selection import StratifiedKFold
18
18
from sklearn .utils import safe_indexing
@@ -117,7 +117,7 @@ def _validate_estimator(self):
117
117
if (self .estimator is not None and
118
118
isinstance (self .estimator , ClassifierMixin ) and
119
119
hasattr (self .estimator , 'predict_proba' )):
120
- self .estimator_ = self .estimator
120
+ self .estimator_ = clone ( self .estimator )
121
121
elif self .estimator is None :
122
122
self .estimator_ = RandomForestClassifier (
123
123
random_state = self .random_state , n_jobs = self .n_jobs )
Original file line number Diff line number Diff line change 9
9
from collections import Counter
10
10
11
11
import numpy as np
12
+
13
+ from sklearn .base import clone
12
14
from sklearn .neighbors import KNeighborsClassifier
13
15
from sklearn .utils import check_random_state , safe_indexing
14
16
@@ -114,7 +116,7 @@ def _validate_estimator(self):
114
116
self .estimator_ = KNeighborsClassifier (
115
117
n_neighbors = self .n_neighbors , n_jobs = self .n_jobs )
116
118
elif isinstance (self .n_neighbors , KNeighborsClassifier ):
117
- self .estimator_ = self .n_neighbors
119
+ self .estimator_ = clone ( self .n_neighbors )
118
120
else :
119
121
raise ValueError ('`n_neighbors` has to be a int or an object'
120
122
' inhereited from KNeighborsClassifier.'
Original file line number Diff line number Diff line change @@ -36,7 +36,8 @@ def test_check_neighbors_object():
36
36
assert issubclass (type (estimator ), KNeighborsMixin )
37
37
assert estimator .n_neighbors == 2
38
38
estimator = NearestNeighbors (n_neighbors )
39
- assert estimator is check_neighbors_object (name , estimator )
39
+ estimator_cloned = check_neighbors_object (name , estimator )
40
+ assert estimator .n_neighbors == estimator_cloned .n_neighbors
40
41
n_neighbors = 'rnd'
41
42
with pytest .raises (ValueError , match = "has to be one of" ):
42
43
check_neighbors_object (name , n_neighbors )
You can’t perform that action at this time.
0 commit comments