|
7 | 7 | from skmatter.datasets import load_degenerate_CH4_manifold |
8 | 8 | from skmatter.metrics import ( |
9 | 9 | check_global_reconstruction_measures_input, |
| 10 | + check_local_reconstruction_measures_input, |
10 | 11 | componentwise_prediction_rigidity, |
11 | 12 | global_reconstruction_distortion, |
12 | 13 | global_reconstruction_error, |
@@ -233,6 +234,28 @@ def test_source_target_len(self): |
233 | 234 | expected_message = "First dimension of X (2) and Y (1) must match" |
234 | 235 | self.assertEqual(str(context.exception), expected_message) |
235 | 236 |
|
| 237 | + def test_len_n_local_points(self): |
| 238 | + # tests that source len is greater or equal than n_local_points in LFRE |
| 239 | + X = np.array([[1, 2, 3], [4, 5, 6]]) |
| 240 | + Y = np.array([[1, 1, 1], [2, 2, 2]]) |
| 241 | + |
| 242 | + n_local_points = 10 |
| 243 | + train_idx = [0] |
| 244 | + test_idx = [1] |
| 245 | + scaler = None |
| 246 | + estimator = None |
| 247 | + |
| 248 | + with self.assertRaises(ValueError) as context: |
| 249 | + check_local_reconstruction_measures_input( |
| 250 | + X, Y, n_local_points, train_idx, test_idx, scaler, estimator |
| 251 | + ) |
| 252 | + |
| 253 | + expected_message = ( |
| 254 | + f"X has {len(X)} samples but n_local_points={n_local_points}. " |
| 255 | + "Must have at least n_local_points samples" |
| 256 | + ) |
| 257 | + self.assertEqual(str(context.exception), expected_message) |
| 258 | + |
236 | 259 |
|
237 | 260 | class DistanceTests(unittest.TestCase): |
238 | 261 | @classmethod |
|
0 commit comments