Skip to content

Commit 108ce37

Browse files
sofiia-chornaPicoCentauri
authored andcommitted
Added test to check len(samples) and n_local_points error
1 parent b215be3 commit 108ce37

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/test_metrics.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from skmatter.datasets import load_degenerate_CH4_manifold
88
from skmatter.metrics import (
99
check_global_reconstruction_measures_input,
10+
check_local_reconstruction_measures_input,
1011
componentwise_prediction_rigidity,
1112
global_reconstruction_distortion,
1213
global_reconstruction_error,
@@ -233,6 +234,28 @@ def test_source_target_len(self):
233234
expected_message = "First dimension of X (2) and Y (1) must match"
234235
self.assertEqual(str(context.exception), expected_message)
235236

237+
def test_len_n_local_points(self):
238+
# tests that source len is greater or equal than n_local_points in LFRE
239+
X = np.array([[1, 2, 3], [4, 5, 6]])
240+
Y = np.array([[1, 1, 1], [2, 2, 2]])
241+
242+
n_local_points = 10
243+
train_idx = [0]
244+
test_idx = [1]
245+
scaler = None
246+
estimator = None
247+
248+
with self.assertRaises(ValueError) as context:
249+
check_local_reconstruction_measures_input(
250+
X, Y, n_local_points, train_idx, test_idx, scaler, estimator
251+
)
252+
253+
expected_message = (
254+
f"X has {len(X)} samples but n_local_points={n_local_points}. "
255+
"Must have at least n_local_points samples"
256+
)
257+
self.assertEqual(str(context.exception), expected_message)
258+
236259

237260
class DistanceTests(unittest.TestCase):
238261
@classmethod

0 commit comments

Comments
 (0)