Skip to content

Commit 49b51ca

Browse files
committed
Fix output; Remove subsetting, neighborhood options, dimreduction and clustering steps
1 parent 9fa0157 commit 49b51ca

File tree

1 file changed

+17
-45
lines changed

1 file changed

+17
-45
lines changed

src/squidpy/gr/_niche.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ def calculate_niche(
2424
groups: str,
2525
flavor: str = "neighborhood",
2626
library_key: str | None = None,
27-
radius: float | None = None,
28-
n_neighbors: int | None = None,
29-
limit_to: str | list[Any] | None = None,
3027
table_key: str | None = None,
3128
spatial_key: str = "spatial",
3229
spatial_connectivities_key: str = "spatial_connectivities",
@@ -51,7 +48,7 @@ def calculate_niche(
5148
- `{c.UTAG.s!r}` - use utag algorithm (matrix multiplication).
5249
- `{c.ALL.s!r}` - apply all available methods and compare them using cluster validation scores.
5350
%(library_key)s
54-
limit_to
51+
subset
5552
Restrict niche calculation to a subset of the data.
5653
table_key
5754
Key in `spatialdata.tables` to specify an 'anndata' table. Only necessary if 'sdata' is passed.
@@ -65,14 +62,15 @@ def calculate_niche(
6562
if isinstance(adata, SpatialData):
6663
is_sdata = True
6764
if table_key is not None:
68-
table = adata.tables[table_key]
65+
sdata = adata
66+
adata = adata.tables[table_key].copy()
6967
else:
7068
if len(adata.tables) > 1:
7169
count = 0
72-
for key in adata.tables.keys():
70+
for table in adata.tables.keys():
7371
if groups in table.obs:
7472
count += 1
75-
table_key = key
73+
table_key = table
7674
if count > 1:
7775
raise ValueError(
7876
f"Multiple tables in `spatialdata` with group `{groups}` detected. Please specify which table to use in `table_key`."
@@ -82,70 +80,44 @@ def calculate_niche(
8280
f"Group `{groups}` not found in any table in `spatialdata`. Please specify a valid group in `groups`."
8381
)
8482
else:
85-
table = adata.tables[table_key]
83+
adata = adata.tables[table_key].copy()
8684
else:
87-
((key, table),) = adata.tables.items()
88-
if groups not in table.obs:
85+
((key, adata),) = adata.tables.items()
86+
if groups not in adata.obs:
8987
raise ValueError(
9088
f"Group {groups} not found in table in `spatialdata`. Please specify a valid group in `groups`."
9189
)
92-
else:
93-
table = adata.copy()
94-
95-
# check whether to use radius or knn for neighborhood profile calculation
96-
if radius is None and n_neighbors is None:
97-
raise ValueError("Either `radius` or `n_neighbors` must be provided, but both are `None`.")
98-
if radius is not None and n_neighbors is not None:
99-
raise ValueError("Either `radius` and `n_neighbors` must be provided, but both were provided.")
100-
101-
# subset adata if only observations within specified groups are to be considered
102-
if limit_to is not None:
103-
if isinstance(limit_to, str):
104-
limit_to = [limit_to]
105-
table_subset = table[table.obs[groups].isin([limit_to])]
106-
else:
107-
table_subset = table
10890

10991
if flavor == "neighborhood":
11092
rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile(
111-
table, groups, table_subset, spatial_connectivities_key
93+
adata, groups, spatial_connectivities_key
11294
)
113-
df = pd.DataFrame(rel_nhood_profile, index=table_subset.obs.index)
95+
df = pd.DataFrame(rel_nhood_profile, index=adata.obs.index)
11496
nhood_table = _df_to_adata(df)
115-
sc.pp.neighbors(nhood_table, n_neighbors=n_neighbors, use_rep="X")
116-
sc.tl.leiden(nhood_table)
117-
table.obs["niche"] = nhood_table.obs["leiden"]
11897
if copy:
119-
return nhood_table
98+
return df
12099
else:
121100
if is_sdata:
122-
adata.tables[f"{flavor}_niche"] = nhood_table
101+
sdata.tables[f"{flavor}_niche"] = nhood_table
123102
else:
124-
df = df.reindex(table.obs.index)
125-
print(df.head())
126-
table.obsm[f"{flavor}_niche"] = df
103+
adata.obsm["neighborhood_profile"] = df
127104

128105
elif flavor == "utag":
129-
new_feature_matrix = _utag(table, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key)
130-
table.X = new_feature_matrix
106+
new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key)
131107
if copy:
132-
return table
108+
return new_feature_matrix
133109
else:
134110
if is_sdata:
135-
adata.tables[f"{flavor}_niche"] = table
111+
sdata.tables[f"{flavor}_niche"] = new_feature_matrix
136112
else:
137-
table.obsm[f"{flavor}_niche"] = table.X
113+
adata.obsm[f"{flavor}_niche"] = new_feature_matrix
138114

139115

140116
def _calculate_neighborhood_profile(
141117
adata: AnnData | SpatialData,
142118
groups: str,
143-
subset: AnnData,
144119
spatial_connectivities_key: str,
145120
) -> tuple[pd.DataFrame, pd.DataFrame]:
146-
# reset index
147-
adata.obs = adata.obs.reset_index()
148-
149121
# get obs x neighbor matrix from sparse matrix
150122
matrix = adata.obsp[spatial_connectivities_key].tocoo()
151123
nonzero_indices = np.split(matrix.col, matrix.row.searchsorted(np.arange(1, matrix.shape[0])))

0 commit comments

Comments
 (0)