Skip to content

Commit dd31493

Browse files
cajchristianChristian Jorgensen
andauthored
Allowing FPS to take numpy array of ints as initialize parameter (#225)
* Add numpy array support for initialize paramater for FPS * Adding unit test for initialize as np array * Fixed linting issue * Added fix for np array value error * Adding unit test for case with np array containing non-ints * Adding documentation in skmatter.sample_selection * Removed unnecessary test and fixed initialize * Revert "Removed unnecessary test and fixed initialize" This reverts commit c25c850. * Adding "numpy" before ndarray in docstrings * Changing error message and adding another unit test * Added unit tests * Combined if statements for list and array * Update CHANGELOG --------- Co-authored-by: Christian Jorgensen <[email protected]>
1 parent bd54517 commit dd31493

File tree

6 files changed

+64
-10
lines changed

6 files changed

+64
-10
lines changed

CHANGELOG

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ The rules for CHANGELOG file:
1313

1414
0.3.0 (XXXX/XX/XX)
1515
------------------
16+
- Updating ``FPS`` to allow a numpy array of ints as an initialize parameter (#145)
1617
- Supported Python versions are now ranging from 3.9 - 3.12.
1718

1819
0.2.0 (2023/08/24)

src/skmatter/_selection.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,7 @@ class _FPS(GreedySelector):
934934
Parameters
935935
----------
936936
937-
initialize: int, list of int, or 'random', default=0
937+
initialize: int, list of int, numpy.ndarray of int, or 'random', default=0
938938
Index of the first selection(s). If 'random', picks a random
939939
value when fit starts. Stored in :py:attr:`self.initialize`.
940940
@@ -1038,7 +1038,14 @@ def _init_greedy_search(self, X, y, n_to_select):
10381038
self.hausdorff_ = np.full(X.shape[self._axis], np.inf)
10391039
self.hausdorff_at_select_ = np.full(X.shape[self._axis], np.inf)
10401040

1041-
if self.initialize == "random":
1041+
if isinstance(self.initialize, (np.ndarray, list)):
1042+
if all(isinstance(i, numbers.Integral) for i in self.initialize):
1043+
for i, val in enumerate(self.initialize):
1044+
self.selected_idx_[i] = val
1045+
self._update_post_selection(X, y, self.selected_idx_[i])
1046+
else:
1047+
raise ValueError("Invalid value of the initialize parameter")
1048+
elif self.initialize == "random":
10421049
random_state = check_random_state(self.random_state)
10431050
initialize = random_state.randint(X.shape[self._axis])
10441051
self.selected_idx_[0] = initialize
@@ -1047,12 +1054,6 @@ def _init_greedy_search(self, X, y, n_to_select):
10471054
initialize = self.initialize
10481055
self.selected_idx_[0] = initialize
10491056
self._update_post_selection(X, y, self.selected_idx_[0])
1050-
elif isinstance(self.initialize, list) and all(
1051-
[isinstance(i, numbers.Integral) for i in self.initialize]
1052-
):
1053-
for i, val in enumerate(self.initialize):
1054-
self.selected_idx_[i] = val
1055-
self._update_post_selection(X, y, self.selected_idx_[i])
10561057
else:
10571058
raise ValueError("Invalid value of the initialize parameter")
10581059

src/skmatter/feature_selection/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class FPS(_FPS):
1212
Parameters
1313
----------
1414
15-
initialize: int, list of int, or 'random', default=0
15+
initialize: int, list of int, numpy.ndarray of int, or 'random', default=0
1616
Index of the first selection(s). If 'random', picks a random
1717
value when fit starts. Stored in :py:attr:`self.initialize`.
1818

src/skmatter/sample_selection/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class FPS(_FPS):
5858
Parameters
5959
----------
6060
61-
initialize: int, list of int, or 'random', default=0
61+
initialize: int, list of int, numpy.ndarray of int, or 'random', default=0
6262
Index of the first selection(s). If 'random', picks a random
6363
value when fit starts. Stored in :py:attr:`self.initialize`.
6464

tests/test_feature_simple_fps.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
import numpy as np
34
from sklearn.datasets import load_diabetes as get_dataset
45
from sklearn.utils.validation import NotFittedError
56

@@ -42,6 +43,31 @@ def test_initialize(self):
4243
for i in range(4):
4344
self.assertEqual(selector.selected_idx_[i], self.idx[i])
4445

46+
initialize = np.array(self.idx[:4])
47+
with self.subTest(initialize=initialize):
48+
selector = FPS(n_to_select=len(self.idx) - 1, initialize=initialize)
49+
selector.fit(self.X)
50+
for i in range(4):
51+
self.assertEqual(selector.selected_idx_[i], self.idx[i])
52+
53+
initialize = np.array([1, 5, 3, 0.25])
54+
with self.subTest(initialize=initialize):
55+
with self.assertRaises(ValueError) as cm:
56+
selector = FPS(n_to_select=len(self.idx) - 1, initialize=initialize)
57+
selector.fit(self.X)
58+
self.assertEqual(
59+
str(cm.exception), "Invalid value of the initialize parameter"
60+
)
61+
62+
initialize = np.array([[1, 5, 3], [2, 4, 6]])
63+
with self.subTest(initialize=initialize):
64+
with self.assertRaises(ValueError) as cm:
65+
selector = FPS(n_to_select=len(self.idx) - 1, initialize=initialize)
66+
selector.fit(self.X)
67+
self.assertEqual(
68+
str(cm.exception), "Invalid value of the initialize parameter"
69+
)
70+
4571
with self.assertRaises(ValueError) as cm:
4672
selector = FPS(n_to_select=1, initialize="bad")
4773
selector.fit(self.X)

tests/test_sample_simple_fps.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
import numpy as np
34
from sklearn.datasets import load_diabetes as get_dataset
45
from sklearn.utils.validation import NotFittedError
56

@@ -43,6 +44,31 @@ def test_initialize(self):
4344
for i in range(4):
4445
self.assertEqual(selector.selected_idx_[i], self.idx[i])
4546

47+
initialize = np.array(self.idx[:4])
48+
with self.subTest(initialize=initialize):
49+
selector = FPS(n_to_select=len(self.idx) - 1, initialize=initialize)
50+
selector.fit(self.X)
51+
for i in range(4):
52+
self.assertEqual(selector.selected_idx_[i], self.idx[i])
53+
54+
initialize = np.array([1, 5, 3, 0.25])
55+
with self.subTest(initialize=initialize):
56+
with self.assertRaises(ValueError) as cm:
57+
selector = FPS(n_to_select=len(self.idx) - 1, initialize=initialize)
58+
selector.fit(self.X)
59+
self.assertEqual(
60+
str(cm.exception), "Invalid value of the initialize parameter"
61+
)
62+
63+
initialize = np.array([[1, 5, 3], [2, 4, 6]])
64+
with self.subTest(initialize=initialize):
65+
with self.assertRaises(ValueError) as cm:
66+
selector = FPS(n_to_select=len(self.idx) - 1, initialize=initialize)
67+
selector.fit(self.X)
68+
self.assertEqual(
69+
str(cm.exception), "Invalid value of the initialize parameter"
70+
)
71+
4672
with self.assertRaises(ValueError) as cm:
4773
selector = FPS(n_to_select=1, initialize="bad")
4874
selector.fit(self.X)

0 commit comments

Comments
 (0)