Skip to content

Commit 4273901

Browse files
committed
Merge pull request #18 from vighneshbirodkar/new_classes
Added demo code for a Classifier and Transformer
2 parents 5da57b3 + b25babd commit 4273901

File tree

6 files changed

+225
-11
lines changed

6 files changed

+225
-11
lines changed

examples/plot_classifier.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
============================
3+
Plotting Template Classifier
4+
============================
5+
6+
An example plot of :class:`skltemplate.template.TemplateClassifier`
7+
"""
8+
import numpy as np
9+
from skltemplate import TemplateClassifier
10+
from matplotlib import pyplot as plt
11+
12+
13+
X = [[0, 0], [1, 1]]
14+
y = [0, 1]
15+
clf = TemplateClassifier()
16+
clf.fit(X, y)
17+
18+
rng = np.random.RandomState(13)
19+
X_test = rng.rand(500, 2)
20+
y_pred = clf.predict(X_test)
21+
22+
X_0 = X_test[y_pred == 0]
23+
X_1 = X_test[y_pred == 1]
24+
25+
26+
p0 = plt.scatter(0, 0, c='red', s=100)
27+
p1 = plt.scatter(1, 1, c='blue', s=100)
28+
29+
ax0 = plt.scatter(X_0[:, 0], X_0[:, 1], c='crimson', s=50)
30+
ax1 = plt.scatter(X_1[:, 0], X_1[:, 1], c='deepskyblue', s=50)
31+
32+
leg = plt.legend([p0, p1, ax0, ax1],
33+
['Point 0', 'Point 1', 'Class 0', 'Class 1'],
34+
loc='upper left', fancybox=True, scatterpoints=1)
35+
leg.get_frame().set_alpha(0.5)
36+
37+
plt.xlabel('Feature 1')
38+
plt.ylabel('Feature 2')
39+
plt.xlim([-.5, 1.5])
40+
41+
plt.show()

examples/plot_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Plotting Template Estimator
44
===========================
55
6-
An example plot of TemplateEstimator
6+
An example plot of :class:`skltemplate.template.TemplateEstimator`
77
"""
88
import numpy as np
99
from skltemplate import TemplateEstimator
@@ -14,4 +14,4 @@
1414
estimator = TemplateEstimator()
1515
estimator.fit(X, y)
1616
plt.plot(estimator.predict(X))
17-
plt.show()
17+
plt.show()

examples/plot_transformer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
=============================
3+
Plotting Template Transformer
4+
=============================
5+
6+
An example plot of :class:`skltemplate.template.TemplateTransformer`
7+
"""
8+
import numpy as np
9+
from skltemplate import TemplateTransformer
10+
from matplotlib import pyplot as plt
11+
12+
X = np.arange(50, dtype=np.float).reshape(-1, 1)
13+
X /= 50
14+
estimator = TemplateTransformer()
15+
X_transformed = estimator.fit_transform(X)
16+
17+
plt.plot(X.flatten(), label='Original Data')
18+
plt.plot(X_transformed.flatten(), label='Transformed Data')
19+
plt.title('Plots of original and transformed data')
20+
21+
plt.legend(loc='best')
22+
plt.grid(True)
23+
plt.xlabel('Index')
24+
plt.ylabel('Value of Data')
25+
26+
plt.show()

skltemplate/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from .template import TemplateEstimator
1+
from .template import (TemplateEstimator, TemplateClassifier,
2+
TemplateTransformer)
23
from . import template
34

4-
__all__ = ['TemplateEstimator', 'template']
5+
__all__ = ['TemplateEstimator', 'TemplateClassifier',
6+
'TemplateTransformer', 'template']

skltemplate/template.py

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""
22
This is a module to be used as a reference for building other modules
33
"""
4-
from sklearn.base import BaseEstimator
5-
from sklearn.utils import check_X_y, check_array
4+
import numpy as np
5+
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
6+
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
7+
from sklearn.utils.multiclass import unique_labels
8+
from sklearn.metrics import euclidean_distances
9+
610

711
class TemplateEstimator(BaseEstimator):
812
""" A template estimator to be used as a reference implementation .
9-
13+
1014
Parameters
1115
----------
1216
demo_param : str, optional
@@ -21,7 +25,7 @@ def fit(self, X, y):
2125
Parameters
2226
----------
2327
X : array-like or sparse matrix of shape = [n_samples, n_features]
24-
The training input samples.
28+
The training input samples.
2529
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
2630
The target values (class labels in classification, real numbers in
2731
regression).
@@ -47,4 +51,135 @@ def predict(self, X):
4751
Returns :math:`x^2` where :math:`x` is the first column of `X`.
4852
"""
4953
X = check_array(X)
50-
return X[:, 0]**2
54+
return X[:, 0]**2
55+
56+
57+
class TemplateClassifier(BaseEstimator, ClassifierMixin):
58+
""" An example classifier which implements a 1-NN algorithm.
59+
60+
Parameters
61+
----------
62+
demo_param : str, optional
63+
A parameter used for demonstation of how to pass and store paramters.
64+
65+
Attributes
66+
----------
67+
X_ : array, shape = [n_samples, n_features]
68+
The input passed during :meth:`fit`
69+
y_ : array, shape = [n_samples]
70+
The labels passed during :meth:`fit`
71+
"""
72+
def __init__(self, demo_param='demo'):
73+
self.demo_param = demo_param
74+
75+
def fit(self, X, y):
76+
"""A reference implementation of a fitting function for a classifier.
77+
78+
Parameters
79+
----------
80+
X : array-like, shape = [n_samples, n_features]
81+
The training input samples.
82+
y : array-like, shape = [n_samples]
83+
The target values. An array of int.
84+
Returns
85+
-------
86+
self : object
87+
Returns self.
88+
"""
89+
# Check that X and y have correct shape
90+
X, y = check_X_y(X, y)
91+
# Store the classes seen durind fit
92+
self.classes_ = unique_labels(y)
93+
94+
self.X_ = X
95+
self.y_ = y
96+
# Return the classifier
97+
return self
98+
99+
def predict(self, X):
100+
""" A reference implementation of a prediction for a classifier.
101+
102+
Parameters
103+
----------
104+
X : array-like of shape = [n_samples, n_features]
105+
The input samples.
106+
107+
Returns
108+
-------
109+
y : array of int of shape = [n_samples]
110+
The label for each sample is the label of the closest sample
111+
seen udring fit.
112+
"""
113+
# Check is fit had been called
114+
check_is_fitted(self, ['X_', 'y_'])
115+
116+
# Input validation
117+
X = check_array(X)
118+
119+
closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
120+
return self.y_[closest]
121+
122+
123+
class TemplateTransformer(BaseEstimator, TransformerMixin):
124+
""" An example transformer that returns the element-wise square root..
125+
126+
Parameters
127+
----------
128+
demo_param : str, optional
129+
A parameter used for demonstation of how to pass and store paramters.
130+
131+
Attributes
132+
----------
133+
input_shape : tuple
134+
The shape the data passed to :meth:`fit`
135+
"""
136+
def __init__(self, demo_param='demo'):
137+
self.demo_param = demo_param
138+
139+
def fit(self, X, y=None):
140+
"""A reference implementation of a fitting function for a transformer.
141+
142+
Parameters
143+
----------
144+
X : array-like or sparse matrix of shape = [n_samples, n_features]
145+
The training input samples.
146+
y : array-like, shape = [n_samples]
147+
The target values. An array of int.
148+
Returns
149+
-------
150+
self : object
151+
Returns self.
152+
"""
153+
X = check_array(X)
154+
155+
self.input_shape_ = X.shape
156+
157+
# Return the classifier
158+
return self
159+
160+
def transform(self, X):
161+
""" A reference implementation of a transform function.
162+
163+
Parameters
164+
----------
165+
X : array-like of shape = [n_samples, n_features]
166+
The input samples.
167+
168+
Returns
169+
-------
170+
X_transformed : array of int of shape = [n_samples, n_features]
171+
The array containing the element-wise square roots of the values
172+
in `X`
173+
"""
174+
# Check is fit had been called
175+
check_is_fitted(self, ['input_shape_'])
176+
177+
# Input validation
178+
X = check_array(X)
179+
180+
# Check that the input is of the same shape as the one passed
181+
# during fit.
182+
if X.shape != self.input_shape_:
183+
raise ValueError('Shape of input is different from what was seen'
184+
'in `fit`')
185+
return np.sqrt(X)

skltemplate/tests/test_common.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
from sklearn.utils.estimator_checks import check_estimator
2-
from skltemplate import TemplateEstimator
2+
from skltemplate import (TemplateEstimator, TemplateClassifier,
3+
TemplateTransformer)
34

4-
def test_common():
5+
6+
def test_estimator():
57
return check_estimator(TemplateEstimator)
8+
9+
10+
def test_classifier():
11+
return check_estimator(TemplateClassifier)
12+
13+
14+
def test_transformer():
15+
return check_estimator(TemplateTransformer)

0 commit comments

Comments
 (0)