Skip to content

Commit ac6e63b

Browse files
Added TemplateClassifer
1 parent b00ff08 commit ac6e63b

File tree

4 files changed

+107
-9
lines changed

4 files changed

+107
-9
lines changed

examples/plot_classifier.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
============================
3+
Plotting Template Classifier
4+
============================
5+
6+
An example plot of TemplateClassifier
7+
"""
8+
import numpy as np
9+
from skltemplate import TemplateClassifier
10+
from matplotlib import pyplot as plt
11+
12+
13+
X = [[0, 0], [1, 1]]
14+
y = [0, 1]
15+
clf = TemplateClassifier()
16+
clf.fit(X, y)
17+
18+
rng = np.random.RandomState(13)
19+
X_test = rng.rand(500, 2)
20+
y_pred = clf.predict(X_test)
21+
22+
X_0 = X_test[y_pred == 0]
23+
X_1 = X_test[y_pred == 1]
24+
25+
ax0 = plt.scatter(X_0[:, 0], X_0[:, 1], c='darkorange', s=30)
26+
ax1 = plt.scatter(X_1[:, 0], X_1[:, 1], c='deepskyblue', s=30)
27+
28+
29+
plt.legend([ax0, ax1], ['Class 0', 'Class 1'])
30+
plt.show()

skltemplate/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .template import TemplateEstimator
1+
from .template import TemplateEstimator, TemplateClassifier
22
from . import template
33

4-
__all__ = ['TemplateEstimator', 'template']
4+
__all__ = ['TemplateEstimator', 'TemplateClassifier', 'template']

skltemplate/template.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""
22
This is a module to be used as a reference for building other modules
33
"""
4-
from sklearn.base import BaseEstimator
5-
from sklearn.utils import check_X_y, check_array
4+
import numpy as np
5+
from sklearn.base import BaseEstimator, ClassifierMixin
6+
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
7+
from sklearn.utils.multiclass import unique_labels
8+
from sklearn.metrics import euclidean_distances
9+
610

711
class TemplateEstimator(BaseEstimator):
812
""" A template estimator to be used as a reference implementation .
9-
13+
1014
Parameters
1115
----------
1216
demo_param : str, optional
@@ -21,7 +25,7 @@ def fit(self, X, y):
2125
Parameters
2226
----------
2327
X : array-like or sparse matrix of shape = [n_samples, n_features]
24-
The training input samples.
28+
The training input samples.
2529
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
2630
The target values (class labels in classification, real numbers in
2731
regression).
@@ -47,4 +51,63 @@ def predict(self, X):
4751
Returns :math:`x^2` where :math:`x` is the first column of `X`.
4852
"""
4953
X = check_array(X)
50-
return X[:, 0]**2
54+
return X[:, 0]**2
55+
56+
57+
class TemplateClassifier(BaseEstimator, ClassifierMixin):
58+
""" An example for a classifier which implements a 1-NN algorithm.
59+
60+
Parameters
61+
----------
62+
demo_param : str, optional
63+
A parameter used for demonstation of how to pass and store paramters.
64+
"""
65+
def __init__(self, demo_param='demo'):
66+
self.demo_param = demo_param
67+
68+
def fit(self, X, y):
69+
"""A reference implementation of a fitting function for a classifier.
70+
71+
Parameters
72+
----------
73+
X : array-like or sparse matrix of shape = [n_samples, n_features]
74+
The training input samples.
75+
y : array-like, shape = [n_samples]
76+
The target values. An array of int.
77+
Returns
78+
-------
79+
self : object
80+
Returns self.
81+
"""
82+
# Check that X and y have correct shape
83+
X, y = check_X_y(X, y)
84+
# Store the classes seen durind fit
85+
self.classes_ = unique_labels(y)
86+
87+
self.X_ = X
88+
self.y_ = y
89+
# Return the classifier
90+
return self
91+
92+
def predict(self, X):
93+
""" A reference implementation of a prediction for a classifier.
94+
95+
Parameters
96+
----------
97+
X : array-like of shape = [n_samples, n_features]
98+
The input samples.
99+
100+
Returns
101+
-------
102+
y : array of int of shape = [n_samples]
103+
The label for each sample is the label of the closest sample
104+
seen udring fit.
105+
"""
106+
# Check is fit had been called
107+
check_is_fitted(self, ['X_', 'y_'])
108+
109+
# Input validation
110+
X = check_array(X)
111+
112+
closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
113+
return self.y_[closest]

skltemplate/tests/test_common.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from sklearn.utils.estimator_checks import check_estimator
2-
from skltemplate import TemplateEstimator
2+
from skltemplate import TemplateEstimator, TemplateClassifier
33

4-
def test_common():
4+
5+
def test_estimator():
56
return check_estimator(TemplateEstimator)
7+
8+
9+
def test_classifier():
10+
return check_estimator(TemplateClassifier)

0 commit comments

Comments
 (0)