Skip to content

Commit 364552a

Browse files
timothymillartomwhite
authored andcommitted
Improve performance of variant_stats #1116
* Add count_variant_alleles option to calculate directly from calls * Improve performance of variant_stats using gufuncs * Raise error is variant_stats used on mixed-ploidy data * Document behavior of variant_stats with partial genotype calls
1 parent 72ba776 commit 364552a

File tree

3 files changed

+308
-81
lines changed

3 files changed

+308
-81
lines changed

sgkit/stats/aggregation.py

Lines changed: 117 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Hashable
1+
from typing import Hashable
22

33
import dask.array as da
44
import numpy as np
@@ -96,7 +96,9 @@ def count_call_alleles(
9696
def count_variant_alleles(
9797
ds: Dataset,
9898
*,
99+
call_genotype: Hashable = variables.call_genotype,
99100
call_allele_count: Hashable = variables.call_allele_count,
101+
using: Literal[variables.call_allele_count, variables.call_genotype] = variables.call_allele_count, # type: ignore
100102
merge: bool = True,
101103
) -> Dataset:
102104
"""Compute allele count from per-sample allele counts, or genotype calls.
@@ -105,11 +107,22 @@ def count_variant_alleles(
105107
----------
106108
ds
107109
Dataset containing genotype calls.
110+
call_genotype
111+
Input variable name holding call_genotype as defined by
112+
:data:`sgkit.variables.call_genotype_spec`.
113+
This variable is only used if specified by the 'using' argument.
108114
call_allele_count
109115
Input variable name holding call_allele_count as defined by
110116
:data:`sgkit.variables.call_allele_count_spec`.
117+
This variable is only used if specified by the 'using' argument.
111118
If the variable is not present in ``ds``, it will be computed
112119
using :func:`count_call_alleles`.
120+
using
121+
specify the variable used to calculate allele counts from.
122+
If ``'call_allele_count'`` (the default), the result will
123+
be calculated from the call_allele_count variable.
124+
If ``'call_genotype'``, the result will be calculated from
125+
the call_genotype variable.
113126
merge
114127
If True (the default), merge the input dataset and the computed
115128
output variables into a single dataset, otherwise return only
@@ -122,6 +135,12 @@ def count_variant_alleles(
122135
of allele counts with shape (variants, alleles) and values corresponding to
123136
the number of non-missing occurrences of each allele.
124137
138+
Note
139+
----
140+
This method is more efficient when calculating allele counts directly from
141+
the call_genotype variable unless the call_allele_count variable has already
142+
been (or will be) calculated.
143+
125144
Examples
126145
--------
127146
@@ -141,14 +160,28 @@ def count_variant_alleles(
141160
[2, 2],
142161
[4, 0]], dtype=uint64)
143162
"""
144-
ds = define_variable_if_absent(
145-
ds, variables.call_allele_count, call_allele_count, count_call_alleles
146-
)
147-
variables.validate(ds, {call_allele_count: variables.call_allele_count_spec})
148-
149-
new_ds = create_dataset(
150-
{variables.variant_allele_count: ds[call_allele_count].sum(dim="samples")}
151-
)
163+
if using == variables.call_allele_count:
164+
ds = define_variable_if_absent(
165+
ds, variables.call_allele_count, call_allele_count, count_call_alleles
166+
)
167+
variables.validate(ds, {call_allele_count: variables.call_allele_count_spec})
168+
AC = ds[call_allele_count].sum(dim="samples")
169+
elif using == variables.call_genotype:
170+
from .aggregation_numba_fns import count_alleles
171+
172+
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
173+
n_alleles = ds.dims["alleles"]
174+
n_variant = ds.dims["variants"]
175+
G = da.asarray(ds[call_genotype]).reshape((n_variant, -1))
176+
shape = (G.chunks[0], n_alleles)
177+
# use uint64 dummy array to return uin64 counts array
178+
N = np.empty(n_alleles, dtype=np.uint64)
179+
AC = da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=1, new_axis=1)
180+
AC = xr.DataArray(AC, dims=["variants", "alleles"])
181+
else:
182+
options = {variables.call_genotype, variables.call_allele_count}
183+
raise ValueError(f"The 'using' argument must be one of {options}.")
184+
new_ds = create_dataset({variables.variant_allele_count: AC})
152185
return conditional_merge_datasets(ds, new_ds, merge)
153186

154187

@@ -601,35 +634,9 @@ def cohort_allele_frequencies(
601634
return conditional_merge_datasets(ds, new_ds, merge)
602635

603636

604-
def allele_frequency(
605-
ds: Dataset,
606-
call_genotype_mask: Hashable,
607-
variant_allele_count: Hashable,
608-
) -> Dataset:
609-
data_vars: Dict[Hashable, Any] = {}
610-
# only compute variant allele count if not already in dataset
611-
if variant_allele_count in ds:
612-
variables.validate(
613-
ds, {variant_allele_count: variables.variant_allele_count_spec}
614-
)
615-
AC = ds[variant_allele_count]
616-
else:
617-
AC = count_variant_alleles(ds, merge=False)[variables.variant_allele_count]
618-
data_vars[variables.variant_allele_count] = AC
619-
620-
M = ds[call_genotype_mask].stack(calls=("samples", "ploidy"))
621-
AN = (~M).sum(dim="calls")
622-
assert AN.shape == (ds.dims["variants"],)
623-
624-
data_vars[variables.variant_allele_total] = AN
625-
data_vars[variables.variant_allele_frequency] = AC / AN
626-
return create_dataset(data_vars)
627-
628-
629637
def variant_stats(
630638
ds: Dataset,
631639
*,
632-
call_genotype_mask: Hashable = variables.call_genotype_mask,
633640
call_genotype: Hashable = variables.call_genotype,
634641
variant_allele_count: Hashable = variables.variant_allele_count,
635642
merge: bool = True,
@@ -644,10 +651,6 @@ def variant_stats(
644651
Input variable name holding call_genotype.
645652
Defined by :data:`sgkit.variables.call_genotype_spec`.
646653
Must be present in ``ds``.
647-
call_genotype_mask
648-
Input variable name holding call_genotype_mask.
649-
Defined by :data:`sgkit.variables.call_genotype_mask_spec`
650-
Must be present in ``ds``.
651654
variant_allele_count
652655
Input variable name holding variant_allele_count,
653656
as defined by :data:`sgkit.variables.variant_allele_count_spec`.
@@ -681,31 +684,85 @@ def variant_stats(
681684
The number of occurrences of all alleles.
682685
- :data:`sgkit.variables.variant_allele_frequency_spec` (variants, alleles):
683686
The frequency of occurrence of each allele.
687+
688+
Note
689+
----
690+
If the dataset contains partial genotype calls (i.e., genotype calls with
691+
a mixture of called and missing alleles), these genotypes will be ignored
692+
when counting the number of homozygous, heterozygous or total genotype calls.
693+
However, the called alleles will be counted when calculating allele counts
694+
and frequencies using :func:`count_variant_alleles`.
695+
696+
Note
697+
----
698+
When used on autopolyploid genotypes, this method treats genotypes calls
699+
with any level of heterozygosity as 'heterozygous'. Only fully homozygous
700+
genotype calls (e.g. 0/0/0/0) will be classified as 'homozygous'.
701+
702+
Warnings
703+
--------
704+
This method does not support mixed-ploidy datasets.
705+
706+
Raises
707+
------
708+
ValueError
709+
If the dataset contains mixed-ploidy genotype calls.
710+
711+
See Also
712+
--------
713+
:func:`count_variant_genotypes`
684714
"""
685-
variables.validate(
715+
from .aggregation_numba_fns import count_hom
716+
717+
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
718+
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
719+
if mixed_ploidy:
720+
raise ValueError("Mixed-ploidy dataset")
721+
AC = define_variable_if_absent(
686722
ds,
687-
{
688-
call_genotype: variables.call_genotype_spec,
689-
call_genotype_mask: variables.call_genotype_mask_spec,
690-
},
723+
variables.variant_allele_count,
724+
variant_allele_count,
725+
count_variant_alleles,
726+
using=variables.call_genotype, # improved performance
727+
merge=False,
728+
)[variant_allele_count]
729+
G = da.array(ds[call_genotype].data)
730+
H = xr.DataArray(
731+
da.map_blocks(
732+
count_hom,
733+
G,
734+
np.zeros(3, np.uint64),
735+
drop_axis=(1, 2),
736+
new_axis=1,
737+
dtype=np.int64,
738+
chunks=(G.chunks[0], 3),
739+
),
740+
dims=["variants", "categories"],
691741
)
692-
new_ds = xr.merge(
693-
[
694-
call_rate(ds, dim="samples", call_genotype_mask=call_genotype_mask),
695-
count_genotypes(
696-
ds,
697-
dim="samples",
698-
call_genotype=call_genotype,
699-
call_genotype_mask=call_genotype_mask,
700-
merge=False,
701-
),
702-
allele_frequency(
703-
ds,
704-
call_genotype_mask=call_genotype_mask,
705-
variant_allele_count=variant_allele_count,
706-
),
707-
]
742+
_, n_sample, _ = G.shape
743+
n_called = H.sum(axis=-1)
744+
call_rate = n_called / n_sample
745+
n_hom_ref = H[:, 0]
746+
n_hom_alt = H[:, 1]
747+
n_het = H[:, 2]
748+
n_non_ref = n_called - n_hom_ref
749+
allele_total = AC.sum(axis=-1).astype(int) # backwards compatibility
750+
new_ds = xr.Dataset(
751+
{
752+
variables.variant_n_called: n_called,
753+
variables.variant_call_rate: call_rate,
754+
variables.variant_n_het: n_het,
755+
variables.variant_n_hom_ref: n_hom_ref,
756+
variables.variant_n_hom_alt: n_hom_alt,
757+
variables.variant_n_non_ref: n_non_ref,
758+
variables.variant_allele_count: AC,
759+
variables.variant_allele_total: allele_total,
760+
variables.variant_allele_frequency: AC / allele_total,
761+
}
708762
)
763+
# for backwards compatible behavior
764+
if (variant_allele_count in ds) and merge:
765+
new_ds = new_ds.drop_vars(variant_allele_count)
709766
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
710767

711768

sgkit/stats/aggregation_numba_fns.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# in a separate file here, and imported dynamically to avoid
33
# initial compilation overhead.
44

5-
from sgkit.accelerate import numba_guvectorize
5+
from sgkit.accelerate import numba_guvectorize, numba_jit
66
from sgkit.typing import ArrayLike
77

88

@@ -12,6 +12,10 @@
1212
"void(int16[:], uint8[:], uint8[:])",
1313
"void(int32[:], uint8[:], uint8[:])",
1414
"void(int64[:], uint8[:], uint8[:])",
15+
"void(int8[:], uint64[:], uint64[:])",
16+
"void(int16[:], uint64[:], uint64[:])",
17+
"void(int32[:], uint64[:], uint64[:])",
18+
"void(int64[:], uint64[:], uint64[:])",
1519
],
1620
"(k),(n)->(n)",
1721
)
@@ -26,9 +30,10 @@ def count_alleles(
2630
Genotype call of shape (ploidy,) containing alleles encoded as
2731
type `int` with values < 0 indicating a missing allele.
2832
_
29-
Dummy variable of type `uint8` and shape (alleles,) used to
30-
define the number of unique alleles to be counted in the
31-
return value.
33+
Dummy variable of type `uint8` or `uint64` and shape (alleles,)
34+
used to define the number of unique alleles to be counted in the
35+
return value. The dtype of this array determines the dtype of the
36+
returned array.
3237
3338
Returns
3439
-------
@@ -43,3 +48,57 @@ def count_alleles(
4348
a = g[i]
4449
if a >= 0:
4550
out[a] += 1
51+
52+
53+
@numba_jit(nogil=True)
54+
def _classify_hom(genotype: ArrayLike) -> int: # pragma: no cover
55+
a0 = genotype[0]
56+
cat = min(a0, 1) # -1, 0, 1
57+
for i in range(1, len(genotype)):
58+
if cat < 0:
59+
break
60+
a = genotype[i]
61+
if a != a0:
62+
cat = 2 # het
63+
if a < 0:
64+
cat = -1
65+
return cat
66+
67+
68+
@numba_guvectorize( # type: ignore
69+
[
70+
"void(int8[:,:], uint64[:], int64[:])",
71+
"void(int16[:,:], uint64[:], int64[:])",
72+
"void(int32[:,:], uint64[:], int64[:])",
73+
"void(int64[:,:], uint64[:], int64[:])",
74+
],
75+
"(n, k),(c)->(c)",
76+
)
77+
def count_hom(
78+
genotypes: ArrayLike, _: ArrayLike, out: ArrayLike
79+
) -> None: # pragma: no cover
80+
"""Generalized U-function for counting homozygous and heterozygous genotypes.
81+
82+
Parameters
83+
----------
84+
g
85+
Genotype call of shape (ploidy,) containing alleles encoded as
86+
type `int` with values < 0 indicating a missing allele.
87+
_
88+
Dummy variable of type `uint64` with length 3 which determines the
89+
number of categories returned (this should always be 3).
90+
91+
Note
92+
----
93+
This method is not suitable for mixed-ploidy genotypes.
94+
95+
Returns
96+
-------
97+
counts : ndarray
98+
Counts of homozygous reference, homozygous alternate, and heterozygous genotypes.
99+
"""
100+
out[:] = 0
101+
for i in range(len(genotypes)):
102+
index = _classify_hom(genotypes[i])
103+
if index >= 0:
104+
out[index] += 1

0 commit comments

Comments
 (0)