3131
3232if  TYPE_CHECKING :
3333    from  collections .abc  import  Callable , MutableMapping 
34-     from  typing  import  Any , Literal , NotRequired 
34+     from  typing  import  Any , Literal , NotRequired ,  Unpack 
3535
3636    from  anndata  import  AnnData 
3737    from  igraph  import  Graph 
@@ -60,6 +60,13 @@ class KwdsForTransformer(TypedDict):
6060    random_state : _LegacyRandom 
6161
6262
63+ class  NeighborsDict (TypedDict ):  # noqa: D101 
64+     connectivities_key : str 
65+     distances_key : str 
66+     params : NeighborsParams 
67+     rp_forest : NotRequired [RPForestDict ]
68+ 
69+ 
6370class  NeighborsParams (TypedDict ):  # noqa: D101 
6471    n_neighbors : int 
6572    method : _Method 
@@ -138,6 +145,7 @@ def neighbors(  # noqa: PLR0913
138145               Use :func:`rapids_singlecell.pp.neighbors` instead. 
139146    metric 
140147        A known metric’s name or a callable that returns a distance. 
148+         If `distances` is given, this parameter is simply stored in `.uns` (see below). 
141149
142150        *ignored if ``transformer`` is an instance.* 
143151    metric_kwds 
@@ -190,12 +198,15 @@ def neighbors(  # noqa: PLR0913
190198
191199    """ 
192200    if  distances  is  not   None :
193-         # Added this to support the new distance matrix function 
201+         if  callable (metric ):
202+             msg  =  "`metric` must be a string if `distances` is given." 
203+             raise  TypeError (msg )
194204        # if a precomputed distance matrix is provided, skip the PCA and distance computation 
195205        return  neighbors_from_distance (
196206            adata ,
197207            distances ,
198208            n_neighbors = n_neighbors ,
209+             metric = metric ,
199210            method = method ,
200211        )
201212    start  =  logg .info ("computing neighbors" )
@@ -215,46 +226,31 @@ def neighbors(  # noqa: PLR0913
215226        random_state = random_state ,
216227    )
217228
218-     if  key_added  is  None :
219-         key_added  =  "neighbors" 
220-         conns_key  =  "connectivities" 
221-         dists_key  =  "distances" 
222-     else :
223-         conns_key  =  key_added  +  "_connectivities" 
224-         dists_key  =  key_added  +  "_distances" 
225- 
226-     adata .uns [key_added ] =  {}
227- 
228-     neighbors_dict  =  adata .uns [key_added ]
229- 
230-     neighbors_dict ["connectivities_key" ] =  conns_key 
231-     neighbors_dict ["distances_key" ] =  dists_key 
232- 
233-     neighbors_dict ["params" ] =  NeighborsParams (
229+     key_added , neighbors_dict  =  _get_metadata (
230+         key_added ,
234231        n_neighbors = neighbors .n_neighbors ,
235232        method = method ,
236233        random_state = random_state ,
237234        metric = metric ,
235+         ** ({} if  not  metric_kwds  else  dict (metric_kwds = metric_kwds )),
236+         ** ({} if  use_rep  is  None  else  dict (use_rep = use_rep )),
237+         ** ({} if  n_pcs  is  None  else  dict (n_pcs = n_pcs )),
238238    )
239-     if  metric_kwds :
240-         neighbors_dict ["params" ]["metric_kwds" ] =  metric_kwds 
241-     if  use_rep  is  not   None :
242-         neighbors_dict ["params" ]["use_rep" ] =  use_rep 
243-     if  n_pcs  is  not   None :
244-         neighbors_dict ["params" ]["n_pcs" ] =  n_pcs 
245- 
246-     adata .obsp [dists_key ] =  neighbors .distances 
247-     adata .obsp [conns_key ] =  neighbors .connectivities 
248239
249240    if  neighbors .rp_forest  is  not   None :
250241        neighbors_dict ["rp_forest" ] =  neighbors .rp_forest 
242+ 
243+     adata .uns [key_added ] =  neighbors_dict 
244+     adata .obsp [neighbors_dict ["distances_key" ]] =  neighbors .distances 
245+     adata .obsp [neighbors_dict ["connectivities_key" ]] =  neighbors .connectivities 
246+ 
251247    logg .info (
252248        "    finished" ,
253249        time = start ,
254250        deep = (
255251            f"added to `.uns[{ key_added !r}  ]`\n " 
256-             f"    `.obsp[{ dists_key !r}  ]`, distances for each pair of neighbors\n " 
257-             f"    `.obsp[{ conns_key !r}  ]`, weighted adjacency matrix" 
252+             f"    `.obsp[{ neighbors_dict [ 'distances_key' ] !r}  ]`, distances for each pair of neighbors\n " 
253+             f"    `.obsp[{ neighbors_dict [ 'connectivities_key' ] !r}  ]`, weighted adjacency matrix" 
258254        ),
259255    )
260256    return  adata  if  copy  else  None 
@@ -265,6 +261,7 @@ def neighbors_from_distance(
265261    distances : np .ndarray  |  SpBase ,
266262    * ,
267263    n_neighbors : int  =  15 ,
264+     metric : _Metric  =  "euclidean" ,
268265    method : _Method  =  "umap" ,  # default to umap 
269266    key_added : str  |  None  =  None ,
270267) ->  AnnData :
@@ -298,63 +295,59 @@ def neighbors_from_distance(
298295        distances  =  sparse .csr_matrix (distances )  # noqa: TID251 
299296        distances .setdiag (0 )
300297        distances .eliminate_zeros ()
301-         # extracting for each observation the indices and distances of the n_neighbors 
302-         # being then used by umap or gauss 
303-         knn_indices , knn_distances  =  _get_indices_distances_from_sparse_matrix (
304-             distances , n_neighbors 
305-         )
306298    else :
307-         # if it is dense, converting it to ndarray 
308-         # and setting the diagonal to 0 
309-         # extracting knn indices and distances 
310299        distances  =  np .asarray (distances )
311300        np .fill_diagonal (distances , 0 )
312-         knn_indices , knn_distances  =  _get_indices_distances_from_dense_matrix (
313-             distances , n_neighbors 
314-         )
315301
316302    if  method  ==  "umap" :
317-         # using umap to build connectivities from distances 
303+         if  isinstance (distances , CSRBase ):
304+             knn_indices , knn_distances  =  _get_indices_distances_from_sparse_matrix (
305+                 distances , n_neighbors 
306+             )
307+         else :
308+             knn_indices , knn_distances  =  _get_indices_distances_from_dense_matrix (
309+                 distances , n_neighbors 
310+             )
318311        connectivities  =  umap (
319-             knn_indices ,
320-             knn_distances ,
321-             n_obs = adata .n_obs ,
322-             n_neighbors = n_neighbors ,
312+             knn_indices , knn_distances , n_obs = adata .n_obs , n_neighbors = n_neighbors 
323313        )
324314    elif  method  ==  "gauss" :
325-         # using gauss to build connectivities from distances 
326-         # requires sparse matrix for efficiency 
327-         connectivities  =  _connectivity .gauss (
328-             sparse .csr_matrix (distances ),  # noqa: TID251 
329-             n_neighbors ,
330-             knn = True ,
331-         )
315+         distances  =  sparse .csr_matrix (distances )  # noqa: TID251 
316+         connectivities  =  _connectivity .gauss (distances , n_neighbors , knn = True )
332317    else :
333318        msg  =  f"Method { method }   not implemented." 
334319        raise  NotImplementedError (msg )
335-     # defining where to store graph info 
336-     key  =  "neighbors"  if  key_added  is  None  else  key_added 
337-     dists_key  =  "distances"  if  key_added  is  None  else  key_added  +  "_distances" 
338-     conns_key  =  "connectivities"  if  key_added  is  None  else  key_added  +  "_connectivities" 
339-     # storing the actual distance and connectivitiy matrices as obsp 
340-     adata .obsp [dists_key ] =  sparse .csr_matrix (distances )  # noqa: TID251 
341-     adata .obsp [conns_key ] =  connectivities 
342-     # populating with metadata describing how neighbors were computed 
343-     # I think might be important as many functions downstream rely 
344-     # on .uns['neighbors'] to find correct .obsp key 
345-     adata .uns [key ] =  {
346-         "connectivities_key" : "connectivities" ,
347-         "distances_key" : "distances" ,
348-         "params" : {
349-             "n_neighbors" : n_neighbors ,
350-             "method" : method ,
351-             "random_state" : 0 ,
352-             "metric" : "euclidean" ,
353-         },
354-     }
320+ 
321+     key_added , neighbors_dict  =  _get_metadata (
322+         key_added ,
323+         n_neighbors = n_neighbors ,
324+         method = method ,
325+         random_state = 0 ,
326+         metric = metric ,
327+     )
328+     adata .uns [key_added ] =  neighbors_dict 
329+     adata .obsp [neighbors_dict ["distances_key" ]] =  distances 
330+     adata .obsp [neighbors_dict ["connectivities_key" ]] =  connectivities 
355331    return  adata 
356332
357333
334+ def  _get_metadata (
335+     key_added : str  |  None ,
336+     ** params : Unpack [NeighborsParams ],
337+ ) ->  tuple [str , NeighborsDict ]:
338+     if  key_added  is  None :
339+         return  "neighbors" , NeighborsDict (
340+             connectivities_key = "connectivities" ,
341+             distances_key = "distances" ,
342+             params = params ,
343+         )
344+     return  key_added , NeighborsDict (
345+         connectivities_key = f"{ key_added }  _connectivities" ,
346+         distances_key = f"{ key_added }  _distances" ,
347+         params = params ,
348+     )
349+ 
350+ 
358351class  FlatTree (NamedTuple ):  # noqa: D101 
359352    hyperplanes : None 
360353    offsets : None 
0 commit comments