4
4
from collections .abc import Iterator
5
5
from typing import Any , Optional
6
6
7
+ import anndata as ad
7
8
import numpy as np
8
9
import pandas as pd
9
10
import scanpy as sc
10
11
from anndata import AnnData
12
+ from scipy .sparse import csr_matrix , vstack
11
13
from scipy .stats import ranksums
12
14
from sklearn import metrics
13
15
from sklearn .metrics import adjusted_rand_score , fowlkes_mallows_score , normalized_mutual_info_score
@@ -26,7 +28,10 @@ def calculate_niche(
26
28
library_key : str | None = None ,
27
29
table_key : str | None = None ,
28
30
spatial_key : str = "spatial" ,
31
+ adj_subsets : list [int ] | None = None ,
32
+ aggregation : str = "mean" ,
29
33
spatial_connectivities_key : str = "spatial_connectivities" ,
34
+ spatial_distances_key : str = "spatial_distances" ,
30
35
copy : bool = False ,
31
36
) -> AnnData | pd .DataFrame :
32
37
"""Calculate niches (spatial clusters) based on a user-defined method in 'flavor'.
@@ -112,6 +117,38 @@ def calculate_niche(
112
117
else :
113
118
adata .layers ["utag" ] = new_feature_matrix
114
119
120
+ elif flavor == "cellcharter" :
121
+ adj_matrix_subsets = []
122
+ if isinstance (adj_subsets , list ):
123
+ for k in adj_subsets :
124
+ adj_matrix_subsets .append (
125
+ _get_adj_matrix_subsets (
126
+ adata .obsp [spatial_connectivities_key ], adata .obsp [spatial_distances_key ], k
127
+ )
128
+ )
129
+ if aggregation == "mean" :
130
+ inner_products = [adata .X .dot (adj_subset ) for adj_subset in adj_matrix_subsets ]
131
+ elif aggregation == "variance" :
132
+ inner_products = [
133
+ _aggregate_var (matrix , adata .obsp [spatial_connectivities_key ], adata ) for matrix in inner_products
134
+ ]
135
+ else :
136
+ raise ValueError (
137
+ f"Invalid aggregation method '{ aggregation } '. Please choose either 'mean' or 'variance'."
138
+ )
139
+ concatenated_matrix = vstack (inner_products )
140
+ if copy :
141
+ return concatenated_matrix
142
+ else :
143
+ if is_sdata :
144
+ sdata .tables [f"{ flavor } _niche" ] = ad .AnnData (concatenated_matrix )
145
+ else :
146
+ adata .obsm [f"{ flavor } _niche" ] = concatenated_matrix
147
+ else :
148
+ raise ValueError (
149
+ "Flavor 'cellcharter' requires list of neighbors to build adjacency matrices. Please provide a list of k_neighbors for 'adj_subsets'."
150
+ )
151
+
115
152
116
153
def _calculate_neighborhood_profile (
117
154
adata : AnnData | SpatialData ,
@@ -164,13 +201,43 @@ def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) ->
164
201
return adjacency_matrix @ adata .X
165
202
166
203
204
+ def _get_adj_matrix_subsets (connectivities : csr_matrix , distances : csr_matrix , k_neighbors : int ) -> csr_matrix :
205
+ # Convert the distance matrix to a dense format for easier manipulation
206
+ dist_dense = distances .todense ()
207
+
208
+ # Find the indices of the k closest neighbors for each row
209
+ closest_neighbors_indices = np .argsort (dist_dense , axis = 1 )[:, :k_neighbors ]
210
+
211
+ # Initialize lists to collect data for the new sparse matrix
212
+ rows = []
213
+ cols = []
214
+ data = []
215
+
216
+ # Iterate over each row to construct the new adjacency matrix
217
+ for row in range (dist_dense .shape [0 ]):
218
+ for col in closest_neighbors_indices [row ].flat :
219
+ rows .append (row )
220
+ cols .append (col )
221
+ data .append (connectivities [row , col ])
222
+
223
+ # Create the new sparse matrix with the reduced neighbors
224
+ new_adj_matrix = csr_matrix ((data , (rows , cols )), shape = connectivities .shape )
225
+
226
+ return new_adj_matrix
227
+
228
+
167
229
def _df_to_adata (df : pd .DataFrame ) -> AnnData :
168
230
df .index = df .index .map (str )
169
231
adata = AnnData (X = df )
170
232
adata .obs .index = df .index
171
233
return adata
172
234
173
235
236
+ def _aggregate_var (product : csr_matrix , connectivities : csr_matrix , adata : AnnData ) -> csr_matrix :
237
+ mean_squared = connectivities .dot (adata .X .multiply (adata .X ))
238
+ return mean_squared - (product .multiply (product ))
239
+
240
+
174
241
def pairwise_niche_comparison (
175
242
adata : AnnData ,
176
243
library_key : str ,
0 commit comments