3131
3232if TYPE_CHECKING :
3333 from collections .abc import Callable , MutableMapping
34- from typing import Any , Literal , NotRequired
34+ from typing import Any , Literal , NotRequired , Unpack
3535
3636 from anndata import AnnData
3737 from igraph import Graph
@@ -60,6 +60,13 @@ class KwdsForTransformer(TypedDict):
6060 random_state : _LegacyRandom
6161
6262
63+ class NeighborsDict (TypedDict ): # noqa: D101
64+ connectivities_key : str
65+ distances_key : str
66+ params : NeighborsParams
67+ rp_forest : NotRequired [RPForestDict ]
68+
69+
6370class NeighborsParams (TypedDict ): # noqa: D101
6471 n_neighbors : int
6572 method : _Method
@@ -138,6 +145,7 @@ def neighbors( # noqa: PLR0913
138145 Use :func:`rapids_singlecell.pp.neighbors` instead.
139146 metric
140147 A known metric’s name or a callable that returns a distance.
148+ If `distances` is given, this parameter is simply stored in `.uns` (see below).
141149
142150 *ignored if ``transformer`` is an instance.*
143151 metric_kwds
@@ -190,12 +198,15 @@ def neighbors( # noqa: PLR0913
190198
191199 """
192200 if distances is not None :
193- # Added this to support the new distance matrix function
201+ if callable (metric ):
202+ msg = "`metric` must be a string if `distances` is given."
203+ raise TypeError (msg )
194204 # if a precomputed distance matrix is provided, skip the PCA and distance computation
195205 return neighbors_from_distance (
196206 adata ,
197207 distances ,
198208 n_neighbors = n_neighbors ,
209+ metric = metric ,
199210 method = method ,
200211 )
201212 start = logg .info ("computing neighbors" )
@@ -215,46 +226,31 @@ def neighbors( # noqa: PLR0913
215226 random_state = random_state ,
216227 )
217228
218- if key_added is None :
219- key_added = "neighbors"
220- conns_key = "connectivities"
221- dists_key = "distances"
222- else :
223- conns_key = key_added + "_connectivities"
224- dists_key = key_added + "_distances"
225-
226- adata .uns [key_added ] = {}
227-
228- neighbors_dict = adata .uns [key_added ]
229-
230- neighbors_dict ["connectivities_key" ] = conns_key
231- neighbors_dict ["distances_key" ] = dists_key
232-
233- neighbors_dict ["params" ] = NeighborsParams (
229+ key_added , neighbors_dict = _get_metadata (
230+ key_added ,
234231 n_neighbors = neighbors .n_neighbors ,
235232 method = method ,
236233 random_state = random_state ,
237234 metric = metric ,
235+ ** ({} if not metric_kwds else dict (metric_kwds = metric_kwds )),
236+ ** ({} if use_rep is None else dict (use_rep = use_rep )),
237+ ** ({} if n_pcs is None else dict (n_pcs = n_pcs )),
238238 )
239- if metric_kwds :
240- neighbors_dict ["params" ]["metric_kwds" ] = metric_kwds
241- if use_rep is not None :
242- neighbors_dict ["params" ]["use_rep" ] = use_rep
243- if n_pcs is not None :
244- neighbors_dict ["params" ]["n_pcs" ] = n_pcs
245-
246- adata .obsp [dists_key ] = neighbors .distances
247- adata .obsp [conns_key ] = neighbors .connectivities
248239
249240 if neighbors .rp_forest is not None :
250241 neighbors_dict ["rp_forest" ] = neighbors .rp_forest
242+
243+ adata .uns [key_added ] = neighbors_dict
244+ adata .obsp [neighbors_dict ["distances_key" ]] = neighbors .distances
245+ adata .obsp [neighbors_dict ["connectivities_key" ]] = neighbors .connectivities
246+
251247 logg .info (
252248 " finished" ,
253249 time = start ,
254250 deep = (
255251 f"added to `.uns[{ key_added !r} ]`\n "
256- f" `.obsp[{ dists_key !r} ]`, distances for each pair of neighbors\n "
257- f" `.obsp[{ conns_key !r} ]`, weighted adjacency matrix"
252+ f" `.obsp[{ neighbors_dict [ 'distances_key' ] !r} ]`, distances for each pair of neighbors\n "
253+ f" `.obsp[{ neighbors_dict [ 'connectivities_key' ] !r} ]`, weighted adjacency matrix"
258254 ),
259255 )
260256 return adata if copy else None
@@ -265,6 +261,7 @@ def neighbors_from_distance(
265261 distances : np .ndarray | SpBase ,
266262 * ,
267263 n_neighbors : int = 15 ,
264+ metric : _Metric = "euclidean" ,
268265 method : _Method = "umap" , # default to umap
269266 key_added : str | None = None ,
270267) -> AnnData :
@@ -298,63 +295,59 @@ def neighbors_from_distance(
298295 distances = sparse .csr_matrix (distances ) # noqa: TID251
299296 distances .setdiag (0 )
300297 distances .eliminate_zeros ()
301- # extracting for each observation the indices and distances of the n_neighbors
302- # being then used by umap or gauss
303- knn_indices , knn_distances = _get_indices_distances_from_sparse_matrix (
304- distances , n_neighbors
305- )
306298 else :
307- # if it is dense, converting it to ndarray
308- # and setting the diagonal to 0
309- # extracting knn indices and distances
310299 distances = np .asarray (distances )
311300 np .fill_diagonal (distances , 0 )
312- knn_indices , knn_distances = _get_indices_distances_from_dense_matrix (
313- distances , n_neighbors
314- )
315301
316302 if method == "umap" :
317- # using umap to build connectivities from distances
303+ if isinstance (distances , CSRBase ):
304+ knn_indices , knn_distances = _get_indices_distances_from_sparse_matrix (
305+ distances , n_neighbors
306+ )
307+ else :
308+ knn_indices , knn_distances = _get_indices_distances_from_dense_matrix (
309+ distances , n_neighbors
310+ )
318311 connectivities = umap (
319- knn_indices ,
320- knn_distances ,
321- n_obs = adata .n_obs ,
322- n_neighbors = n_neighbors ,
312+ knn_indices , knn_distances , n_obs = adata .n_obs , n_neighbors = n_neighbors
323313 )
324314 elif method == "gauss" :
325- # using gauss to build connectivities from distances
326- # requires sparse matrix for efficiency
327- connectivities = _connectivity .gauss (
328- sparse .csr_matrix (distances ), # noqa: TID251
329- n_neighbors ,
330- knn = True ,
331- )
315+ distances = sparse .csr_matrix (distances ) # noqa: TID251
316+ connectivities = _connectivity .gauss (distances , n_neighbors , knn = True )
332317 else :
333318 msg = f"Method { method } not implemented."
334319 raise NotImplementedError (msg )
335- # defining where to store graph info
336- key = "neighbors" if key_added is None else key_added
337- dists_key = "distances" if key_added is None else key_added + "_distances"
338- conns_key = "connectivities" if key_added is None else key_added + "_connectivities"
339- # storing the actual distance and connectivitiy matrices as obsp
340- adata .obsp [dists_key ] = sparse .csr_matrix (distances ) # noqa: TID251
341- adata .obsp [conns_key ] = connectivities
342- # populating with metadata describing how neighbors were computed
343- # I think might be important as many functions downstream rely
344- # on .uns['neighbors'] to find correct .obsp key
345- adata .uns [key ] = {
346- "connectivities_key" : "connectivities" ,
347- "distances_key" : "distances" ,
348- "params" : {
349- "n_neighbors" : n_neighbors ,
350- "method" : method ,
351- "random_state" : 0 ,
352- "metric" : "euclidean" ,
353- },
354- }
320+
321+ key_added , neighbors_dict = _get_metadata (
322+ key_added ,
323+ n_neighbors = n_neighbors ,
324+ method = method ,
325+ random_state = 0 ,
326+ metric = metric ,
327+ )
328+ adata .uns [key_added ] = neighbors_dict
329+ adata .obsp [neighbors_dict ["distances_key" ]] = distances
330+ adata .obsp [neighbors_dict ["connectivities_key" ]] = connectivities
355331 return adata
356332
357333
334+ def _get_metadata (
335+ key_added : str | None ,
336+ ** params : Unpack [NeighborsParams ],
337+ ) -> tuple [str , NeighborsDict ]:
338+ if key_added is None :
339+ return "neighbors" , NeighborsDict (
340+ connectivities_key = "connectivities" ,
341+ distances_key = "distances" ,
342+ params = params ,
343+ )
344+ return key_added , NeighborsDict (
345+ connectivities_key = f"{ key_added } _connectivities" ,
346+ distances_key = f"{ key_added } _distances" ,
347+ params = params ,
348+ )
349+
350+
358351class FlatTree (NamedTuple ): # noqa: D101
359352 hyperplanes : None
360353 offsets : None
0 commit comments