Skip to content

Commit bd16795

Browse files
committed
Update
1 parent f797ef9 commit bd16795

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/squidpy/gr/_niche.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,16 @@ def calculate_niche(
121121
adj_matrix_subsets = []
122122
if isinstance(adj_subsets, list):
123123
for k in adj_subsets:
124-
adj_matrix_subsets.append(
125-
_get_adj_matrix_subsets(
126-
adata.obsp[spatial_connectivities_key], adata.obsp[spatial_distances_key], k
124+
if k == 0:
125+
adj_matrix_subsets.append(adata.obsp[spatial_connectivities_key])
126+
else:
127+
adj_matrix_subsets.append(
128+
_get_adj_matrix_subsets(
129+
adata.obsp[spatial_connectivities_key], adata.obsp[spatial_distances_key], k
130+
)
127131
)
128-
)
129132
if aggregation == "mean":
130-
inner_products = [adata.X.dot(adj_subset) for adj_subset in adj_matrix_subsets]
133+
inner_products = [adj_subset.dot(adata.X) for adj_subset in adj_matrix_subsets]
131134
elif aggregation == "variance":
132135
inner_products = [
133136
_aggregate_var(matrix, adata.obsp[spatial_connectivities_key], adata) for matrix in inner_products
@@ -222,7 +225,7 @@ def _get_adj_matrix_subsets(connectivities: csr_matrix, distances: csr_matrix, k
222225

223226
# Create the new sparse matrix with the reduced neighbors
224227
new_adj_matrix = csr_matrix((data, (rows, cols)), shape=connectivities.shape)
225-
228+
print(new_adj_matrix.shape)
226229
return new_adj_matrix
227230

228231

0 commit comments

Comments
 (0)