Skip to content

Commit 4896263

Browse files
committed
fixed n_hop logic
1 parent 987ed5d commit 4896263

File tree

1 file changed

+70
-31
lines changed

1 file changed

+70
-31
lines changed

src/squidpy/gr/_niche.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import warnings
45
from typing import Any, Literal
56

@@ -148,19 +149,19 @@ def calculate_niche(
148149
"If you haven't computed a spatial neighborhood graph yet, use `sq.gr.spatial_neighbors`."
149150
)
150151

152+
result_columns = _get_result_columns(
153+
flavor=flavor,
154+
resolutions=resolutions,
155+
library_key=None,
156+
libraries=None,
157+
)
158+
151159
if library_key is not None:
152160
if library_key not in adata.obs.columns:
153161
raise KeyError(f"'{library_key}' not found in `adata.obs`.")
154162

155163
logg.info(f"Stratifying by library_key '{library_key}'")
156164

157-
result_columns = _get_result_columns(
158-
flavor=flavor,
159-
resolutions=resolutions,
160-
library_key=None,
161-
libraries=None,
162-
)
163-
164165
for col in result_columns:
165166
adata.obs[col] = "not_a_niche"
166167

@@ -226,15 +227,25 @@ def calculate_niche(
226227
spatial_connectivities_key,
227228
)
228229

229-
if inplace:
230-
if isinstance(data, SpatialData):
231-
data.tables[table_key] = adata
232-
else:
233-
for col in adata.obs.columns:
234-
if "niche" in col and col not in orig_adata.obs.columns:
235-
orig_adata.obs[col] = adata.obs[col]
230+
if not inplace:
231+
return adata
232+
# For SpatialData, update the table directly
233+
if isinstance(data, SpatialData):
234+
data.tables[table_key] = adata
235+
else:
236+
# For AnnData, copy results back to original object
237+
for col in result_columns:
238+
if col in orig_adata.obs.columns:
239+
logg.info(f"Overwriting existing column '{col}'")
240+
with contextlib.suppress(KeyError):
241+
del orig_adata.obs[col]
242+
if f"{col}_colors" in orig_adata.uns.keys():
243+
with contextlib.suppress(KeyError):
244+
del orig_adata.uns[f"{col}_colors"]
245+
246+
orig_adata.obs[col] = adata.obs[col]
236247

237-
return adata if not inplace else orig_adata
248+
return None
238249

239250

240251
def _get_result_columns(
@@ -345,17 +356,32 @@ def _get_nhood_profile_niches(
345356
if n_hop_weights is None:
346357
n_hop_weights = [1] * distance
347358
# if weights are provided, start with applying weight to the original neighborhood profile
348-
else:
349-
nhood_profile = n_hop_weights[0] * nhood_profile
359+
elif len(n_hop_weights) < distance:
360+
# Extend weights if too few provided
361+
n_hop_weights = n_hop_weights + [n_hop_weights[-1]] * (distance - len(n_hop_weights))
362+
logg.debug(f"Extended weights to match distance: {n_hop_weights}")
363+
364+
# Apply first weight to base profile
365+
weighted_profile = n_hop_weights[0] * nhood_profile
366+
367+
# Calculate higher-order hop profiles
368+
n_hop_adjacency_matrix = adata_masked.obsp[spatial_connectivities_key].copy()
369+
350370
# get n_hop neighbor adjacency matrices by multiplying the original adjacency matrix with itself n times and get corresponding neighborhood profiles.
351-
for n_hop in range(distance - 1):
371+
for n_hop in range(1, distance):
372+
logg.debug(f"Calculating {n_hop + 1}-hop neighbors")
373+
# Multiply adjacency matrix by itself to get n+1 hop adjacency
352374
n_hop_adjacency_matrix = n_hop_adjacency_matrix @ adata_masked.obsp[spatial_connectivities_key]
353375
matrix = n_hop_adjacency_matrix.tocoo()
354-
nhood_profile += n_hop_weights[n_hop + 1] * _calculate_neighborhood_profile(
355-
adata_masked, groups, matrix, abs_nhood
356-
)
376+
377+
# Calculate and add weighted profile
378+
hop_profile = _calculate_neighborhood_profile(adata_masked, groups, matrix, abs_nhood)
379+
weighted_profile += n_hop_weights[n_hop] * hop_profile
380+
357381
if not abs_nhood:
358-
nhood_profile = nhood_profile / sum(n_hop_weights)
382+
weighted_profile = weighted_profile / sum(n_hop_weights)
383+
384+
nhood_profile = weighted_profile
359385

360386
# create AnnData object from neighborhood profile to perform scanpy functions
361387
adata_neighborhood = ad.AnnData(X=nhood_profile)
@@ -376,25 +402,38 @@ def _get_nhood_profile_niches(
376402

377403
# For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label
378404
for res in resolutions:
405+
niche_key = f"nhood_niche_res={res}"
406+
407+
if niche_key in adata_masked.obs.columns:
408+
del adata_masked.obs[niche_key]
409+
410+
if f"{niche_key}_colors" in adata_masked.uns.keys():
411+
del adata_masked.uns[f"{niche_key}_colors"]
412+
# print(adata_masked.obs[niche_key])
413+
379414
sc.tl.leiden(
380415
adata_neighborhood,
381416
resolution=res,
382-
key_added=f"nhood_niche_res={res}",
417+
key_added=niche_key,
383418
)
384-
adata_masked.obs[f"nhood_niche_res={res}"] = adata_masked.obs.index.map(
385-
adata_neighborhood.obs[f"nhood_niche_res={res}"]
386-
).fillna("not_a_niche")
419+
420+
adata_masked.obs[niche_key] = "not_a_niche"
421+
422+
neighborhood_clusters = dict(zip(adata_neighborhood.obs.index, adata_neighborhood.obs[niche_key], strict=False))
423+
424+
mask_indices = adata_masked.obs.index
425+
adata_masked.obs.loc[mask_indices, niche_key] = [
426+
neighborhood_clusters.get(idx, "not_a_niche") for idx in mask_indices
427+
]
387428

388429
# filter niches with n_cells < min_niche_size
389430
if min_niche_size is not None:
390-
counts_by_niche = adata_masked.obs[f"nhood_niche_res={res}"].value_counts()
431+
counts_by_niche = adata_masked.obs[niche_key].value_counts()
391432
to_filter = counts_by_niche[counts_by_niche < min_niche_size].index
392-
adata_masked.obs[f"nhood_niche_res={res}"] = adata_masked.obs[f"nhood_niche_res={res}"].apply(
433+
adata_masked.obs[niche_key] = adata_masked.obs[niche_key].apply(
393434
lambda x, to_filter=to_filter: "not_a_niche" if x in to_filter else x
394435
)
395-
adata.obs[f"nhood_niche_res={res}"] = adata.obs.index.map(
396-
adata_masked.obs[f"nhood_niche_res={res}"]
397-
).fillna("not_a_niche")
436+
adata_masked.obs[niche_key] = adata_masked.obs.index.map(adata_masked.obs[niche_key]).fillna("not_a_niche")
398437

399438
return
400439

0 commit comments

Comments
 (0)