Skip to content

Commit 5fb88f1

Browse files
committed
Fix the tests that checks for infinities
1 parent d802221 commit 5fb88f1

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

tests/test_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,23 +84,25 @@ def test_empty_matrix(self):
8484
assert nbr_distances.size == 0
8585

8686
def test_infinite_values(self):
87-
"""Test with a matrix containing infinite values."""
87+
"""Test that infinite values are preserved in the neighbor matrix."""
8888
# Create a distance matrix with some infinite values
8989
distances_data = np.array(
9090
[[0.0, 1.0, np.inf, 3.0], [1.0, 0.0, 4.0, np.inf], [np.inf, 4.0, 0.0, 6.0], [3.0, np.inf, 6.0, 0.0]]
9191
)
9292
distances = csr_matrix(distances_data)
9393

94-
# Extract neighbors
95-
indices, nbr_distances = extract_neighbors_from_distances(distances)
94+
# Extract neighbors with include_self=True
95+
indices, nbr_distances = extract_neighbors_from_distances(distances, include_self=True)
9696

97-
# Check shapes - should exclude infinite values
98-
assert indices.shape == (4, 3) # One less neighbor per cell
99-
assert nbr_distances.shape == (4, 3)
97+
# Check for a specific cell with infinite distance
98+
row0_neighbors = indices[0]
99+
assert 2 in row0_neighbors, "Neighbor with infinite distance should be included"
100100

101-
# Check that infinite values are excluded
102-
for i in range(4):
103-
assert np.all(np.isfinite(nbr_distances[i]))
101+
# Find where the infinite value is in the results
102+
idx = np.where(row0_neighbors == 2)[0][0]
103+
104+
# Verify the distance is infinite
105+
assert np.isinf(nbr_distances[0, idx]), "Distance should be infinite"
104106

105107
def test_include_self_parameter(self):
106108
"""Test the include_self parameter to control self-connections."""

0 commit comments

Comments
 (0)