Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions src/squidpy/gr/_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def test(
# much faster than applymap (tested on 1M interactions)
interactions_ = np.vectorize(lambda g: gene_mapper[g])(interactions.values)

rng = np.random.default_rng(seed)
n_jobs = _get_n_cores(kwargs.pop("n_jobs", None))
start = logg.info(
f"Running `{n_perms}` permutations on `{len(interactions)}` interactions "
Expand All @@ -420,7 +421,7 @@ def test(
clusters_,
threshold=threshold,
n_perms=n_perms,
seed=seed,
rng=rng,
n_jobs=n_jobs,
numba_parallel=numba_parallel,
**kwargs,
Expand Down Expand Up @@ -642,7 +643,17 @@ def ligrec(
copy: bool = False,
key_added: str | None = None,
gene_symbols: str | None = None,
**kwargs: Any,
n_perms: int = 1000,
seed: int | None = None,
clusters: Cluster_t | None = None,
alpha: float = 0.05,
numba_parallel: bool | None = None,
n_jobs: int | None = None,
backend: str = "loky",
show_progress_bar: bool = True,
interactions_params: Mapping[str, Any] = MappingProxyType({}),
transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}),
receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}),
) -> Mapping[str, pd.DataFrame] | None:
"""
%(PT_test.full_desc)s
Expand All @@ -664,15 +675,28 @@ def ligrec(
with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False):
return ( # type: ignore[no-any-return]
PermutationTest(adata, use_raw=use_raw)
.prepare(interactions, complex_policy=complex_policy, **kwargs)
.prepare(
interactions,
complex_policy=complex_policy,
interactions_params=interactions_params,
transmitter_params=transmitter_params,
receiver_params=receiver_params,
)
.test(
cluster_key=cluster_key,
clusters=clusters,
n_perms=n_perms,
threshold=threshold,
seed=seed,
corr_method=corr_method,
corr_axis=corr_axis,
alpha=alpha,
copy=copy,
key_added=key_added,
**kwargs,
numba_parallel=numba_parallel,
n_jobs=n_jobs,
backend=backend,
show_progress_bar=show_progress_bar,
)
)

Expand All @@ -682,9 +706,9 @@ def _analysis(
data: pd.DataFrame,
interactions: NDArrayA,
interaction_clusters: NDArrayA,
rng: np.random.Generator,
threshold: float = 0.1,
n_perms: int = 1000,
seed: int | None = None,
n_jobs: int = 1,
numba_parallel: bool | None = None,
**kwargs: Any,
Expand All @@ -705,7 +729,8 @@ def _analysis(
threshold
Percentage threshold for removing lowly expressed genes in clusters.
%(n_perms)s
%(seed)s
rng
NumPy :class:`numpy.random.Generator` for reproducibility.
n_jobs
Number of parallel jobs to launch.
numba_parallel
Expand Down Expand Up @@ -750,6 +775,7 @@ def extractor(res: Sequence[TempResult]) -> TempResult:

# (n_cells, n_genes)
data = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C")
root_seed = rng.integers(np.iinfo(np.int64).max)
# all 3 should be C contiguous
return parallelize( # type: ignore[no-any-return]
_analysis_helper,
Expand All @@ -765,7 +791,7 @@ def extractor(res: Sequence[TempResult]) -> TempResult:
interactions,
interaction_clusters=interaction_clusters,
clustering=clustering,
seed=seed,
root_seed=root_seed,
numba_parallel=numba_parallel,
)

Expand All @@ -778,7 +804,7 @@ def _analysis_helper(
interactions: NDArrayA,
interaction_clusters: NDArrayA,
clustering: NDArrayA,
seed: int | None = None,
root_seed: int,
numba_parallel: bool | None = None,
queue: SigQueue | None = None,
) -> TempResult:
Expand All @@ -788,7 +814,7 @@ def _analysis_helper(
Parameters
----------
perms
Permutation indices. Only used to set the ``seed``.
Permutation indices. Only used to differentiate workers/permutations.
data
Array of shape `(n_cells, n_genes)`.
mean
Expand All @@ -802,8 +828,9 @@ def _analysis_helper(
Array of shape `(n_interaction_clusters, 2)`.
clustering
Array of shape `(n_cells,)` containing the original clustering.
seed
Random seed for :class:`numpy.random.RandomState`.
root_seed
Integer seed derived from the root generator. Each worker creates
an independent stream via ``default_rng([perms[0], root_seed])``.
numba_parallel
Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically.
queue
Expand All @@ -818,7 +845,7 @@ def _analysis_helper(
- `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing `np.sum(T0 > T)`
where `T0` is the test statistic under null hypothesis and `T` is the true test statistic.
"""
rs = np.random.RandomState(None if seed is None else perms[0] + seed)
rng = np.random.default_rng([perms[0], root_seed])

clustering = clustering.copy()
n_cls = mean.shape[1]
Expand Down Expand Up @@ -847,7 +874,7 @@ def _analysis_helper(
test = _test

for _ in perms:
rs.shuffle(clustering)
rng.shuffle(clustering)
error = test(interactions, interaction_clusters, data, clustering, mean, mask, res=res)
if error:
raise ValueError("In the execution of the numba function, an unhandled case was encountered. ")
Expand Down
Binary file modified tests/_data/ligrec_no_numba.pickle
Binary file not shown.
Binary file added tests/_data/ligrec_pvalues_reference.pickle
Binary file not shown.
Binary file modified tests/_images/Ligrec_pvalue_threshold.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Ligrec_remove_nonsig_interactions.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def ligrec_no_numba() -> Mapping[str, pd.DataFrame]:
return {"means": data[0], "pvalues": data[1], "metadata": data[2]}


@pytest.fixture(scope="session")
def ligrec_pvalues_reference() -> Mapping[str, pd.DataFrame]:
with open("tests/_data/ligrec_pvalues_reference.pickle", "rb") as fin:
return pickle.load(fin)
Comment on lines +268 to +269
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we store this as a zarr file or hdf5 or even just an anndata?



@pytest.fixture(scope="session")
def ligrec_result() -> Mapping[str, pd.DataFrame]:
adata = _adata.copy()
Expand Down
17 changes: 17 additions & 0 deletions tests/graph/test_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,23 @@ def test_reproducibility_numba_off(
np.testing.assert_allclose(r["pvalues"], ligrec_no_numba["pvalues"])
np.testing.assert_array_equal(np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_no_numba["pvalues"])))

def test_pvalues_reference(
self, adata: AnnData, interactions: Interactions_t, ligrec_pvalues_reference: Mapping[str, pd.DataFrame]
):
r = ligrec(
adata, _CK, interactions=interactions, n_perms=25, copy=True, show_progress_bar=False, seed=42, n_jobs=1
)
np.testing.assert_array_equal(r["means"].index, ligrec_pvalues_reference["means"].index)
np.testing.assert_array_equal(r["means"].columns, ligrec_pvalues_reference["means"].columns)
np.testing.assert_array_equal(r["pvalues"].index, ligrec_pvalues_reference["pvalues"].index)
np.testing.assert_array_equal(r["pvalues"].columns, ligrec_pvalues_reference["pvalues"].columns)

np.testing.assert_allclose(r["means"], ligrec_pvalues_reference["means"])
np.testing.assert_allclose(r["pvalues"], ligrec_pvalues_reference["pvalues"])
np.testing.assert_array_equal(
np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_pvalues_reference["pvalues"]))
)

def test_logging(self, adata: AnnData, interactions: Interactions_t, capsys):
s.logfile = sys.stderr
s.verbosity = 4
Expand Down