Skip to content

Commit d802221

Browse files
committed
Fix tests failing due to include_self
1 parent aee1ca1 commit d802221

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

tests/test_cellmapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def test_load_squidpy_distances(self, adata_spatial, squidpy_params):
417417
assert "leiden_conf" in cm.query.obs
418418

419419
@pytest.mark.parametrize("include_self", [True, False])
420-
def test_load_distances_with_include_self(self, adata_spatial):
420+
def test_load_distances_with_include_self(self, adata_spatial, include_self):
421421
"""Test loading precomputed distances with and without self-connections."""
422422

423423
# Compute neighbors with scanpy

tests/test_utils.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,41 @@
88
class TestExtractNeighborsFromDistances:
99
"""Tests for the extract_neighbors_from_distances function."""
1010

11-
def test_basic_extraction(self):
11+
@pytest.mark.parametrize("include_self", [None, True, False])
12+
def test_basic_extraction(self, include_self):
1213
"""Test basic extraction from a simple distance matrix."""
1314
# Create a simple distance matrix
1415
distances_data = np.array(
1516
[[0.0, 1.0, 2.0, 3.0], [1.0, 0.0, 4.0, 5.0], [2.0, 4.0, 0.0, 6.0], [3.0, 5.0, 6.0, 0.0]]
1617
)
1718
distances = csr_matrix(distances_data)
1819

19-
# Extract neighbors
20-
indices, nbr_distances = extract_neighbors_from_distances(distances)
20+
# Extract neighbors with specified include_self parameter
21+
indices, nbr_distances = extract_neighbors_from_distances(distances, include_self=include_self)
22+
23+
# Determine expected shapes and content based on include_self
24+
if include_self is False or include_self is None:
25+
# Without self, each row has 3 neighbors
26+
expected_shape = (4, 3)
27+
# First neighbors should be the closest non-self neighbors
28+
expected_first_neighbors = [1, 0, 0, 0]
29+
else:
30+
# With self (or default behavior), each row has 4 neighbors
31+
expected_shape = (4, 4)
32+
# First neighbors should be self (distance 0)
33+
expected_first_neighbors = [0, 1, 2, 3]
2134

2235
# Check shapes
23-
assert indices.shape == (4, 4)
24-
assert nbr_distances.shape == (4, 4)
25-
26-
# Check values - neighbors should be sorted by distance
27-
np.testing.assert_array_equal(indices[0], [0, 1, 2, 3])
28-
np.testing.assert_array_equal(indices[1], [1, 0, 2, 3])
29-
np.testing.assert_array_almost_equal(nbr_distances[0], [0.0, 1.0, 2.0, 3.0])
30-
np.testing.assert_array_almost_equal(nbr_distances[1], [0.0, 1.0, 4.0, 5.0])
36+
assert indices.shape == expected_shape
37+
assert nbr_distances.shape == expected_shape
38+
39+
# Check first neighbors for each cell
40+
for i, expected in enumerate(expected_first_neighbors):
41+
assert indices[i, 0] == expected, f"Cell {i}'s first neighbor should be {expected}"
42+
43+
# Check all distances are sorted
44+
for i in range(4):
45+
assert np.all(np.diff(nbr_distances[i, :]) >= 0), f"Distances for cell {i} should be sorted"
3146

3247
def test_sparse_matrix(self):
3348
"""Test extraction from a sparse distance matrix."""

0 commit comments

Comments
 (0)