Skip to content

Commit 243bdb1

Browse files
Added TemplateTransformer
1 parent ac6e63b commit 243bdb1

File tree

6 files changed

+110
-10
lines changed

6 files changed

+110
-10
lines changed

examples/plot_classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Plotting Template Classifier
44
============================
55
6-
An example plot of TemplateClassifier
6+
An example plot of :class:`skltemplate.template.TemplateClassifier`
77
"""
88
import numpy as np
99
from skltemplate import TemplateClassifier
@@ -22,7 +22,7 @@
2222
X_0 = X_test[y_pred == 0]
2323
X_1 = X_test[y_pred == 1]
2424

25-
ax0 = plt.scatter(X_0[:, 0], X_0[:, 1], c='darkorange', s=30)
25+
ax0 = plt.scatter(X_0[:, 0], X_0[:, 1], c='crimson', s=30)
2626
ax1 = plt.scatter(X_1[:, 0], X_1[:, 1], c='deepskyblue', s=30)
2727

2828

examples/plot_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Plotting Template Estimator
44
===========================
55
6-
An example plot of TemplateEstimator
6+
An example plot of :class:`skltemplate.template.TemplateEstimator`
77
"""
88
import numpy as np
99
from skltemplate import TemplateEstimator
@@ -14,4 +14,4 @@
1414
estimator = TemplateEstimator()
1515
estimator.fit(X, y)
1616
plt.plot(estimator.predict(X))
17-
plt.show()
17+
plt.show()

examples/plot_transformer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
=============================
3+
Plotting Template Transformer
4+
=============================
5+
6+
An example plot of :class:`skltemplate.template.TemplateTransformer`
7+
"""
8+
import numpy as np
9+
from skltemplate import TemplateTransformer
10+
from matplotlib import pyplot as plt
11+
12+
X = np.arange(50).reshape(-1, 1)
13+
estimator = TemplateTransformer()
14+
X_transformed = estimator.fit_transform(X)
15+
16+
plt.plot(X.flatten()/X.max(), label='Original Data')
17+
plt.plot(X_transformed.flatten()/X_transformed.max(), label='Transformed Data')
18+
plt.title('Scaled plots of original and transformed data')
19+
20+
plt.legend(loc='best')
21+
plt.show()

skltemplate/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from .template import TemplateEstimator, TemplateClassifier
1+
from .template import (TemplateEstimator, TemplateClassifier,
2+
TemplateTransformer)
23
from . import template
34

4-
__all__ = ['TemplateEstimator', 'TemplateClassifier', 'template']
5+
__all__ = ['TemplateEstimator', 'TemplateClassifier',
6+
'TemplateTransformer', 'template']

skltemplate/template.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This is a module to be used as a reference for building other modules
33
"""
44
import numpy as np
5-
from sklearn.base import BaseEstimator, ClassifierMixin
5+
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
66
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
77
from sklearn.utils.multiclass import unique_labels
88
from sklearn.metrics import euclidean_distances
@@ -55,12 +55,19 @@ def predict(self, X):
5555

5656

5757
class TemplateClassifier(BaseEstimator, ClassifierMixin):
58-
""" An example for a classifier which implements a 1-NN algorithm.
58+
""" An example classifier which implements a 1-NN algorithm.
5959
6060
Parameters
6161
----------
6262
demo_param : str, optional
6363
A parameter used for demonstation of how to pass and store paramters.
64+
65+
Attributes
66+
----------
67+
X_ : array, shape = [n_samples, n_features]
68+
The input passed during :meth:`fit`
69+
y_ : array, shape = [n_samples]
70+
The labels passed during :meth:`fit`
6471
"""
6572
def __init__(self, demo_param='demo'):
6673
self.demo_param = demo_param
@@ -70,7 +77,7 @@ def fit(self, X, y):
7077
7178
Parameters
7279
----------
73-
X : array-like or sparse matrix of shape = [n_samples, n_features]
80+
X : array-like, shape = [n_samples, n_features]
7481
The training input samples.
7582
y : array-like, shape = [n_samples]
7683
The target values. An array of int.
@@ -111,3 +118,68 @@ def predict(self, X):
111118

112119
closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
113120
return self.y_[closest]
121+
122+
123+
class TemplateTransformer(BaseEstimator, TransformerMixin):
124+
""" An example transformer that returns the element-wise square root..
125+
126+
Parameters
127+
----------
128+
demo_param : str, optional
129+
A parameter used for demonstation of how to pass and store paramters.
130+
131+
Attributes
132+
----------
133+
input_shape : tuple
134+
The shape the data passed to :meth:`fit`
135+
"""
136+
def __init__(self, demo_param='demo'):
137+
self.demo_param = demo_param
138+
139+
def fit(self, X, y=None):
140+
"""A reference implementation of a fitting function for a transformer.
141+
142+
Parameters
143+
----------
144+
X : array-like or sparse matrix of shape = [n_samples, n_features]
145+
The training input samples.
146+
y : array-like, shape = [n_samples]
147+
The target values. An array of int.
148+
Returns
149+
-------
150+
self : object
151+
Returns self.
152+
"""
153+
X = check_array(X)
154+
155+
self.input_shape_ = X.shape
156+
157+
# Return the classifier
158+
return self
159+
160+
def transform(self, X):
161+
""" A reference implementation of a transform function.
162+
163+
Parameters
164+
----------
165+
X : array-like of shape = [n_samples, n_features]
166+
The input samples.
167+
168+
Returns
169+
-------
170+
X_transformed : array of int of shape = [n_samples, n_features]
171+
The array containing the element-wise square roots of the values
172+
in `X`
173+
"""
174+
# Check is fit had been called
175+
check_is_fitted(self, ['input_shape_'])
176+
177+
# Input validation
178+
X = check_array(X)
179+
180+
# Check that the input is of the same shape as the one passed
181+
# during fit.
182+
if X.shape != self.input_shape_:
183+
raise ValueError('Shape of input is different from what was seen'
184+
'in `fit`')
185+
return np.sqrt(X)

skltemplate/tests/test_common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from sklearn.utils.estimator_checks import check_estimator
2-
from skltemplate import TemplateEstimator, TemplateClassifier
2+
from skltemplate import (TemplateEstimator, TemplateClassifier,
3+
TemplateTransformer)
34

45

56
def test_estimator():
@@ -8,3 +9,7 @@ def test_estimator():
89

910
def test_classifier():
1011
return check_estimator(TemplateClassifier)
12+
13+
14+
def test_transformer():
15+
return check_estimator(TemplateTransformer)

0 commit comments

Comments
 (0)