Skip to content

Commit 3a9a0ec

Browse files
authored
Backport PR #3853: fix: scanpy benchmark uses pip for igraph + setup_cache (#3859)
1 parent 2153086 commit 3a9a0ec

File tree

5 files changed

+129
-159
lines changed

5 files changed

+129
-159
lines changed

.github/workflows/benchmark.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
key: benchmark-state-${{ hashFiles('benchmarks/**') }}
4949

5050
- name: Install dependencies
51-
run: pip install 'asv>=0.6.4'
51+
run: pip install 'asv>=0.6.4' py-rattler
5252

5353
- name: Configure ASV
5454
working-directory: ${{ env.ASV_DIR }}

benchmarks/asv.conf.json

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
// If missing or the empty string, the tool will be automatically
4444
// determined by looking for tools on the PATH environment
4545
// variable.
46-
"environment_type": "conda",
46+
"environment_type": "rattler",
4747

4848
// timeout in seconds for installing any dependencies in environment
4949
// defaults to 10 min
@@ -78,13 +78,12 @@
7878
"natsort": [""],
7979
"pandas": [""],
8080
"memory_profiler": [""],
81-
"zarr": ["2.18.4"],
81+
"zarr": [""],
8282
"pytest": [""],
83-
"scanpy": [""],
84-
"python-igraph": [""],
83+
"pip+igraph": [""], // https://github.com/airspeed-velocity/asv/issues/1554
8584
// "psutil": [""]
8685
"pooch": [""],
87-
"scikit-image": [""],
86+
"scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29
8887
// "scikit-misc": [""],
8988
},
9089

benchmarks/benchmarks/preprocessing_counts.py

Lines changed: 62 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5,73 +5,64 @@
55

66
from __future__ import annotations
77

8+
from itertools import product
89
from typing import TYPE_CHECKING
910

11+
import anndata as ad
12+
1013
import scanpy as sc
1114

1215
from ._utils import get_count_dataset
1316

1417
if TYPE_CHECKING:
15-
from anndata import AnnData
16-
1718
from ._utils import Dataset, KeyCount
1819

19-
# setup variables
20-
21-
adata: AnnData
22-
batch_key: str | None
23-
24-
25-
def setup(dataset: Dataset, layer: KeyCount, *_):
26-
"""Set up global variables before each benchmark."""
27-
global adata, batch_key
28-
adata, batch_key = get_count_dataset(dataset, layer=layer)
29-
assert "log1p" not in adata.uns
30-
3120

3221
# ASV suite
22+
class PreprocessingCountsSuite: # noqa: D101
23+
params: tuple[list[Dataset], list[KeyCount]] = (
24+
["pbmc68k_reduced", "pbmc3k"],
25+
["counts", "counts-off-axis"],
26+
)
27+
param_names = ("dataset", "layer")
3328

34-
params: tuple[list[Dataset], list[KeyCount]] = (
35-
["pbmc68k_reduced", "pbmc3k"],
36-
["counts", "counts-off-axis"],
37-
)
38-
param_names = ["dataset", "layer"]
39-
40-
41-
def time_filter_cells(*_):
42-
sc.pp.filter_cells(adata, min_genes=100)
43-
44-
45-
def peakmem_filter_cells(*_):
46-
sc.pp.filter_cells(adata, min_genes=100)
47-
48-
49-
def time_filter_genes(*_):
50-
sc.pp.filter_genes(adata, min_cells=3)
51-
29+
def setup_cache(self) -> None:
30+
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
31+
for dataset, layer in product(*self.params):
32+
adata, batch_key = get_count_dataset(dataset, layer=layer)
33+
assert "lop1p" not in adata.uns
34+
adata.uns["batch_key"] = batch_key
35+
adata.write_h5ad(f"{dataset}_{layer}.h5ad")
5236

53-
def peakmem_filter_genes(*_):
54-
sc.pp.filter_genes(adata, min_cells=3)
37+
def setup(self, dataset, layer) -> None:
38+
self.adata = ad.read_h5ad(f"{dataset}_{layer}.h5ad")
5539

40+
def time_filter_cells(self, *_) -> None:
41+
sc.pp.filter_cells(self.adata, min_genes=100)
5642

57-
def time_scrublet(*_):
58-
sc.pp.scrublet(adata, batch_key=batch_key)
43+
def peakmem_filter_cells(self, *_) -> None:
44+
sc.pp.filter_cells(self.adata, min_genes=100)
5945

46+
def time_filter_genes(self, *_) -> None:
47+
sc.pp.filter_genes(self.adata, min_cells=3)
6048

61-
def peakmem_scrublet(*_):
62-
sc.pp.scrublet(adata, batch_key=batch_key)
49+
def peakmem_filter_genes(self, *_) -> None:
50+
sc.pp.filter_genes(self.adata, min_cells=3)
6351

52+
def time_scrublet(self, *_) -> None:
53+
sc.pp.scrublet(self.adata, batch_key=self.adata.uns["batch_key"])
6454

65-
# Can’t do seurat v3 yet: https://github.com/conda-forge/scikit-misc-feedstock/issues/17
66-
"""
67-
def time_hvg_seurat_v3(*_):
68-
# seurat v3 runs on counts
69-
sc.pp.highly_variable_genes(adata, flavor="seurat_v3_paper")
55+
def peakmem_scrublet(self, *_) -> None:
56+
sc.pp.scrublet(self.adata, batch_key=self.adata.uns["batch_key"])
7057

58+
# sciki-misc does not exit on osx-arm64
59+
# https://github.com/conda-forge/scikit-misc-feedstock/pull/29
60+
# def time_hvg_seurat_v3(self, *_):
61+
# # seurat v3 runs on counts
62+
# sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3_paper")
7163

72-
def peakmem_hvg_seurat_v3(*_):
73-
sc.pp.highly_variable_genes(adata, flavor="seurat_v3_paper")
74-
"""
64+
# def peakmem_hvg_seurat_v3(self, *_):
65+
# sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3_paper")
7566

7667

7768
class FastSuite:
@@ -83,28 +74,38 @@ class FastSuite:
8374
)
8475
param_names = ("dataset", "layer")
8576

86-
def time_calculate_qc_metrics(self, *_):
77+
def setup_cache(self) -> None:
78+
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
79+
for dataset, layer in product(*self.params):
80+
adata, _ = get_count_dataset(dataset, layer=layer)
81+
assert "lop1p" not in adata.uns
82+
adata.write_h5ad(f"{dataset}_{layer}.h5ad")
83+
84+
def setup(self, dataset, layer) -> None:
85+
self.adata = ad.read_h5ad(f"{dataset}_{layer}.h5ad")
86+
87+
def time_calculate_qc_metrics(self, *_) -> None:
8788
sc.pp.calculate_qc_metrics(
88-
adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
89+
self.adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
8990
)
9091

91-
def peakmem_calculate_qc_metrics(self, *_):
92+
def peakmem_calculate_qc_metrics(self, *_) -> None:
9293
sc.pp.calculate_qc_metrics(
93-
adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
94+
self.adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
9495
)
9596

96-
def time_normalize_total(self, *_):
97-
sc.pp.normalize_total(adata, target_sum=1e4)
97+
def time_normalize_total(self, *_) -> None:
98+
sc.pp.normalize_total(self.adata, target_sum=1e4)
9899

99-
def peakmem_normalize_total(self, *_):
100-
sc.pp.normalize_total(adata, target_sum=1e4)
100+
def peakmem_normalize_total(self, *_) -> None:
101+
sc.pp.normalize_total(self.adata, target_sum=1e4)
101102

102-
def time_log1p(self, *_):
103-
# TODO: This would fail: assert "log1p" not in adata.uns, "ASV bug?"
103+
def time_log1p(self, *_) -> None:
104+
# TODO: This would fail: assert "log1p" not in self.adata.uns, "ASV bug?"
104105
# https://github.com/scverse/scanpy/issues/3052
105-
adata.uns.pop("log1p", None)
106-
sc.pp.log1p(adata)
106+
self.adata.uns.pop("log1p", None)
107+
sc.pp.log1p(self.adata)
107108

108-
def peakmem_log1p(self, *_):
109-
adata.uns.pop("log1p", None)
110-
sc.pp.log1p(adata)
109+
def peakmem_log1p(self, *_) -> None:
110+
self.adata.uns.pop("log1p", None)
111+
sc.pp.log1p(self.adata)

benchmarks/benchmarks/preprocessing_log.py

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,89 +5,70 @@
55

66
from __future__ import annotations
77

8+
from itertools import product
89
from typing import TYPE_CHECKING
910

11+
import anndata as ad
12+
1013
import scanpy as sc
11-
from scanpy.preprocessing._utils import _get_mean_var
1214

1315
from ._utils import get_dataset, param_skipper
1416

1517
if TYPE_CHECKING:
16-
from anndata import AnnData
17-
1818
from ._utils import Dataset, KeyX
1919

20-
# setup variables
21-
22-
23-
adata: AnnData
24-
batch_key: str | None
25-
26-
27-
def setup(dataset: Dataset, layer: KeyX, *_):
28-
"""Set up global variables before each benchmark."""
29-
global adata, batch_key
30-
adata, batch_key = get_dataset(dataset, layer=layer)
31-
3220

3321
# ASV suite
3422

3523
params: tuple[list[Dataset], list[KeyX]] = (
3624
["pbmc68k_reduced", "pbmc3k"],
3725
[None, "off-axis"],
3826
)
39-
param_names = ["dataset", "layer"]
40-
27+
param_names = ("dataset", "layer")
4128
skip_when = param_skipper(param_names, params)
4229

4330

44-
def time_pca(*_):
45-
sc.pp.pca(adata, svd_solver="arpack")
46-
47-
48-
def peakmem_pca(*_):
49-
sc.pp.pca(adata, svd_solver="arpack")
50-
51-
52-
def time_highly_variable_genes(*_):
53-
# the default flavor runs on log-transformed data
54-
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
55-
56-
57-
def peakmem_highly_variable_genes(*_):
58-
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
59-
60-
61-
# regress_out is very slow for this dataset
62-
@skip_when(dataset={"pbmc3k"})
63-
def time_regress_out(*_):
64-
sc.pp.regress_out(adata, ["total_counts", "pct_counts_mt"])
65-
31+
class PreprocessingSuite: # noqa: D101
32+
params = params
33+
param_names = param_names
6634

67-
@skip_when(dataset={"pbmc3k"})
68-
def peakmem_regress_out(*_):
69-
sc.pp.regress_out(adata, ["total_counts", "pct_counts_mt"])
35+
def setup_cache(self) -> None:
36+
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
37+
for dataset, layer in product(*self.params):
38+
adata, _ = get_dataset(dataset, layer=layer)
39+
adata.write_h5ad(f"{dataset}_{layer}.h5ad")
7040

41+
def setup(self, dataset, layer) -> None:
42+
self.adata = ad.read_h5ad(f"{dataset}_{layer}.h5ad")
7143

72-
def time_scale(*_):
73-
sc.pp.scale(adata, max_value=10)
44+
def time_pca(self, *_) -> None:
45+
sc.pp.pca(self.adata, svd_solver="arpack")
7446

47+
def peakmem_pca(self, *_) -> None:
48+
sc.pp.pca(self.adata, svd_solver="arpack")
7549

76-
def peakmem_scale(*_):
77-
sc.pp.scale(adata, max_value=10)
50+
def time_highly_variable_genes(self, *_) -> None:
51+
# the default flavor runs on log-transformed data
52+
sc.pp.highly_variable_genes(
53+
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
54+
)
7855

56+
def peakmem_highly_variable_genes(self, *_) -> None:
57+
sc.pp.highly_variable_genes(
58+
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
59+
)
7960

80-
class FastSuite:
81-
"""Suite for fast preprocessing operations."""
61+
# regress_out is very slow for this dataset
62+
@skip_when(dataset={"pbmc3k"})
63+
def time_regress_out(self, *_) -> None:
64+
sc.pp.regress_out(self.adata, ["total_counts", "pct_counts_mt"])
8265

83-
params: tuple[list[Dataset], list[KeyX]] = (
84-
["pbmc3k", "pbmc68k_reduced", "bmmc", "lung93k"],
85-
[None, "off-axis"],
86-
)
87-
param_names = ("dataset", "layer")
66+
@skip_when(dataset={"pbmc3k"})
67+
def peakmem_regress_out(self, *_) -> None:
68+
sc.pp.regress_out(self.adata, ["total_counts", "pct_counts_mt"])
8869

89-
def time_mean_var(self, *_):
90-
_get_mean_var(adata.X)
70+
def time_scale(self, *_) -> None:
71+
sc.pp.scale(self.adata, max_value=10)
9172

92-
def peakmem_mean_var(self, *_):
93-
_get_mean_var(adata.X)
73+
def peakmem_scale(self, *_) -> None:
74+
sc.pp.scale(self.adata, max_value=10)

0 commit comments

Comments
 (0)