@@ -24,9 +24,6 @@ def calculate_niche(
24
24
groups : str ,
25
25
flavor : str = "neighborhood" ,
26
26
library_key : str | None = None ,
27
- radius : float | None = None ,
28
- n_neighbors : int | None = None ,
29
- limit_to : str | list [Any ] | None = None ,
30
27
table_key : str | None = None ,
31
28
spatial_key : str = "spatial" ,
32
29
spatial_connectivities_key : str = "spatial_connectivities" ,
@@ -51,7 +48,7 @@ def calculate_niche(
51
48
- `{c.UTAG.s!r}` - use utag algorithm (matrix multiplication).
52
49
- `{c.ALL.s!r}` - apply all available methods and compare them using cluster validation scores.
53
50
%(library_key)s
54
- limit_to
51
+ subset
55
52
Restrict niche calculation to a subset of the data.
56
53
table_key
57
54
Key in `spatialdata.tables` to specify an 'anndata' table. Only necessary if 'sdata' is passed.
@@ -65,14 +62,15 @@ def calculate_niche(
65
62
if isinstance (adata , SpatialData ):
66
63
is_sdata = True
67
64
if table_key is not None :
68
- table = adata .tables [table_key ]
65
+ sdata = adata
66
+ adata = adata .tables [table_key ].copy ()
69
67
else :
70
68
if len (adata .tables ) > 1 :
71
69
count = 0
72
- for key in adata .tables .keys ():
70
+ for table in adata .tables .keys ():
73
71
if groups in table .obs :
74
72
count += 1
75
- table_key = key
73
+ table_key = table
76
74
if count > 1 :
77
75
raise ValueError (
78
76
f"Multiple tables in `spatialdata` with group `{ groups } ` detected. Please specify which table to use in `table_key`."
@@ -82,70 +80,44 @@ def calculate_niche(
82
80
f"Group `{ groups } ` not found in any table in `spatialdata`. Please specify a valid group in `groups`."
83
81
)
84
82
else :
85
- table = adata .tables [table_key ]
83
+ adata = adata .tables [table_key ]. copy ()
86
84
else :
87
- ((key , table ),) = adata .tables .items ()
88
- if groups not in table .obs :
85
+ ((key , adata ),) = adata .tables .items ()
86
+ if groups not in adata .obs :
89
87
raise ValueError (
90
88
f"Group { groups } not found in table in `spatialdata`. Please specify a valid group in `groups`."
91
89
)
92
- else :
93
- table = adata .copy ()
94
-
95
- # check whether to use radius or knn for neighborhood profile calculation
96
- if radius is None and n_neighbors is None :
97
- raise ValueError ("Either `radius` or `n_neighbors` must be provided, but both are `None`." )
98
- if radius is not None and n_neighbors is not None :
99
- raise ValueError ("Either `radius` and `n_neighbors` must be provided, but both were provided." )
100
-
101
- # subset adata if only observations within specified groups are to be considered
102
- if limit_to is not None :
103
- if isinstance (limit_to , str ):
104
- limit_to = [limit_to ]
105
- table_subset = table [table .obs [groups ].isin ([limit_to ])]
106
- else :
107
- table_subset = table
108
90
109
91
if flavor == "neighborhood" :
110
92
rel_nhood_profile , abs_nhood_profile = _calculate_neighborhood_profile (
111
- table , groups , table_subset , spatial_connectivities_key
93
+ adata , groups , spatial_connectivities_key
112
94
)
113
- df = pd .DataFrame (rel_nhood_profile , index = table_subset .obs .index )
95
+ df = pd .DataFrame (rel_nhood_profile , index = adata .obs .index )
114
96
nhood_table = _df_to_adata (df )
115
- sc .pp .neighbors (nhood_table , n_neighbors = n_neighbors , use_rep = "X" )
116
- sc .tl .leiden (nhood_table )
117
- table .obs ["niche" ] = nhood_table .obs ["leiden" ]
118
97
if copy :
119
- return nhood_table
98
+ return df
120
99
else :
121
100
if is_sdata :
122
- adata .tables [f"{ flavor } _niche" ] = nhood_table
101
+ sdata .tables [f"{ flavor } _niche" ] = nhood_table
123
102
else :
124
- df = df .reindex (table .obs .index )
125
- print (df .head ())
126
- table .obsm [f"{ flavor } _niche" ] = df
103
+ adata .obsm ["neighborhood_profile" ] = df
127
104
128
105
elif flavor == "utag" :
129
- new_feature_matrix = _utag (table , normalize_adj = True , spatial_connectivity_key = spatial_connectivities_key )
130
- table .X = new_feature_matrix
106
+ new_feature_matrix = _utag (adata , normalize_adj = True , spatial_connectivity_key = spatial_connectivities_key )
131
107
if copy :
132
- return table
108
+ return new_feature_matrix
133
109
else :
134
110
if is_sdata :
135
- adata .tables [f"{ flavor } _niche" ] = table
111
+ sdata .tables [f"{ flavor } _niche" ] = new_feature_matrix
136
112
else :
137
- table .obsm [f"{ flavor } _niche" ] = table . X
113
+ adata .obsm [f"{ flavor } _niche" ] = new_feature_matrix
138
114
139
115
140
116
def _calculate_neighborhood_profile (
141
117
adata : AnnData | SpatialData ,
142
118
groups : str ,
143
- subset : AnnData ,
144
119
spatial_connectivities_key : str ,
145
120
) -> tuple [pd .DataFrame , pd .DataFrame ]:
146
- # reset index
147
- adata .obs = adata .obs .reset_index ()
148
-
149
121
# get obs x neighbor matrix from sparse matrix
150
122
matrix = adata .obsp [spatial_connectivities_key ].tocoo ()
151
123
nonzero_indices = np .split (matrix .col , matrix .row .searchsorted (np .arange (1 , matrix .shape [0 ])))
0 commit comments