diff --git a/src/squidpy/gr/_ligrec.py b/src/squidpy/gr/_ligrec.py index d7766e5b0..d54ce91f0 100644 --- a/src/squidpy/gr/_ligrec.py +++ b/src/squidpy/gr/_ligrec.py @@ -735,11 +735,15 @@ def extractor(res: Sequence[TempResult]) -> TempResult: clustering = np.array(data["clusters"].values, dtype=np.int32) mean = groups.mean().values.T # (n_genes, n_clusters) - mask = groups.apply(lambda c: ((c > 0).sum() / len(c)) >= threshold).values.T # (n_genes, n_clusters) + # see https://github.com/scverse/squidpy/pull/991#issuecomment-2888506296 + # for why we need to cast to int64 here + mask = groups.apply( + lambda c: ((c > 0).astype(np.int64).sum() / len(c)) >= threshold + ).values.T # (n_genes, n_clusters) + # (n_cells, n_genes) data = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C") # all 3 should be C contiguous - return parallelize( # type: ignore[no-any-return] _analysis_helper, np.arange(n_perms, dtype=np.int32).tolist(), diff --git a/tests/graph/test_ligrec.py b/tests/graph/test_ligrec.py index 0145999d6..099ecb1b6 100644 --- a/tests/graph/test_ligrec.py +++ b/tests/graph/test_ligrec.py @@ -14,6 +14,7 @@ from pandas.testing import assert_frame_equal from scanpy import settings as s from scanpy.datasets import blobs +from scipy.sparse import csc_matrix from squidpy._constants._pkg_constants import Key from squidpy.gr import ligrec @@ -461,3 +462,104 @@ def test_none_source_target(self, adata: AnnData): ) assert isinstance(pt.interactions, pd.DataFrame) assert len(pt.interactions) == 1 + + def test_ligrec_nan_counts(self): + """ + For the test case with 2 clusters (A, B) and 3 gene pairs (Gene1→Gene2, Gene2→Gene3, Gene3→Gene1): + + The mask is computed for each gene in each cluster as: + mask[gene, cluster] = (number of cells with value > 0) / (total cells in cluster) >= threshold + + Number of cells with value > 0 in each cluster: + Cluster A: [1, 3, 0] + Cluster B: [1, 0, 3] + + Number of cells with value > 0 in each cluster divided by total number of cells in the cluster: + Cluster A: [1/3, 3/3, 0/3] = [0.33, 1.0, 0.0] + Cluster B: [1/3, 0/3, 3/3] = [0.33, 0.0, 1.0] + + Using threshold=0.8 on this data, the mask is: + Cluster A: [False, True, False] + Cluster B: [False, False, True] + + A value in the result becomes NaN if either: + - The ligand's mask is False in the source cluster, OR + - The receptor's mask is False in the target cluster + + Only in one combination, the mask is both True in the source and target cluster. + This is the case for Gene2→Gene3 in A→B. + + This means from all the possible cluster pairs (A→A, A→B, B→A, B→B) and gene pairs (Gene1→Gene2, Gene2→Gene3, Gene3→Gene1), + (4 cluster pairs * 3 gene pairs = 12 combinations) only one combination is non-NaN. + + Therefore, the total number of NaNs is 11. + + The expected p-values are: + cluster_1 A B + cluster_2 A B A B + source target + GENE1 GENE2 NaN NaN NaN NaN + GENE2 GENE3 NaN 0.0 NaN NaN + GENE3 GENE1 NaN NaN NaN NaN + + """ + # only Gene2→Gene3 is non-NaN + # + + expected_pvalues = np.array( + [ + [ + np.nan, + np.nan, + np.nan, + np.nan, + ], + [ + np.nan, + 0.0, + np.nan, + np.nan, + ], + [ + np.nan, + np.nan, + np.nan, + np.nan, + ], + ] + ) + + expected_nans = 11 + # Setup test data + threshold = 0.8 + interactions = pd.DataFrame({"source": ["Gene1", "Gene2", "Gene3"], "target": ["Gene2", "Gene3", "Gene1"]}) + + # Create sparse matrix with test data + X = csc_matrix( + [ + [1.0, 0.1, 0.0], # A1 + [0.0, 1.0, 0.0], # A2 + [0.0, 1.0, 0.0], # A3 + [0.1, 0.0, 1.0], # B1 + [0.0, 0.0, 1.0], # B2 + [0.0, 0.0, 1.0], # B3 + ] + ) + + # Create AnnData object + adata = AnnData( + X=X, + obs=pd.DataFrame({"cluster": ["A"] * 3 + ["B"] * 3}, index=[f"cell{i}" for i in range(1, 7)]), + var=pd.DataFrame(index=["Gene1", "Gene2", "Gene3"]), + ) + adata.obs["cluster"] = adata.obs["cluster"].astype("category") + + # Run ligrec and compare NaN counts + res = ligrec( + adata, cluster_key="cluster", interactions=interactions, threshold=threshold, use_raw=False, copy=True + ) + + actual_nans = np.sum(np.isnan(res["pvalues"].values)) + + assert actual_nans == expected_nans, f"NaN count mismatch: expected {expected_nans}, got {actual_nans}" + np.testing.assert_array_equal(res["pvalues"].values, expected_pvalues)