Skip to content

Commit b6dffb5

Browse files
committed
Make Gaussian/Mahalanobis classifiers work with non-ndarray images (fixes #49).
1 parent 0871018 commit b6dffb5

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

spectral/algorithms/classifiers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def classify_image(self, image):
193193
'''
194194
import math
195195
import spectral
196-
if not self.cache_class_scores:
196+
if not (self.cache_class_scores and isinstance(image, np.ndarray)):
197197
return super(GaussianClassifier, self).classify_image(image)
198198

199199
status = spectral._status
@@ -289,7 +289,7 @@ def classify_image(self, image):
289289
'''
290290
import spectral
291291
from .detectors import RX
292-
if not self.cache_class_scores:
292+
if not (self.cache_class_scores and isinstance(image, np.ndarray)):
293293
return super(MahalanobisDistanceClassifier,
294294
self).classify_image(image)
295295

spectral/tests/classifiers.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,31 @@ def test_gmlc_spectrum_image_equal(self):
8383
assert(gmlc.classify_spectrum(data[2, 2]) == \
8484
gmlc.classify_image(data)[2, 2])
8585

86+
def test_gmlc_classify_spyfile_runs(self):
87+
'''Tests that GaussianClassifier classifies a SpyFile object.'''
88+
gmlc = spy.GaussianClassifier(self.ts, min_samples=600)
89+
ret = gmlc.classify_image(self.image)
90+
91+
def test_gmlc_classify_transformedimage_runs(self):
92+
'''Tests that GaussianClassifier classifies a TransformedImage object.'''
93+
pc = spy.principal_components(self.data).reduce(num=3)
94+
ximg = pc.transform(self.image)
95+
ts = spy.create_training_classes(pc.transform(self.data), self.gt,
96+
calc_stats=True)
97+
gmlc = spy.GaussianClassifier(ts)
98+
ret = gmlc.classify_image(ximg)
99+
100+
def test_gmlc_classify_ndarray_transformedimage_equal(self):
101+
'''Gaussian classification of an ndarray and TransformedImage are equal'''
102+
pc = spy.principal_components(self.data).reduce(num=3)
103+
ximg = pc.transform(self.image)
104+
ts = spy.create_training_classes(pc.transform(self.data), self.gt,
105+
calc_stats=True)
106+
gmlc = spy.GaussianClassifier(ts)
107+
cl_ximg = gmlc.classify_image(ximg)
108+
cl_ndarray = gmlc.classify_image(pc.transform(self.data))
109+
assert(np.all(cl_ximg == cl_ndarray))
110+
86111
def test_mahalanobis_class_mean(self):
87112
'''Test that a class's mean spectrum is classified as that class.
88113
Note this assumes that class priors are equal.
@@ -91,6 +116,31 @@ def test_mahalanobis_class_mean(self):
91116
cl = mdc.classes[0]
92117
assert(mdc.classify(cl.stats.mean) == cl.index)
93118

119+
def test_mahalanobis_classify_spyfile_runs(self):
120+
'''Mahalanobis classifier works with a SpyFile object.'''
121+
mdc = spy.MahalanobisDistanceClassifier(self.ts)
122+
ret = mdc.classify_image(self.image)
123+
124+
def test_mahalanobis_classify_transformedimage_runs(self):
125+
'''Mahalanobis classifier works with a TransformedImage object.'''
126+
pc = spy.principal_components(self.data).reduce(num=3)
127+
ximg = pc.transform(self.image)
128+
ts = spy.create_training_classes(pc.transform(self.data), self.gt,
129+
calc_stats=True)
130+
gmlc = spy.MahalanobisDistanceClassifier(ts)
131+
ret = gmlc.classify_image(ximg)
132+
133+
def test_mahalanobis_classify_ndarray_transformedimage_equal(self):
134+
'''Mahalanobis classification of ndarray and TransformedImage are equal'''
135+
pc = spy.principal_components(self.data).reduce(num=3)
136+
ximg = pc.transform(self.image)
137+
ts = spy.create_training_classes(pc.transform(self.data), self.gt,
138+
calc_stats=True)
139+
mdc = spy.GaussianClassifier(ts)
140+
cl_ximg = mdc.classify_image(ximg)
141+
cl_ndarray = mdc.classify_image(pc.transform(self.data))
142+
assert(np.all(cl_ximg == cl_ndarray))
143+
94144
def test_perceptron_learns_and(self):
95145
'''Test that 2x1 network can learn the logical AND function.'''
96146
from spectral.algorithms.perceptron import test_and

0 commit comments

Comments
 (0)