Skip to content

Commit 50705b3

Browse files
committed
test added
1 parent d285203 commit 50705b3

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/scanpy/neighbors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def neighbors_from_distance(
337337
dists_key = "distances" if key_added is None else key_added + "_distances"
338338
conns_key = "connectivities" if key_added is None else key_added + "_connectivities"
339339
# storing the actual distance and connectivitiy matrices as obsp
340-
adata.uns[dists_key] = sparse.csr_matrix(distances) # noqa: TID251
340+
adata.obsp[dists_key] = sparse.csr_matrix(distances) # noqa: TID251
341341
adata.obsp[conns_key] = connectivities
342342
# populating with metadata describing how neighbors were computed
343343
# I think might be important as many functions downstream rely

tests/test_neighbors.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from scanpy import Neighbors
1414
from scanpy._compat import CSBase
1515
from testing.scanpy._helpers import anndata_v0_8_constructor_compat
16+
from testing.scanpy._helpers.data import pbmc68k_reduced
1617

1718
if TYPE_CHECKING:
1819
from typing import Literal
@@ -241,3 +242,22 @@ def test_restore_n_neighbors(neigh, conv):
241242
ad.uns["neighbors"] = dict(connectivities=conv(neigh.connectivities))
242243
neigh_restored = Neighbors(ad)
243244
assert neigh_restored.n_neighbors == 1
245+
246+
247+
def test_neighbors_distance_equivalence():
248+
adata = pbmc68k_reduced()
249+
adata_d = adata.copy()
250+
251+
sc.pp.neighbors(adata)
252+
# reusing the same distances
253+
sc.pp.neighbors(adata_d, distances=adata.obsp["distances"])
254+
np.testing.assert_allclose(
255+
adata.obsp["connectivities"].toarray(),
256+
adata_d.obsp["connectivities"].toarray(),
257+
rtol=1e-5,
258+
)
259+
np.testing.assert_allclose(
260+
adata.obsp["distances"].toarray(),
261+
adata_d.obsp["distances"].toarray(),
262+
rtol=1e-5,
263+
)

0 commit comments

Comments
 (0)