Skip to content

Commit 9f0a874

Browse files
timothymillartomwhite
authored andcommitted
Use gufunc in sample_stats
1 parent e07dde7 commit 9f0a874

File tree

2 files changed

+93
-78
lines changed

2 files changed

+93
-78
lines changed

sgkit/stats/aggregation.py

Lines changed: 55 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -270,18 +270,6 @@ def count_cohort_alleles(
270270
return conditional_merge_datasets(ds, new_ds, merge)
271271

272272

273-
def _swap(dim: Dimension) -> Dimension:
274-
return "samples" if dim == "variants" else "variants"
275-
276-
277-
def call_rate(ds: Dataset, dim: Dimension, call_genotype_mask: Hashable) -> Dataset:
278-
odim = _swap(dim)[:-1]
279-
n_called = (~ds[call_genotype_mask].any(dim="ploidy")).sum(dim=dim)
280-
return create_dataset(
281-
{f"{odim}_n_called": n_called, f"{odim}_call_rate": n_called / ds.dims[dim]}
282-
)
283-
284-
285273
def count_variant_genotypes(
286274
ds: Dataset,
287275
*,
@@ -465,39 +453,6 @@ def genotype_coords(
465453
return ds
466454

467455

468-
def count_genotypes(
469-
ds: Dataset,
470-
dim: Dimension,
471-
call_genotype: Hashable = variables.call_genotype,
472-
call_genotype_mask: Hashable = variables.call_genotype_mask,
473-
merge: bool = True,
474-
) -> Dataset:
475-
variables.validate(
476-
ds,
477-
{
478-
call_genotype_mask: variables.call_genotype_mask_spec,
479-
call_genotype: variables.call_genotype_spec,
480-
},
481-
)
482-
odim = _swap(dim)[:-1]
483-
M, G = ds[call_genotype_mask].any(dim="ploidy"), ds[call_genotype]
484-
n_hom_ref = (G == 0).all(dim="ploidy")
485-
n_hom_alt = ((G > 0) & (G[..., 0] == G)).all(dim="ploidy")
486-
n_non_ref = (G > 0).any(dim="ploidy")
487-
n_het = ~(n_hom_alt | n_hom_ref)
488-
# This would 0 out the `het` case with any missing calls
489-
agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call]
490-
new_ds = create_dataset(
491-
{
492-
f"{odim}_n_het": agg(n_het), # type: ignore[no-untyped-call]
493-
f"{odim}_n_hom_ref": agg(n_hom_ref), # type: ignore[no-untyped-call]
494-
f"{odim}_n_hom_alt": agg(n_hom_alt), # type: ignore[no-untyped-call]
495-
f"{odim}_n_non_ref": agg(n_non_ref), # type: ignore[no-untyped-call]
496-
}
497-
)
498-
return conditional_merge_datasets(ds, new_ds, merge)
499-
500-
501456
def call_allele_frequencies(
502457
ds: Dataset,
503458
*,
@@ -769,9 +724,7 @@ def variant_stats(
769724
def sample_stats(
770725
ds: Dataset,
771726
*,
772-
call_genotype_mask: Hashable = variables.call_genotype_mask,
773727
call_genotype: Hashable = variables.call_genotype,
774-
variant_allele_count: Hashable = variables.variant_allele_count,
775728
merge: bool = True,
776729
) -> Dataset:
777730
"""Compute quality control sample statistics from genotype calls.
@@ -784,15 +737,6 @@ def sample_stats(
784737
Input variable name holding call_genotype.
785738
Defined by :data:`sgkit.variables.call_genotype_spec`.
786739
Must be present in ``ds``.
787-
call_genotype_mask
788-
Input variable name holding call_genotype_mask.
789-
Defined by :data:`sgkit.variables.call_genotype_mask_spec`
790-
Must be present in ``ds``.
791-
variant_allele_count
792-
Input variable name holding variant_allele_count,
793-
as defined by :data:`sgkit.variables.variant_allele_count_spec`.
794-
If the variable is not present in ``ds``, it will be computed
795-
using :func:`count_variant_alleles`.
796740
merge
797741
If True (the default), merge the input dataset and the computed
798742
output variables into a single dataset, otherwise return only
@@ -815,25 +759,63 @@ def sample_stats(
815759
The number of variants with homozygous alternate calls.
816760
- :data:`sgkit.variables.sample_n_non_ref_spec` (samples):
817761
The number of variants that are not homozygous reference calls.
762+
763+
Note
764+
----
765+
If the dataset contains partial genotype calls (i.e., genotype calls with
766+
a mixture of called and missing alleles), these genotypes will be ignored
767+
when counting the number of homozygous, heterozygous or total genotype calls.
768+
769+
Note
770+
----
771+
When used on autopolyploid genotypes, this method treats genotypes calls
772+
with any level of heterozygosity as 'heterozygous'. Only fully homozygous
773+
genotype calls (e.g. 0/0/0/0) will be classified as 'homozygous'.
774+
775+
Warnings
776+
--------
777+
This method does not support mixed-ploidy datasets.
778+
779+
Raises
780+
------
781+
ValueError
782+
If the dataset contains mixed-ploidy genotype calls.
818783
"""
819-
variables.validate(
820-
ds,
821-
{
822-
call_genotype: variables.call_genotype_spec,
823-
call_genotype_mask: variables.call_genotype_mask_spec,
824-
},
784+
from .aggregation_numba_fns import count_hom
785+
786+
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
787+
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
788+
if mixed_ploidy:
789+
raise ValueError("Mixed-ploidy dataset")
790+
G = da.array(ds[call_genotype].data)
791+
H = xr.DataArray(
792+
da.map_blocks(
793+
count_hom,
794+
G.transpose(1, 0, 2),
795+
np.zeros(3, np.uint64),
796+
drop_axis=(1, 2),
797+
new_axis=1,
798+
dtype=np.int64,
799+
chunks=(G.chunks[1], 3),
800+
),
801+
dims=["samples", "categories"],
825802
)
826-
new_ds = xr.merge(
827-
[
828-
call_rate(ds, dim="variants", call_genotype_mask=call_genotype_mask),
829-
count_genotypes(
830-
ds,
831-
dim="variants",
832-
call_genotype=call_genotype,
833-
call_genotype_mask=call_genotype_mask,
834-
merge=False,
835-
),
836-
]
803+
n_variant, _, _ = G.shape
804+
n_called = H.sum(axis=-1)
805+
call_rate = n_called / n_variant
806+
n_hom_ref = H[:, 0]
807+
n_hom_alt = H[:, 1]
808+
n_het = H[:, 2]
809+
n_non_ref = n_called - n_hom_ref
810+
new_ds = xr.Dataset(
811+
{
812+
variables.sample_n_called: n_called,
813+
variables.sample_call_rate: call_rate,
814+
variables.sample_n_het: n_het,
815+
variables.sample_n_hom_ref: n_hom_ref,
816+
variables.sample_n_hom_alt: n_hom_alt,
817+
variables.sample_n_non_ref: n_non_ref,
818+
}
837819
)
838820
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
839821

sgkit/tests/test_aggregation.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -794,15 +794,11 @@ def test_variant_stats__raise_on_mixed_ploidy():
794794
variant_stats(ds)
795795

796796

797-
@pytest.mark.parametrize("precompute_variant_allele_count", [False, True])
798-
def test_sample_stats(precompute_variant_allele_count):
797+
def test_sample_stats():
799798
ds = get_dataset(
800799
[[[1, 0], [-1, -1]], [[1, 0], [1, 1]], [[0, 1], [1, 0]], [[-1, -1], [0, 0]]]
801800
)
802-
if precompute_variant_allele_count:
803-
ds = count_variant_alleles(ds)
804801
ss = sample_stats(ds)
805-
806802
np.testing.assert_equal(ss["sample_n_called"], np.array([3, 3]))
807803
np.testing.assert_equal(ss["sample_call_rate"], np.array([0.75, 0.75]))
808804
np.testing.assert_equal(ss["sample_n_hom_ref"], np.array([0, 1]))
@@ -811,6 +807,43 @@ def test_sample_stats(precompute_variant_allele_count):
811807
np.testing.assert_equal(ss["sample_n_non_ref"], np.array([3, 2]))
812808

813809

810+
def test_sample_stats__multi_allelic():
811+
ds = simulate_genotype_call_dataset(n_variant=2, n_sample=4, n_allele=4, seed=0)
812+
ds["call_genotype"].data = [
813+
[[0, 0], [0, 0], [1, 1], [2, 2]],
814+
[[0, 0], [2, 3], [0, -1], [-1, 2]],
815+
]
816+
vs = sample_stats(ds)
817+
np.testing.assert_equal(vs["sample_n_called"], np.array([2, 2, 1, 1]))
818+
np.testing.assert_equal(vs["sample_call_rate"], np.array([1, 1, 0.5, 0.5]))
819+
np.testing.assert_equal(vs["sample_n_hom_ref"], np.array([2, 1, 0, 0]))
820+
np.testing.assert_equal(vs["sample_n_hom_alt"], np.array([0, 0, 1, 1]))
821+
np.testing.assert_equal(vs["sample_n_het"], np.array([0, 1, 0, 0]))
822+
np.testing.assert_equal(vs["sample_n_non_ref"], np.array([0, 1, 1, 1]))
823+
824+
825+
def test_sample_stats__tetraploid():
826+
ds = simulate_genotype_call_dataset(n_variant=2, n_sample=3, n_ploidy=4, seed=0)
827+
ds["call_genotype"].data = [
828+
[[0, 0, 0, 0], [0, 0, 0, 1], [1, 1, 1, 1]],
829+
[[0, 0, 1, 1], [0, 1, 1, 1], [0, 0, -1, 0]],
830+
]
831+
vs = sample_stats(ds)
832+
np.testing.assert_equal(vs["sample_n_called"], np.array([2, 2, 1]))
833+
np.testing.assert_equal(vs["sample_call_rate"], np.array([1, 1, 0.5]))
834+
np.testing.assert_equal(vs["sample_n_hom_ref"], np.array([1, 0, 0]))
835+
np.testing.assert_equal(vs["sample_n_hom_alt"], np.array([0, 0, 1]))
836+
np.testing.assert_equal(vs["sample_n_het"], np.array([1, 2, 0]))
837+
np.testing.assert_equal(vs["sample_n_non_ref"], np.array([1, 2, 1]))
838+
839+
840+
def test_sample_stats__raise_on_mixed_ploidy():
841+
ds = simulate_genotype_call_dataset(n_variant=2, n_sample=2, n_ploidy=3, seed=0)
842+
ds["call_genotype"].attrs["mixed_ploidy"] = True
843+
with pytest.raises(ValueError, match="Mixed-ploidy dataset"):
844+
sample_stats(ds)
845+
846+
814847
def test_infer_call_ploidy():
815848
ds = get_dataset(
816849
[

0 commit comments

Comments
 (0)