1
- from typing import Any , Dict , Hashable
1
+ from typing import Hashable
2
2
3
3
import dask .array as da
4
4
import numpy as np
@@ -96,7 +96,9 @@ def count_call_alleles(
96
96
def count_variant_alleles (
97
97
ds : Dataset ,
98
98
* ,
99
+ call_genotype : Hashable = variables .call_genotype ,
99
100
call_allele_count : Hashable = variables .call_allele_count ,
101
+ using : Literal [variables .call_allele_count , variables .call_genotype ] = variables .call_allele_count , # type: ignore
100
102
merge : bool = True ,
101
103
) -> Dataset :
102
104
"""Compute allele count from per-sample allele counts, or genotype calls.
@@ -105,11 +107,22 @@ def count_variant_alleles(
105
107
----------
106
108
ds
107
109
Dataset containing genotype calls.
110
+ call_genotype
111
+ Input variable name holding call_genotype as defined by
112
+ :data:`sgkit.variables.call_genotype_spec`.
113
+ This variable is only used if specified by the 'using' argument.
108
114
call_allele_count
109
115
Input variable name holding call_allele_count as defined by
110
116
:data:`sgkit.variables.call_allele_count_spec`.
117
+ This variable is only used if specified by the 'using' argument.
111
118
If the variable is not present in ``ds``, it will be computed
112
119
using :func:`count_call_alleles`.
120
+ using
121
+ specify the variable used to calculate allele counts from.
122
+ If ``'call_allele_count'`` (the default), the result will
123
+ be calculated from the call_allele_count variable.
124
+ If ``'call_genotype'``, the result will be calculated from
125
+ the call_genotype variable.
113
126
merge
114
127
If True (the default), merge the input dataset and the computed
115
128
output variables into a single dataset, otherwise return only
@@ -122,6 +135,12 @@ def count_variant_alleles(
122
135
of allele counts with shape (variants, alleles) and values corresponding to
123
136
the number of non-missing occurrences of each allele.
124
137
138
+ Note
139
+ ----
140
+ This method is more efficient when calculating allele counts directly from
141
+ the call_genotype variable unless the call_allele_count variable has already
142
+ been (or will be) calculated.
143
+
125
144
Examples
126
145
--------
127
146
@@ -141,14 +160,28 @@ def count_variant_alleles(
141
160
[2, 2],
142
161
[4, 0]], dtype=uint64)
143
162
"""
144
- ds = define_variable_if_absent (
145
- ds , variables .call_allele_count , call_allele_count , count_call_alleles
146
- )
147
- variables .validate (ds , {call_allele_count : variables .call_allele_count_spec })
148
-
149
- new_ds = create_dataset (
150
- {variables .variant_allele_count : ds [call_allele_count ].sum (dim = "samples" )}
151
- )
163
+ if using == variables .call_allele_count :
164
+ ds = define_variable_if_absent (
165
+ ds , variables .call_allele_count , call_allele_count , count_call_alleles
166
+ )
167
+ variables .validate (ds , {call_allele_count : variables .call_allele_count_spec })
168
+ AC = ds [call_allele_count ].sum (dim = "samples" )
169
+ elif using == variables .call_genotype :
170
+ from .aggregation_numba_fns import count_alleles
171
+
172
+ variables .validate (ds , {call_genotype : variables .call_genotype_spec })
173
+ n_alleles = ds .dims ["alleles" ]
174
+ n_variant = ds .dims ["variants" ]
175
+ G = da .asarray (ds [call_genotype ]).reshape ((n_variant , - 1 ))
176
+ shape = (G .chunks [0 ], n_alleles )
177
+ # use uint64 dummy array to return uin64 counts array
178
+ N = np .empty (n_alleles , dtype = np .uint64 )
179
+ AC = da .map_blocks (count_alleles , G , N , chunks = shape , drop_axis = 1 , new_axis = 1 )
180
+ AC = xr .DataArray (AC , dims = ["variants" , "alleles" ])
181
+ else :
182
+ options = {variables .call_genotype , variables .call_allele_count }
183
+ raise ValueError (f"The 'using' argument must be one of { options } ." )
184
+ new_ds = create_dataset ({variables .variant_allele_count : AC })
152
185
return conditional_merge_datasets (ds , new_ds , merge )
153
186
154
187
@@ -601,35 +634,9 @@ def cohort_allele_frequencies(
601
634
return conditional_merge_datasets (ds , new_ds , merge )
602
635
603
636
604
- def allele_frequency (
605
- ds : Dataset ,
606
- call_genotype_mask : Hashable ,
607
- variant_allele_count : Hashable ,
608
- ) -> Dataset :
609
- data_vars : Dict [Hashable , Any ] = {}
610
- # only compute variant allele count if not already in dataset
611
- if variant_allele_count in ds :
612
- variables .validate (
613
- ds , {variant_allele_count : variables .variant_allele_count_spec }
614
- )
615
- AC = ds [variant_allele_count ]
616
- else :
617
- AC = count_variant_alleles (ds , merge = False )[variables .variant_allele_count ]
618
- data_vars [variables .variant_allele_count ] = AC
619
-
620
- M = ds [call_genotype_mask ].stack (calls = ("samples" , "ploidy" ))
621
- AN = (~ M ).sum (dim = "calls" )
622
- assert AN .shape == (ds .dims ["variants" ],)
623
-
624
- data_vars [variables .variant_allele_total ] = AN
625
- data_vars [variables .variant_allele_frequency ] = AC / AN
626
- return create_dataset (data_vars )
627
-
628
-
629
637
def variant_stats (
630
638
ds : Dataset ,
631
639
* ,
632
- call_genotype_mask : Hashable = variables .call_genotype_mask ,
633
640
call_genotype : Hashable = variables .call_genotype ,
634
641
variant_allele_count : Hashable = variables .variant_allele_count ,
635
642
merge : bool = True ,
@@ -644,10 +651,6 @@ def variant_stats(
644
651
Input variable name holding call_genotype.
645
652
Defined by :data:`sgkit.variables.call_genotype_spec`.
646
653
Must be present in ``ds``.
647
- call_genotype_mask
648
- Input variable name holding call_genotype_mask.
649
- Defined by :data:`sgkit.variables.call_genotype_mask_spec`
650
- Must be present in ``ds``.
651
654
variant_allele_count
652
655
Input variable name holding variant_allele_count,
653
656
as defined by :data:`sgkit.variables.variant_allele_count_spec`.
@@ -681,31 +684,85 @@ def variant_stats(
681
684
The number of occurrences of all alleles.
682
685
- :data:`sgkit.variables.variant_allele_frequency_spec` (variants, alleles):
683
686
The frequency of occurrence of each allele.
687
+
688
+ Note
689
+ ----
690
+ If the dataset contains partial genotype calls (i.e., genotype calls with
691
+ a mixture of called and missing alleles), these genotypes will be ignored
692
+ when counting the number of homozygous, heterozygous or total genotype calls.
693
+ However, the called alleles will be counted when calculating allele counts
694
+ and frequencies using :func:`count_variant_alleles`.
695
+
696
+ Note
697
+ ----
698
+ When used on autopolyploid genotypes, this method treats genotypes calls
699
+ with any level of heterozygosity as 'heterozygous'. Only fully homozygous
700
+ genotype calls (e.g. 0/0/0/0) will be classified as 'homozygous'.
701
+
702
+ Warnings
703
+ --------
704
+ This method does not support mixed-ploidy datasets.
705
+
706
+ Raises
707
+ ------
708
+ ValueError
709
+ If the dataset contains mixed-ploidy genotype calls.
710
+
711
+ See Also
712
+ --------
713
+ :func:`count_variant_genotypes`
684
714
"""
685
- variables .validate (
715
+ from .aggregation_numba_fns import count_hom
716
+
717
+ variables .validate (ds , {call_genotype : variables .call_genotype_spec })
718
+ mixed_ploidy = ds [call_genotype ].attrs .get ("mixed_ploidy" , False )
719
+ if mixed_ploidy :
720
+ raise ValueError ("Mixed-ploidy dataset" )
721
+ AC = define_variable_if_absent (
686
722
ds ,
687
- {
688
- call_genotype : variables .call_genotype_spec ,
689
- call_genotype_mask : variables .call_genotype_mask_spec ,
690
- },
723
+ variables .variant_allele_count ,
724
+ variant_allele_count ,
725
+ count_variant_alleles ,
726
+ using = variables .call_genotype , # improved performance
727
+ merge = False ,
728
+ )[variant_allele_count ]
729
+ G = da .array (ds [call_genotype ].data )
730
+ H = xr .DataArray (
731
+ da .map_blocks (
732
+ count_hom ,
733
+ G ,
734
+ np .zeros (3 , np .uint64 ),
735
+ drop_axis = (1 , 2 ),
736
+ new_axis = 1 ,
737
+ dtype = np .int64 ,
738
+ chunks = (G .chunks [0 ], 3 ),
739
+ ),
740
+ dims = ["variants" , "categories" ],
691
741
)
692
- new_ds = xr .merge (
693
- [
694
- call_rate (ds , dim = "samples" , call_genotype_mask = call_genotype_mask ),
695
- count_genotypes (
696
- ds ,
697
- dim = "samples" ,
698
- call_genotype = call_genotype ,
699
- call_genotype_mask = call_genotype_mask ,
700
- merge = False ,
701
- ),
702
- allele_frequency (
703
- ds ,
704
- call_genotype_mask = call_genotype_mask ,
705
- variant_allele_count = variant_allele_count ,
706
- ),
707
- ]
742
+ _ , n_sample , _ = G .shape
743
+ n_called = H .sum (axis = - 1 )
744
+ call_rate = n_called / n_sample
745
+ n_hom_ref = H [:, 0 ]
746
+ n_hom_alt = H [:, 1 ]
747
+ n_het = H [:, 2 ]
748
+ n_non_ref = n_called - n_hom_ref
749
+ allele_total = AC .sum (axis = - 1 ).astype (int ) # backwards compatibility
750
+ new_ds = xr .Dataset (
751
+ {
752
+ variables .variant_n_called : n_called ,
753
+ variables .variant_call_rate : call_rate ,
754
+ variables .variant_n_het : n_het ,
755
+ variables .variant_n_hom_ref : n_hom_ref ,
756
+ variables .variant_n_hom_alt : n_hom_alt ,
757
+ variables .variant_n_non_ref : n_non_ref ,
758
+ variables .variant_allele_count : AC ,
759
+ variables .variant_allele_total : allele_total ,
760
+ variables .variant_allele_frequency : AC / allele_total ,
761
+ }
708
762
)
763
+ # for backwards compatible behavior
764
+ if (variant_allele_count in ds ) and merge :
765
+ new_ds = new_ds .drop_vars (variant_allele_count )
709
766
return conditional_merge_datasets (ds , variables .validate (new_ds ), merge )
710
767
711
768
0 commit comments