@@ -834,135 +834,118 @@ def test_ancestral_missing_info(tmp_path, caplog):
834834 assert inf_var .site .ancestral_state == var .site .ancestral_state
835835
836836
837- @pytest .mark .skip ("Depends on sgkit methods" )
838837class TestVariantDataErrors :
839838 @staticmethod
840- def simulate_genotype_call_dataset (* args , ** kwargs ):
841- # roll our own simulate_genotype_call_dataset to hack around bug in sgkit where
842- # duplicate alleles are created. Doesn't need to be efficient: just for testing
843- if "seed" not in kwargs :
844- kwargs ["seed" ] = 123
845- ds = sgkit .simulate_genotype_call_dataset (* args , ** kwargs ) ## NOQA
846- variant_alleles = ds ["variant_allele" ].values
847- allowed_alleles = np .array (
848- ["A" , "T" , "C" , "G" , "N" ], dtype = variant_alleles .dtype
839+ def simulate_vcz (path , n_sample , n_variant , n_ploidy = 2 ):
840+ # use b2z to create a zarr dataset we can modify
841+ ts = msprime .sim_ancestry (
842+ n_sample , sequence_length = 1000 , ploidy = n_ploidy , random_seed = 123
849843 )
850- for row in range (len (variant_alleles )):
851- alleles = variant_alleles [row ]
852- if len (set (alleles )) != len (alleles ):
853- # Just use a set that we know is unique
854- variant_alleles [row ] = allowed_alleles [0 : len (alleles )]
855- ds ["variant_allele" ] = ds ["variant_allele" ].dims , variant_alleles
856- return ds
844+ # Add enough mutations that we get at least n_variant sites
845+ ts = msprime .sim_mutations (ts , rate = 1e-3 , random_seed = 123 )
846+ assert ts .num_sites >= n_variant
847+ # Filter down to the first n_variant sites
848+ ts = ts .keep_intervals ([(0 , ts .site (n_variant - 1 ).position + 1 )]).trim ()
849+ assert ts .num_sites == n_variant
850+ ts2z .convert (ts , path )
851+ z = zarr .open (path , mode = "a" )
852+ z .create_dataset (
853+ name = "variant_ancestral_allele" ,
854+ data = ts .sites_ancestral_state ,
855+ dtype = "S1" ,
856+ chunks = True ,
857+ )
858+ return z
859+
860+ def test_phased (self , tmp_path ):
861+ # Check the default runs without error
862+ path = tmp_path / "data.vcz"
863+ self .simulate_vcz (path , n_variant = 3 , n_sample = 3 )
864+ vdata = tsinfer .VariantData (path , "variant_ancestral_allele" )
865+ assert vdata .ploidy == 2
857866
858867 def test_bad_zarr_spec (self ):
859- ds = zarr .group ()
860- ds ["call_genotype" ] = zarr .array (np .zeros (10 , dtype = np .int8 ))
868+ z = zarr .group ()
869+ z ["call_genotype" ] = zarr .array (np .zeros (10 , dtype = np .int8 ))
861870 with pytest .raises (
862871 ValueError , match = "Expecting a VCF Zarr object with 3D call_genotype array"
863872 ):
864- tsinfer .VariantData (ds , np .zeros (10 , dtype = "<U1" ))
873+ tsinfer .VariantData (z , np .zeros (10 , dtype = "<U1" ))
865874
866875 def test_missing_phase (self , tmp_path ):
867- path = tmp_path / "data.zarr "
868- ds = self .simulate_genotype_call_dataset ( n_variant = 3 , n_sample = 3 )
869- tsutil . save_dataset ( ds , path )
876+ path = tmp_path / "data.vcz "
877+ z = self .simulate_vcz ( path , n_variant = 3 , n_sample = 3 )
878+ del z [ "call_genotype_phased" ]
870879 with pytest .raises (
871880 ValueError , match = "The call_genotype_phased array is missing"
872881 ):
873882 tsinfer .VariantData (path , "variant_ancestral_allele" )
874883
875- def test_phased (self , tmp_path ):
876- path = tmp_path / "data.zarr"
877- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
878- ds ["call_genotype_phased" ] = (
879- ds ["call_genotype" ].dims ,
880- np .ones (ds ["call_genotype" ].shape , dtype = bool ),
881- )
882- tsutil .save_dataset (ds , path )
883- tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
884-
885884 def test_ploidy1_missing_phase (self , tmp_path ):
886885 path = tmp_path / "data.zarr"
887886 # Ploidy==1 is always ok
888- ds = self .simulate_genotype_call_dataset ( n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
889- tsutil . save_dataset ( ds , path )
890- tsinfer .VariantData (path , ds [ "variant_allele" ][:, 0 ]. values . astype ( str ) )
887+ z = self .simulate_vcz ( path , n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
888+ del z [ "call_genotype_phased" ]
889+ tsinfer .VariantData (path , "variant_ancestral_allele" )
891890
892891 def test_ploidy1_unphased (self , tmp_path ):
893892 path = tmp_path / "data.zarr"
894- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
895- ds ["call_genotype_phased" ] = (
896- ds ["call_genotype" ].dims ,
897- np .zeros (ds ["call_genotype" ].shape , dtype = bool ),
898- )
899- tsutil .save_dataset (ds , path )
900- tsinfer .VariantData (path , ds ["variant_allele" ][:, 0 ].values .astype (str ))
893+ z = self .simulate_vcz (path , n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
894+ z ["call_genotype_phased" ][:] = False
895+ tsinfer .VariantData (path , "variant_ancestral_allele" )
901896
902897 def test_duplicate_positions (self , tmp_path ):
903898 path = tmp_path / "data.zarr"
904- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
905- ds ["variant_position" ][2 ] = ds ["variant_position" ][1 ]
906- tsutil .save_dataset (ds , path )
899+ z = self .simulate_vcz (path , n_variant = 3 , n_sample = 3 )
900+ z ["variant_position" ][2 ] = z ["variant_position" ][1 ]
907901 with pytest .raises (ValueError , match = "duplicate or out-of-order values" ):
908902 tsinfer .VariantData (path , "variant_ancestral_allele" )
909903
910904 def test_bad_order_positions (self , tmp_path ):
911905 path = tmp_path / "data.zarr"
912- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
913- ds ["variant_position" ][0 ] = ds ["variant_position" ][2 ] - 0.5
914- tsutil .save_dataset (ds , path )
906+ z = self .simulate_vcz (path , n_variant = 3 , n_sample = 3 )
907+ z ["variant_position" ][0 ] = z ["variant_position" ][2 ] - 0.5
915908 with pytest .raises (ValueError , match = "duplicate or out-of-order values" ):
916909 tsinfer .VariantData (path , "variant_ancestral_allele" )
917910
918911 def test_bad_ancestral_state (self , tmp_path ):
919912 path = tmp_path / "data.zarr"
920- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
921- ancestral_state = ds ["variant_allele" ][:, 0 ].values .astype (str )
922- ancestral_state [1 ] = ""
923- tsutil .save_dataset (ds , path )
913+ z = self .simulate_vcz (path , n_variant = 3 , n_sample = 3 )
914+ z ["variant_ancestral_allele" ][1 ] = ""
924915 with pytest .raises (ValueError , match = "cannot contain empty strings" ):
925- tsinfer .VariantData (path , ancestral_state )
916+ tsinfer .VariantData (path , "variant_ancestral_allele" )
926917
927918 def test_ancestral_state_len_not_same_as_mask (self , tmp_path ):
928919 path = tmp_path / "data.zarr"
929- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
930- tsutil .save_dataset (ds , path )
931- ancestral_state = ds ["variant_allele" ][:, 0 ].values .astype (str )
932- site_mask = np .zeros (ds .sizes ["variants" ], dtype = bool )
920+ z = self .simulate_vcz (path , n_variant = 3 , n_sample = 3 )
921+ site_mask = np .zeros (z ["variant_ancestral_allele" ].shape [0 ], dtype = bool )
933922 site_mask [0 ] = True
934923 with pytest .raises (
935924 ValueError ,
936925 match = "Ancestral state array must be the same length as the number of"
937926 " selected sites" ,
938927 ):
939- tsinfer .VariantData (path , ancestral_state , site_mask = site_mask )
928+ tsinfer .VariantData (
929+ path , z ["variant_ancestral_allele" ][:], site_mask = site_mask
930+ )
940931
941932 def test_empty_alleles_not_at_end (self , tmp_path ):
942933 path = tmp_path / "data.zarr"
943- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
944- ds ["variant_allele" ] = (
945- ds ["variant_allele" ].dims ,
946- np .array ([["A" , "" , "C" ], ["A" , "C" , "" ], ["A" , "C" , "" ]], dtype = "S1" ),
947- )
948- tsutil .save_dataset (ds , path )
934+ z = self .simulate_vcz (path , n_variant = 3 , n_sample = 3 )
935+ z ["variant_allele" ][0 ] = ["" , "T" ]
949936 with pytest .raises (
950937 ValueError , match = 'Bad alleles: fill value "" in middle of list'
951938 ):
952- tsinfer .VariantData (path , ds [ "variant_allele" ][:, 0 ]. values . astype ( str ) )
939+ tsinfer .VariantData (path , "variant_ancestral_allele" )
953940
954941 def test_unique_alleles (self , tmp_path ):
955942 path = tmp_path / "data.zarr"
956- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , n_ploidy = 1 )
957- ds ["variant_allele" ] = (
958- ds ["variant_allele" ].dims ,
959- np .array ([["A" , "C" , "T" ], ["A" , "C" , "" ], ["A" , "A" , "" ]], dtype = "S1" ),
960- )
961- tsutil .save_dataset (ds , path )
943+ z = self .simulate_vcz (path , n_variant = 3 , n_sample = 3 )
944+ z ["variant_allele" ][2 ] = ["T" , "T" ]
962945 with pytest .raises (
963946 ValueError , match = "Duplicate allele values provided at site 2"
964947 ):
965- tsinfer .VariantData (path , np . array ([ "A" , "A" , "A" ], dtype = "S1" ) )
948+ tsinfer .VariantData (path , "variant_ancestral_allele" )
966949
967950 def test_unimplemented_from_tree_sequence (self ):
968951 # NB we should reimplement something like this functionality.
@@ -972,28 +955,23 @@ def test_unimplemented_from_tree_sequence(self):
972955
973956 def test_all_masked (self , tmp_path ):
974957 path = tmp_path / "data.zarr"
975- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
976- tsutil .save_dataset (ds , path )
958+ self .simulate_vcz (path , n_variant = 3 , n_sample = 3 )
977959 with pytest .raises (ValueError , match = "All sites have been masked out" ):
978960 tsinfer .VariantData (
979- path , ds [ "variant_allele" ][:, 0 ]. astype ( str ) , site_mask = np .ones (3 , bool )
961+ path , "variant_ancestral_allele" , site_mask = np .ones (3 , bool )
980962 )
981963
982964 def test_missing_sites_time (self , tmp_path ):
983965 path = tmp_path / "data.zarr"
984- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 , phased = True )
985- tsutil .save_dataset (ds , path )
966+ self .simulate_vcz (path , n_variant = 3 , n_sample = 3 )
986967 with pytest .raises (
987968 ValueError , match = "The sites time array XX was not found in the dataset"
988969 ):
989- tsinfer .VariantData (
990- path , ds ["variant_allele" ][:, 0 ].astype (str ), sites_time = "XX"
991- )
970+ tsinfer .VariantData (path , "variant_ancestral_allele" , sites_time = "XX" )
992971
993972 def test_wrong_individuals_array_length (self , tmp_path ):
994973 path = tmp_path / "data.zarr"
995- ds = self .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 5 , phased = True )
996- tsutil .save_dataset (ds , path )
974+ self .simulate_vcz (path , n_variant = 3 , n_sample = 5 )
997975
998976 # Create arrays with wrong length (too short)
999977 wrong_length_time = np .arange (3 , dtype = np .float32 )
@@ -1007,7 +985,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
1007985 ):
1008986 tsinfer .VariantData (
1009987 path ,
1010- ds [ "variant_allele" ][:, 0 ]. values . astype ( str ) ,
988+ "variant_ancestral_allele" ,
1011989 individuals_time = wrong_length_time ,
1012990 )
1013991
@@ -1016,7 +994,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
1016994 ):
1017995 tsinfer .VariantData (
1018996 path ,
1019- ds [ "variant_allele" ][:, 0 ]. values . astype ( str ) ,
997+ "variant_ancestral_allele" ,
1020998 individuals_location = wrong_length_location ,
1021999 )
10221000
@@ -1025,7 +1003,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
10251003 ):
10261004 tsinfer .VariantData (
10271005 path ,
1028- ds [ "variant_allele" ][:, 0 ]. values . astype ( str ) ,
1006+ "variant_ancestral_allele" ,
10291007 individuals_population = wrong_length_population ,
10301008 )
10311009
@@ -1034,7 +1012,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
10341012 ):
10351013 tsinfer .VariantData (
10361014 path ,
1037- ds [ "variant_allele" ][:, 0 ]. values . astype ( str ) ,
1015+ "variant_ancestral_allele" ,
10381016 individuals_flags = wrong_length_flags ,
10391017 )
10401018
0 commit comments