|
8 | 8 | class TestExtractNeighborsFromDistances: |
9 | 9 | """Tests for the extract_neighbors_from_distances function.""" |
10 | 10 |
|
11 | | - def test_basic_extraction(self): |
| 11 | + @pytest.mark.parametrize("include_self", [None, True, False]) |
| 12 | + def test_basic_extraction(self, include_self): |
12 | 13 | """Test basic extraction from a simple distance matrix.""" |
13 | 14 | # Create a simple distance matrix |
14 | 15 | distances_data = np.array( |
15 | 16 | [[0.0, 1.0, 2.0, 3.0], [1.0, 0.0, 4.0, 5.0], [2.0, 4.0, 0.0, 6.0], [3.0, 5.0, 6.0, 0.0]] |
16 | 17 | ) |
17 | 18 | distances = csr_matrix(distances_data) |
18 | 19 |
|
19 | | - # Extract neighbors |
20 | | - indices, nbr_distances = extract_neighbors_from_distances(distances) |
| 20 | + # Extract neighbors with specified include_self parameter |
| 21 | + indices, nbr_distances = extract_neighbors_from_distances(distances, include_self=include_self) |
| 22 | + |
| 23 | + # Determine expected shapes and content based on include_self |
| 24 | + if include_self is False or include_self is None: |
| 25 | + # Without self, each row has 3 neighbors |
| 26 | + expected_shape = (4, 3) |
| 27 | + # First neighbors should be the closest non-self neighbors |
| 28 | + expected_first_neighbors = [1, 0, 0, 0] |
| 29 | + else: |
| 30 | + # With self (or default behavior), each row has 4 neighbors |
| 31 | + expected_shape = (4, 4) |
| 32 | + # First neighbors should be self (distance 0) |
| 33 | + expected_first_neighbors = [0, 1, 2, 3] |
21 | 34 |
|
22 | 35 | # Check shapes |
23 | | - assert indices.shape == (4, 4) |
24 | | - assert nbr_distances.shape == (4, 4) |
25 | | - |
26 | | - # Check values - neighbors should be sorted by distance |
27 | | - np.testing.assert_array_equal(indices[0], [0, 1, 2, 3]) |
28 | | - np.testing.assert_array_equal(indices[1], [1, 0, 2, 3]) |
29 | | - np.testing.assert_array_almost_equal(nbr_distances[0], [0.0, 1.0, 2.0, 3.0]) |
30 | | - np.testing.assert_array_almost_equal(nbr_distances[1], [0.0, 1.0, 4.0, 5.0]) |
| 36 | + assert indices.shape == expected_shape |
| 37 | + assert nbr_distances.shape == expected_shape |
| 38 | + |
| 39 | + # Check first neighbors for each cell |
| 40 | + for i, expected in enumerate(expected_first_neighbors): |
| 41 | + assert indices[i, 0] == expected, f"Cell {i}'s first neighbor should be {expected}" |
| 42 | + |
| 43 | + # Check all distances are sorted |
| 44 | + for i in range(4): |
| 45 | + assert np.all(np.diff(nbr_distances[i, :]) >= 0), f"Distances for cell {i} should be sorted" |
31 | 46 |
|
32 | 47 | def test_sparse_matrix(self): |
33 | 48 | """Test extraction from a sparse distance matrix.""" |
|
0 commit comments