Skip to content

Commit 0882eda

Browse files
committed
Fix general bug in GreedySelector that would pick the same point if there are degeneracies
1 parent 2469350 commit 0882eda

File tree

7 files changed

+28
-15
lines changed

7 files changed

+28
-15
lines changed

src/skmatter/_selection.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,8 @@ def _get_best_new_selection(self, scorer, X, y):
427427
scores = scorer(X, y)
428428

429429
# Get the score argmax, but only for idxs not already selected
430-
_tmp_scores = {
431-
i: score for i, score in enumerate(scores) if i not in self.selected_idx_
432-
}
433-
max_score_idx = max(_tmp_scores, key=_tmp_scores.get)
430+
scores[self.selected_idx_[: self.n_selected_]] = -np.inf
431+
max_score_idx = np.argmax(scores)
434432
if self.score_threshold is not None:
435433
if self.first_score_ is None:
436434
self.first_score_ = scores[max_score_idx]

tests/test_feature_pcov_cur.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def test_non_it(self):
3333
self.idx = [2, 8, 3, 6, 7, 9, 1, 0, 5]
3434
selector = PCovCUR(n_to_select=9, recompute_every=0)
3535
selector.fit(self.X, self.y)
36-
3736
self.assertTrue(np.allclose(selector.selected_idx_, self.idx))
3837

3938

tests/test_feature_simple_cur.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def test_unique_selected_idx_zero_score(self):
5050
n_samples = 10
5151
n_features = 15
5252
X = np.random.rand(n_samples, n_features)
53-
X[:, 3] = np.random.rand(10) * 1e-13
54-
X[:, 4] = np.random.rand(10) * 1e-13
53+
X[:, 1] = X[:, 0]
54+
X[:, 2] = X[:, 0]
5555
selector_problem = CUR(n_to_select=len(X.T)).fit(X)
5656
assert len(selector_problem.selected_idx_) == len(
5757
set(selector_problem.selected_idx_)

tests/test_feature_simple_fps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def test_unique_selected_idx_zero_score(self):
9393
n_samples = 10
9494
n_features = 15
9595
X = np.random.rand(n_samples, n_features)
96-
X[:, 3] = np.random.rand(10) * 1e-13
97-
X[:, 4] = np.random.rand(10) * 1e-13
96+
X[:, 1] = X[:, 0]
97+
X[:, 2] = X[:, 0]
9898
selector_problem = FPS(n_to_select=len(X.T)).fit(X)
9999
assert len(selector_problem.selected_idx_) == len(
100100
set(selector_problem.selected_idx_)

tests/test_sample_simple_cur.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ def test_unique_selected_idx_zero_score(self):
6161
n_samples = 10
6262
n_features = 15
6363
X = np.random.rand(n_samples, n_features)
64-
X[4, :] = np.random.rand(15) * 1e-13
65-
X[5, :] = np.random.rand(15) * 1e-13
66-
X[6, :] = np.random.rand(15) * 1e-13
64+
X[1] = X[0]
65+
X[2] = X[0]
66+
X[3] = X[0]
6767
selector_problem = CUR(n_to_select=len(X)).fit(X)
6868
assert len(selector_problem.selected_idx_) == len(
6969
set(selector_problem.selected_idx_)

tests/test_sample_simple_fps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def test_unique_selected_idx_zero_score(self):
108108
n_samples = 10
109109
n_features = 15
110110
X = np.random.rand(n_samples, n_features)
111-
X[4, :] = np.random.rand(15) * 1e-13
112-
X[5, :] = np.random.rand(15) * 1e-13
113-
X[6, :] = np.random.rand(15) * 1e-13
111+
X[1] = X[0]
112+
X[2] = X[0]
113+
X[3] = X[0]
114114
selector_problem = FPS(n_to_select=len(X)).fit(X)
115115
assert len(selector_problem.selected_idx_) == len(
116116
set(selector_problem.selected_idx_)

tests/test_voronoi_fps.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,22 @@ def test_score(self):
165165
)
166166
)
167167

168+
def test_unique_selected_idx_zero_score(self):
169+
"""
170+
Tests that the selected idxs are unique, which may not be the
171+
case when the score is numerically zero
172+
"""
173+
np.random.seed(0)
174+
n_samples = 10
175+
n_features = 15
176+
X = np.random.rand(n_samples, n_features)
177+
X[1] = X[0]
178+
X[2] = X[0]
179+
selector_problem = VoronoiFPS(n_to_select=n_samples, initialize=3).fit(X)
180+
assert len(selector_problem.selected_idx_) == len(
181+
set(selector_problem.selected_idx_)
182+
)
183+
168184

169185
if __name__ == "__main__":
170186
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)