@@ -270,18 +270,6 @@ def count_cohort_alleles(
270270 return conditional_merge_datasets (ds , new_ds , merge )
271271
272272
273- def _swap (dim : Dimension ) -> Dimension :
274- return "samples" if dim == "variants" else "variants"
275-
276-
277- def call_rate (ds : Dataset , dim : Dimension , call_genotype_mask : Hashable ) -> Dataset :
278- odim = _swap (dim )[:- 1 ]
279- n_called = (~ ds [call_genotype_mask ].any (dim = "ploidy" )).sum (dim = dim )
280- return create_dataset (
281- {f"{ odim } _n_called" : n_called , f"{ odim } _call_rate" : n_called / ds .dims [dim ]}
282- )
283-
284-
285273def count_variant_genotypes (
286274 ds : Dataset ,
287275 * ,
@@ -465,39 +453,6 @@ def genotype_coords(
465453 return ds
466454
467455
468- def count_genotypes (
469- ds : Dataset ,
470- dim : Dimension ,
471- call_genotype : Hashable = variables .call_genotype ,
472- call_genotype_mask : Hashable = variables .call_genotype_mask ,
473- merge : bool = True ,
474- ) -> Dataset :
475- variables .validate (
476- ds ,
477- {
478- call_genotype_mask : variables .call_genotype_mask_spec ,
479- call_genotype : variables .call_genotype_spec ,
480- },
481- )
482- odim = _swap (dim )[:- 1 ]
483- M , G = ds [call_genotype_mask ].any (dim = "ploidy" ), ds [call_genotype ]
484- n_hom_ref = (G == 0 ).all (dim = "ploidy" )
485- n_hom_alt = ((G > 0 ) & (G [..., 0 ] == G )).all (dim = "ploidy" )
486- n_non_ref = (G > 0 ).any (dim = "ploidy" )
487- n_het = ~ (n_hom_alt | n_hom_ref )
488- # This would 0 out the `het` case with any missing calls
489- agg = lambda x : xr .where (M , False , x ).sum (dim = dim ) # type: ignore[no-untyped-call]
490- new_ds = create_dataset (
491- {
492- f"{ odim } _n_het" : agg (n_het ), # type: ignore[no-untyped-call]
493- f"{ odim } _n_hom_ref" : agg (n_hom_ref ), # type: ignore[no-untyped-call]
494- f"{ odim } _n_hom_alt" : agg (n_hom_alt ), # type: ignore[no-untyped-call]
495- f"{ odim } _n_non_ref" : agg (n_non_ref ), # type: ignore[no-untyped-call]
496- }
497- )
498- return conditional_merge_datasets (ds , new_ds , merge )
499-
500-
501456def call_allele_frequencies (
502457 ds : Dataset ,
503458 * ,
@@ -769,9 +724,7 @@ def variant_stats(
769724def sample_stats (
770725 ds : Dataset ,
771726 * ,
772- call_genotype_mask : Hashable = variables .call_genotype_mask ,
773727 call_genotype : Hashable = variables .call_genotype ,
774- variant_allele_count : Hashable = variables .variant_allele_count ,
775728 merge : bool = True ,
776729) -> Dataset :
777730 """Compute quality control sample statistics from genotype calls.
@@ -784,15 +737,6 @@ def sample_stats(
784737 Input variable name holding call_genotype.
785738 Defined by :data:`sgkit.variables.call_genotype_spec`.
786739 Must be present in ``ds``.
787- call_genotype_mask
788- Input variable name holding call_genotype_mask.
789- Defined by :data:`sgkit.variables.call_genotype_mask_spec`
790- Must be present in ``ds``.
791- variant_allele_count
792- Input variable name holding variant_allele_count,
793- as defined by :data:`sgkit.variables.variant_allele_count_spec`.
794- If the variable is not present in ``ds``, it will be computed
795- using :func:`count_variant_alleles`.
796740 merge
797741 If True (the default), merge the input dataset and the computed
798742 output variables into a single dataset, otherwise return only
@@ -815,25 +759,63 @@ def sample_stats(
815759 The number of variants with homozygous alternate calls.
816760 - :data:`sgkit.variables.sample_n_non_ref_spec` (samples):
817761 The number of variants that are not homozygous reference calls.
762+
763+ Note
764+ ----
765+ If the dataset contains partial genotype calls (i.e., genotype calls with
766+ a mixture of called and missing alleles), these genotypes will be ignored
767+ when counting the number of homozygous, heterozygous or total genotype calls.
768+
769+ Note
770+ ----
771+ When used on autopolyploid genotypes, this method treats genotypes calls
772+ with any level of heterozygosity as 'heterozygous'. Only fully homozygous
773+ genotype calls (e.g. 0/0/0/0) will be classified as 'homozygous'.
774+
775+ Warnings
776+ --------
777+ This method does not support mixed-ploidy datasets.
778+
779+ Raises
780+ ------
781+ ValueError
782+ If the dataset contains mixed-ploidy genotype calls.
818783 """
819- variables .validate (
820- ds ,
821- {
822- call_genotype : variables .call_genotype_spec ,
823- call_genotype_mask : variables .call_genotype_mask_spec ,
824- },
784+ from .aggregation_numba_fns import count_hom
785+
786+ variables .validate (ds , {call_genotype : variables .call_genotype_spec })
787+ mixed_ploidy = ds [call_genotype ].attrs .get ("mixed_ploidy" , False )
788+ if mixed_ploidy :
789+ raise ValueError ("Mixed-ploidy dataset" )
790+ G = da .array (ds [call_genotype ].data )
791+ H = xr .DataArray (
792+ da .map_blocks (
793+ count_hom ,
794+ G .transpose (1 , 0 , 2 ),
795+ np .zeros (3 , np .uint64 ),
796+ drop_axis = (1 , 2 ),
797+ new_axis = 1 ,
798+ dtype = np .int64 ,
799+ chunks = (G .chunks [1 ], 3 ),
800+ ),
801+ dims = ["samples" , "categories" ],
825802 )
826- new_ds = xr .merge (
827- [
828- call_rate (ds , dim = "variants" , call_genotype_mask = call_genotype_mask ),
829- count_genotypes (
830- ds ,
831- dim = "variants" ,
832- call_genotype = call_genotype ,
833- call_genotype_mask = call_genotype_mask ,
834- merge = False ,
835- ),
836- ]
803+ n_variant , _ , _ = G .shape
804+ n_called = H .sum (axis = - 1 )
805+ call_rate = n_called / n_variant
806+ n_hom_ref = H [:, 0 ]
807+ n_hom_alt = H [:, 1 ]
808+ n_het = H [:, 2 ]
809+ n_non_ref = n_called - n_hom_ref
810+ new_ds = xr .Dataset (
811+ {
812+ variables .sample_n_called : n_called ,
813+ variables .sample_call_rate : call_rate ,
814+ variables .sample_n_het : n_het ,
815+ variables .sample_n_hom_ref : n_hom_ref ,
816+ variables .sample_n_hom_alt : n_hom_alt ,
817+ variables .sample_n_non_ref : n_non_ref ,
818+ }
837819 )
838820 return conditional_merge_datasets (ds , variables .validate (new_ds ), merge )
839821
0 commit comments