Skip to content

Commit a1a6e34

Browse files
committed
Check that image/mask shapes align and simplify iterator logic.
1 parent 5055391 commit a1a6e34

File tree

2 files changed

+20
-30
lines changed

2 files changed

+20
-30
lines changed

spectral/algorithms/algorithms.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,39 +87,29 @@ class ImageMaskIterator(Iterator):
8787
'''
8888
An iterator over all pixels in an image corresponding to a specified mask.
8989
'''
90-
def __init__(self, im, mask, index=None):
91-
self.image = im
90+
def __init__(self, image, mask, index=None):
91+
if mask.shape != image.shape[:len(mask.shape)]:
92+
raise ValueError('Mask shape does not match image.')
93+
self.image = image
9294
self.index = index
9395
# Get the proper mask for the training set
9496
if index:
9597
self.mask = np.equal(mask, index)
9698
else:
9799
self.mask = np.not_equal(mask, 0)
98-
self.numElements = sum(self.mask.ravel())
100+
self.n_elements = sum(self.mask.ravel())
99101

100102
def get_num_elements(self):
101-
return self.numElements
103+
return self.n_elements
102104

103105
def get_num_bands(self):
104106
return self.image.shape[2]
105107

106108
def __iter__(self):
107-
from numpy import transpose, indices, reshape, compress, not_equal
108-
(nrows, ncols, nbands) = self.image.shape
109-
110-
# Translate the mask into indices into the data source
111-
inds = transpose(indices((nrows, ncols)), (1, 2, 0))
112-
inds = reshape(inds, (nrows * ncols, 2))
113-
inds = compress(not_equal(self.mask.ravel(), 0), inds, 0).astype('h')
114-
115-
for i in range(inds.shape[0]):
116-
sample = self.image[inds[i][0], inds[i][1]].astype(
117-
self.image.dtype)
118-
if len(sample.shape) == 3:
119-
sample.shape = (sample.shape[2],)
120-
(self.row, self.col) = inds[i][:2]
121-
yield sample
122-
109+
coords = np.argwhere(self.mask)
110+
for (i, j) in coords:
111+
(self.row, self.col) = (i, j)
112+
yield self.image[i, j].astype(self.image.dtype).squeeze()
123113

124114
def iterator(image, mask=None, index=None):
125115
'''

spectral/tests/iterators.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,42 +64,42 @@ def test_iterator_nonzero(self):
6464
classes = self.gt.ravel()
6565
pixels = data.reshape((-1, data.shape[-1]))
6666
sum = np.sum(pixels[classes > 0], 0)
67-
itsum = np.sum(np.array([x for x in iterator(data, classes)]), 0)
67+
itsum = np.sum(np.array([x for x in iterator(data, self.gt)]), 0)
6868
assert_allclose(sum, itsum)
6969

7070
def test_iterator_index(self):
7171
'''Iteration over single ground truth index'''
7272
from spectral.algorithms.algorithms import iterator
73-
i = 5
73+
cls = 5
7474
data = self.image.load()
7575
classes = self.gt.ravel()
7676
pixels = data.reshape((-1, data.shape[-1]))
77-
sum = np.sum(pixels[classes == 5], 0)
78-
itsum = np.sum(np.array([x for x in iterator(data, classes, 5)]), 0)
77+
sum = np.sum(pixels[classes == cls], 0)
78+
itsum = np.sum(np.array([x for x in iterator(data, self.gt, cls)]), 0)
7979
assert_allclose(sum, itsum)
8080

8181
def test_iterator_spyfile(self):
8282
'''Iteration over SpyFile object for single ground truth index'''
8383
from spectral.algorithms.algorithms import iterator
84-
i = 5
84+
cls = 5
8585
data = self.image.load()
8686
classes = self.gt.ravel()
8787
pixels = data.reshape((-1, data.shape[-1]))
88-
sum = np.sum(pixels[classes == 5], 0)
89-
itsum = np.sum(np.array([x for x in iterator(self.image, classes, 5)]),
88+
sum = np.sum(pixels[classes == cls], 0)
89+
itsum = np.sum(np.array([x for x in iterator(self.image, self.gt, cls)]),
9090
0)
9191
assert_allclose(sum, itsum)
9292

9393
def test_iterator_spyfile_nomemmap(self):
9494
'''Iteration over SpyFile object without memmap'''
9595
from spectral.algorithms.algorithms import iterator
96-
i = 5
96+
cls = 5
9797
data = self.image.load()
9898
classes = self.gt.ravel()
9999
pixels = data.reshape((-1, data.shape[-1]))
100-
sum = np.sum(pixels[classes == 5], 0)
100+
sum = np.sum(pixels[classes == cls], 0)
101101
image = spy.open_image('92AV3C.lan')
102-
itsum = np.sum(np.array([x for x in iterator(image, classes, 5)]), 0)
102+
itsum = np.sum(np.array([x for x in iterator(image, self.gt, cls)]), 0)
103103
assert_allclose(sum, itsum)
104104

105105

0 commit comments

Comments
 (0)