1
1
"""
2
2
This is a module to be used as a reference for building other modules
3
3
"""
4
- from sklearn .base import BaseEstimator
5
- from sklearn .utils import check_X_y , check_array
4
+ import numpy as np
5
+ from sklearn .base import BaseEstimator , ClassifierMixin , TransformerMixin
6
+ from sklearn .utils .validation import check_X_y , check_array , check_is_fitted
7
+ from sklearn .utils .multiclass import unique_labels
8
+ from sklearn .metrics import euclidean_distances
9
+
6
10
7
11
class TemplateEstimator (BaseEstimator ):
8
12
""" A template estimator to be used as a reference implementation .
9
-
13
+
10
14
Parameters
11
15
----------
12
16
demo_param : str, optional
@@ -21,7 +25,7 @@ def fit(self, X, y):
21
25
Parameters
22
26
----------
23
27
X : array-like or sparse matrix of shape = [n_samples, n_features]
24
- The training input samples.
28
+ The training input samples.
25
29
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
26
30
The target values (class labels in classification, real numbers in
27
31
regression).
@@ -47,4 +51,135 @@ def predict(self, X):
47
51
Returns :math:`x^2` where :math:`x` is the first column of `X`.
48
52
"""
49
53
X = check_array (X )
50
- return X [:, 0 ]** 2
54
+ return X [:, 0 ]** 2
55
+
56
+
57
+ class TemplateClassifier (BaseEstimator , ClassifierMixin ):
58
+ """ An example classifier which implements a 1-NN algorithm.
59
+
60
+ Parameters
61
+ ----------
62
+ demo_param : str, optional
63
+ A parameter used for demonstation of how to pass and store paramters.
64
+
65
+ Attributes
66
+ ----------
67
+ X_ : array, shape = [n_samples, n_features]
68
+ The input passed during :meth:`fit`
69
+ y_ : array, shape = [n_samples]
70
+ The labels passed during :meth:`fit`
71
+ """
72
+ def __init__ (self , demo_param = 'demo' ):
73
+ self .demo_param = demo_param
74
+
75
+ def fit (self , X , y ):
76
+ """A reference implementation of a fitting function for a classifier.
77
+
78
+ Parameters
79
+ ----------
80
+ X : array-like, shape = [n_samples, n_features]
81
+ The training input samples.
82
+ y : array-like, shape = [n_samples]
83
+ The target values. An array of int.
84
+ Returns
85
+ -------
86
+ self : object
87
+ Returns self.
88
+ """
89
+ # Check that X and y have correct shape
90
+ X , y = check_X_y (X , y )
91
+ # Store the classes seen durind fit
92
+ self .classes_ = unique_labels (y )
93
+
94
+ self .X_ = X
95
+ self .y_ = y
96
+ # Return the classifier
97
+ return self
98
+
99
+ def predict (self , X ):
100
+ """ A reference implementation of a prediction for a classifier.
101
+
102
+ Parameters
103
+ ----------
104
+ X : array-like of shape = [n_samples, n_features]
105
+ The input samples.
106
+
107
+ Returns
108
+ -------
109
+ y : array of int of shape = [n_samples]
110
+ The label for each sample is the label of the closest sample
111
+ seen udring fit.
112
+ """
113
+ # Check is fit had been called
114
+ check_is_fitted (self , ['X_' , 'y_' ])
115
+
116
+ # Input validation
117
+ X = check_array (X )
118
+
119
+ closest = np .argmin (euclidean_distances (X , self .X_ ), axis = 1 )
120
+ return self .y_ [closest ]
121
+
122
+
123
+ class TemplateTransformer (BaseEstimator , TransformerMixin ):
124
+ """ An example transformer that returns the element-wise square root..
125
+
126
+ Parameters
127
+ ----------
128
+ demo_param : str, optional
129
+ A parameter used for demonstation of how to pass and store paramters.
130
+
131
+ Attributes
132
+ ----------
133
+ input_shape : tuple
134
+ The shape the data passed to :meth:`fit`
135
+ """
136
+ def __init__ (self , demo_param = 'demo' ):
137
+ self .demo_param = demo_param
138
+
139
+ def fit (self , X , y = None ):
140
+ """A reference implementation of a fitting function for a transformer.
141
+
142
+ Parameters
143
+ ----------
144
+ X : array-like or sparse matrix of shape = [n_samples, n_features]
145
+ The training input samples.
146
+ y : array-like, shape = [n_samples]
147
+ The target values. An array of int.
148
+ Returns
149
+ -------
150
+ self : object
151
+ Returns self.
152
+ """
153
+ X = check_array (X )
154
+
155
+ self .input_shape_ = X .shape
156
+
157
+ # Return the classifier
158
+ return self
159
+
160
+ def transform (self , X ):
161
+ """ A reference implementation of a transform function.
162
+
163
+ Parameters
164
+ ----------
165
+ X : array-like of shape = [n_samples, n_features]
166
+ The input samples.
167
+
168
+ Returns
169
+ -------
170
+ X_transformed : array of int of shape = [n_samples, n_features]
171
+ The array containing the element-wise square roots of the values
172
+ in `X`
173
+ """
174
+ # Check is fit had been called
175
+ check_is_fitted (self , ['input_shape_' ])
176
+
177
+ # Input validation
178
+ X = check_array (X )
179
+
180
+ # Check that the input is of the same shape as the one passed
181
+ # during fit.
182
+ if X .shape != self .input_shape_ :
183
+ raise ValueError ('Shape of input is different from what was seen'
184
+ 'in `fit`' )
185
+ return np .sqrt (X )
0 commit comments