@@ -27,9 +27,9 @@ def calculate_niche(
27
27
flavor : str = "neighborhood" ,
28
28
library_key : str | None = None ,
29
29
table_key : str | None = None ,
30
- spatial_key : str = "spatial" ,
31
30
adj_subsets : list [int ] | None = None ,
32
31
aggregation : str = "mean" ,
32
+ spatial_key : str = "spatial" ,
33
33
spatial_connectivities_key : str = "spatial_connectivities" ,
34
34
spatial_distances_key : str = "spatial_distances" ,
35
35
copy : bool = False ,
@@ -174,11 +174,14 @@ def _calculate_neighborhood_profile(
174
174
neighbor_matrix = pd .DataFrame (nonzero_indices )
175
175
176
176
# get unique categories
177
- category_arr = adata .obs [groups ].values
178
- unique_categories = np .unique (category_arr )
177
+ unique_categories = np .unique (adata .obs [groups ].values )
179
178
180
179
# get obs x k matrix where each column is the category of the k-th neighbor
181
- cat_by_id = np .take (category_arr , neighbor_matrix )
180
+ indices_with_nan = neighbor_matrix .to_numpy ()
181
+ valid_indices = neighbor_matrix .fillna (- 1 ).astype (int ).to_numpy ()
182
+ cat_by_id = adata .obs [groups ].values [valid_indices ]
183
+ cat_by_id [indices_with_nan == - 1 ] = np .nan
184
+ # cat_by_id = np.take(category_arr, neighbor_matrix)
182
185
183
186
# in obs x k matrix convert categorical values to numerical values
184
187
cat_indices = {category : index for index , category in enumerate (unique_categories )}
0 commit comments