Skip to content

Commit d7779a3

Browse files
committed
Don't use local functions to wrap numba functions
1 parent 9dd940e commit d7779a3

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

sgkit/stats/aggregation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def variant_stats(
680680
--------
681681
:func:`count_variant_genotypes`
682682
"""
683-
from .aggregation_numba_fns import count_hom
683+
from .aggregation_numba_fns import count_hom_new_axis
684684

685685
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
686686
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
@@ -697,7 +697,7 @@ def variant_stats(
697697
G = da.asarray(ds[call_genotype].data)
698698
H = xr.DataArray(
699699
da.map_blocks(
700-
lambda *args: count_hom(*args)[:, np.newaxis, :],
700+
count_hom_new_axis,
701701
G,
702702
np.zeros(3, np.uint64),
703703
drop_axis=2,
@@ -796,7 +796,7 @@ def sample_stats(
796796
ValueError
797797
If the dataset contains mixed-ploidy genotype calls.
798798
"""
799-
from .aggregation_numba_fns import count_hom
799+
from .aggregation_numba_fns import count_hom_new_axis
800800

801801
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
802802
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
@@ -805,7 +805,7 @@ def sample_stats(
805805
GT = da.asarray(ds[call_genotype].transpose("samples", "variants", "ploidy").data)
806806
H = xr.DataArray(
807807
da.map_blocks(
808-
lambda *args: count_hom(*args)[:, np.newaxis, :],
808+
count_hom_new_axis,
809809
GT,
810810
np.zeros(3, np.uint64),
811811
drop_axis=2,

sgkit/stats/aggregation_numba_fns.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# in a separate file here, and imported dynamically to avoid
33
# initial compilation overhead.
44

5+
import numpy as np
6+
57
from sgkit.accelerate import numba_guvectorize, numba_jit
68
from sgkit.typing import ArrayLike
79

@@ -102,3 +104,7 @@ def count_hom(
102104
index = _classify_hom(genotypes[i])
103105
if index >= 0:
104106
out[index] += 1
107+
108+
109+
def count_hom_new_axis(genotypes: ArrayLike, _: ArrayLike) -> ArrayLike:
110+
return count_hom(genotypes, _)[:, np.newaxis, :]

sgkit/stats/popgen.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,7 @@ def pbs(
595595
cohorts = cohorts or list(itertools.combinations(range(n_cohorts), 3)) # type: ignore
596596
ct = _cohorts_to_array(cohorts, ds.indexes.get("cohorts_0", None))
597597

598-
p = da.map_blocks(
599-
lambda t: _pbs_cohorts(t, ct), t, chunks=shape, new_axis=3, dtype=np.float64
600-
)
598+
p = da.map_blocks(_pbs_cohorts, t, ct, chunks=shape, new_axis=3, dtype=np.float64)
601599
assert_array_shape(p, n_windows, n_cohorts, n_cohorts, n_cohorts)
602600

603601
new_ds = create_dataset(

0 commit comments

Comments
 (0)