@@ -133,146 +133,141 @@ def test_sgkit_individual_metadata_not_clobbered(tmp_path):
133133
134134
135135@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
136- def test_sgkit_dataset_accessors (tmp_path ):
137- ts , zarr_path = tsutil .make_ts_and_zarr (
138- tmp_path , add_optional = True , shuffle_alleles = False
139- )
140- samples = tsinfer .VariantData (
141- zarr_path , "variant_ancestral_allele" , sites_time = "sites_time"
142- )
143- ds = sgkit .load_dataset (zarr_path )
144-
145- assert samples .format_name == "tsinfer-variant-data"
146- assert samples .format_version == (0 , 1 )
147- assert samples .finalised
148- assert samples .sequence_length == ts .sequence_length + 1337
149- assert samples .num_sites == ts .num_sites
150- assert samples .sites_metadata_schema == ts .tables .sites .metadata_schema .schema
151- assert samples .sites_metadata == [site .metadata for site in ts .sites ()]
152- assert np .array_equal (samples .sites_time , np .arange (ts .num_sites ) / ts .num_sites )
153- assert np .array_equal (samples .sites_position , ts .tables .sites .position )
154- for alleles , v in zip (samples .sites_alleles , ts .variants ()):
136+ @pytest .mark .parametrize ("in_mem" , [True , False ])
137+ def test_variantdata_accessors (tmp_path , in_mem ):
138+ path = None if in_mem else tmp_path
139+ ts , data = tsutil .make_ts_and_zarr (path , add_optional = True , shuffle_alleles = False )
140+ vd = tsinfer .VariantData (data , "variant_ancestral_allele" , sites_time = "sites_time" )
141+ ds = data if in_mem else sgkit .load_dataset (data )
142+
143+ assert vd .format_name == "tsinfer-variant-data"
144+ assert vd .format_version == (0 , 1 )
145+ assert vd .finalised
146+ assert vd .sequence_length == ts .sequence_length + 1337
147+ assert vd .num_sites == ts .num_sites
148+ assert vd .sites_metadata_schema == ts .tables .sites .metadata_schema .schema
149+ assert vd .sites_metadata == [site .metadata for site in ts .sites ()]
150+ assert np .array_equal (vd .sites_time , np .arange (ts .num_sites ) / ts .num_sites )
151+ assert np .array_equal (vd .sites_position , ts .tables .sites .position )
152+ for alleles , v in zip (vd .sites_alleles , ts .variants ()):
155153 # sgkit alleles are padded to be rectangular
156154 assert np .all (alleles [: len (v .alleles )] == v .alleles )
157155 assert np .all (alleles [len (v .alleles ) :] == "" )
158- assert np .array_equal (samples .sites_select , np .ones (ts .num_sites , dtype = bool ))
156+ assert np .array_equal (vd .sites_select , np .ones (ts .num_sites , dtype = bool ))
159157 assert np .array_equal (
160- samples .sites_ancestral_allele , np .zeros (ts .num_sites , dtype = np .int8 )
158+ vd .sites_ancestral_allele , np .zeros (ts .num_sites , dtype = np .int8 )
161159 )
162- assert np .array_equal (samples .sites_genotypes , ts .genotype_matrix ())
160+ assert np .array_equal (vd .sites_genotypes , ts .genotype_matrix ())
163161 assert np .array_equal (
164- samples .provenances_timestamp , ["2021-01-01T00:00:00" , "2021-01-02T00:00:00" ]
162+ vd .provenances_timestamp , ["2021-01-01T00:00:00" , "2021-01-02T00:00:00" ]
165163 )
166- assert samples .provenances_record == [{"foo" : 1 }, {"foo" : 2 }]
167- assert samples .num_samples == ts .num_samples
164+ assert vd .provenances_record == [{"foo" : 1 }, {"foo" : 2 }]
165+ assert vd .num_samples == ts .num_samples
168166 assert np .array_equal (
169- samples .samples_individual , np .repeat (np .arange (ts .num_samples // 3 ), 3 )
167+ vd .samples_individual , np .repeat (np .arange (ts .num_samples // 3 ), 3 )
170168 )
171- assert samples .metadata_schema == tsutil .example_schema ("example" ).schema
172- assert samples .metadata == ts .tables .metadata
169+ assert vd .metadata_schema == tsutil .example_schema ("example" ).schema
170+ assert vd .metadata == ts .tables .metadata
173171 assert (
174- samples .populations_metadata_schema
175- == ts .tables .populations .metadata_schema .schema
172+ vd .populations_metadata_schema == ts .tables .populations .metadata_schema .schema
176173 )
177- assert samples .populations_metadata == [pop .metadata for pop in ts .populations ()]
178- assert samples .num_individuals == ts .num_individuals
174+ assert vd .populations_metadata == [pop .metadata for pop in ts .populations ()]
175+ assert vd .num_individuals == ts .num_individuals
179176 assert np .array_equal (
180- samples .individuals_time , np .arange (ts .num_individuals , dtype = np .float32 )
177+ vd .individuals_time , np .arange (ts .num_individuals , dtype = np .float32 )
181178 )
182179 assert (
183- samples .individuals_metadata_schema
184- == ts .tables .individuals .metadata_schema .schema
180+ vd .individuals_metadata_schema == ts .tables .individuals .metadata_schema .schema
185181 )
186- assert samples .individuals_metadata == [
182+ assert vd .individuals_metadata == [
187183 {"variant_data_sample_id" : sample_id , ** ind .metadata }
188- for ind , sample_id in zip (ts .individuals (), ds [ " sample_id" ]. values )
184+ for ind , sample_id in zip (ts .individuals (), ds . sample_id [:] )
189185 ]
190186 assert np .array_equal (
191- samples .individuals_location ,
187+ vd .individuals_location ,
192188 np .tile (np .array ([["0" , "1" ]], dtype = "float32" ), (ts .num_individuals , 1 )),
193189 )
194190 assert np .array_equal (
195- samples .individuals_population , np .zeros (ts .num_individuals , dtype = "int32" )
191+ vd .individuals_population , np .zeros (ts .num_individuals , dtype = "int32" )
196192 )
197193 assert np .array_equal (
198- samples .individuals_flags ,
194+ vd .individuals_flags ,
199195 np .random .RandomState (42 ).randint (
200196 0 , 2_000_000 , ts .num_individuals , dtype = "int32"
201197 ),
202198 )
203199
204200 # Need to shuffle for the ancestral allele test
205- ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path , add_optional = True )
206- samples = tsinfer .VariantData (zarr_path , "variant_ancestral_allele" )
201+ ts , data = tsutil .make_ts_and_zarr (path , add_optional = True )
202+ vd = tsinfer .VariantData (data , "variant_ancestral_allele" )
207203 for i in range (ts .num_sites ):
208204 assert (
209- samples .sites_alleles [i ][samples .sites_ancestral_allele [i ]]
205+ vd .sites_alleles [i ][vd .sites_ancestral_allele [i ]]
210206 == ts .site (i ).ancestral_state
211207 )
212208
213209
214210@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
215- def test_sgkit_accessors_defaults (tmp_path ):
216- ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
217- samples = tsinfer .VariantData (zarr_path , "variant_ancestral_allele" )
218- ds = sgkit .load_dataset (zarr_path )
211+ @pytest .mark .parametrize ("in_mem" , [True , False ])
212+ def test_variantdata_accessors_defaults (tmp_path , in_mem ):
213+ path = None if in_mem else tmp_path
214+ ts , data = tsutil .make_ts_and_zarr (path )
215+ vdata = tsinfer .VariantData (data , "variant_ancestral_allele" )
216+ ds = data if in_mem else sgkit .load_dataset (data )
219217
220218 default_schema = tskit .MetadataSchema .permissive_json ().schema
221- assert samples .sequence_length == ts .sequence_length
222- assert samples .sites_metadata_schema == default_schema
223- assert samples .sites_metadata == [{} for _ in range (ts .num_sites )]
224- for time in samples .sites_time :
219+ assert vdata .sequence_length == ts .sequence_length
220+ assert vdata .sites_metadata_schema == default_schema
221+ assert vdata .sites_metadata == [{} for _ in range (ts .num_sites )]
222+ for time in vdata .sites_time :
225223 assert tskit .is_unknown_time (time )
226- assert np .array_equal (samples .sites_select , np .ones (ts .num_sites , dtype = bool ))
227- assert np .array_equal (samples .provenances_timestamp , [])
228- assert np .array_equal (samples .provenances_record , [])
229- assert samples .metadata_schema == default_schema
230- assert samples .metadata == {}
231- assert samples .populations_metadata_schema == default_schema
232- assert samples .populations_metadata == []
233- assert samples .individuals_metadata_schema == default_schema
234- assert samples .individuals_metadata == [
235- {"variant_data_sample_id" : sample_id } for sample_id in ds [ " sample_id" ]. values
224+ assert np .array_equal (vdata .sites_select , np .ones (ts .num_sites , dtype = bool ))
225+ assert np .array_equal (vdata .provenances_timestamp , [])
226+ assert np .array_equal (vdata .provenances_record , [])
227+ assert vdata .metadata_schema == default_schema
228+ assert vdata .metadata == {}
229+ assert vdata .populations_metadata_schema == default_schema
230+ assert vdata .populations_metadata == []
231+ assert vdata .individuals_metadata_schema == default_schema
232+ assert vdata .individuals_metadata == [
233+ {"variant_data_sample_id" : sample_id } for sample_id in ds . sample_id [:]
236234 ]
237- for time in samples .individuals_time :
235+ for time in vdata .individuals_time :
238236 assert tskit .is_unknown_time (time )
239237 assert np .array_equal (
240- samples .individuals_location , np .array ([[]] * ts .num_individuals , dtype = float )
238+ vdata .individuals_location , np .array ([[]] * ts .num_individuals , dtype = float )
241239 )
242240 assert np .array_equal (
243- samples .individuals_population , np .full (ts .num_individuals , tskit .NULL )
241+ vdata .individuals_population , np .full (ts .num_individuals , tskit .NULL )
244242 )
245243 assert np .array_equal (
246- samples .individuals_flags , np .zeros (ts .num_individuals , dtype = int )
244+ vdata .individuals_flags , np .zeros (ts .num_individuals , dtype = int )
247245 )
248246
249247
250248@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
251- def test_variantdata_sites_time_default (tmp_path ):
252- ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
253- samples = tsinfer .VariantData (zarr_path , "variant_ancestral_allele" )
249+ def test_variantdata_sites_time_default ():
250+ ts , data = tsutil .make_ts_and_zarr ()
251+ vdata = tsinfer .VariantData (data , "variant_ancestral_allele" )
254252
255253 assert (
256- np .all (np .isnan (samples .sites_time ))
257- and samples .sites_time .size == samples .num_sites
254+ np .all (np .isnan (vdata .sites_time )) and vdata .sites_time .size == vdata .num_sites
258255 )
259256
260257
261258@pytest .mark .skipif (sys .platform == "win32" , reason = "No cyvcf2 on windows" )
262- def test_variantdata_sites_time_array (tmp_path ):
263- ts , zarr_path = tsutil .make_ts_and_zarr (tmp_path )
259+ def test_variantdata_sites_time_array ():
260+ ts , data = tsutil .make_ts_and_zarr ()
264261 sites_time = np .arange (ts .num_sites )
265- samples = tsinfer .VariantData (
266- zarr_path , "variant_ancestral_allele" , sites_time = sites_time
267- )
268- assert np .array_equal (samples .sites_time , sites_time )
262+ vdata = tsinfer .VariantData (data , "variant_ancestral_allele" , sites_time = sites_time )
263+ assert np .array_equal (vdata .sites_time , sites_time )
269264 wrong_length_sites_time = np .arange (ts .num_sites + 1 )
270265 with pytest .raises (
271266 ValueError ,
272267 match = "Sites time array must be the same length as the number of selected sites" ,
273268 ):
274269 tsinfer .VariantData (
275- zarr_path ,
270+ data ,
276271 "variant_ancestral_allele" ,
277272 sites_time = wrong_length_sites_time ,
278273 )
@@ -302,17 +297,17 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
302297 for i in sites :
303298 sites_mask [i ] = False
304299 tsutil .add_array_to_dataset ("variant_mask_42" , sites_mask , zarr_path )
305- samples = tsinfer .VariantData (
300+ vdata = tsinfer .VariantData (
306301 zarr_path ,
307302 "variant_ancestral_allele" ,
308303 site_mask = "variant_mask_42" ,
309304 )
310- assert samples .num_sites == len (sites )
311- assert np .array_equal (samples .sites_select , ~ sites_mask )
305+ assert vdata .num_sites == len (sites )
306+ assert np .array_equal (vdata .sites_select , ~ sites_mask )
312307 assert np .array_equal (
313- samples .sites_position , ts .tables .sites .position [~ sites_mask ]
308+ vdata .sites_position , ts .tables .sites .position [~ sites_mask ]
314309 )
315- inf_ts = tsinfer .infer (samples )
310+ inf_ts = tsinfer .infer (vdata )
316311 assert np .array_equal (
317312 ts .genotype_matrix ()[~ sites_mask ], inf_ts .genotype_matrix ()
318313 )
@@ -675,6 +670,14 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):
675670
676671
677672class TestVariantDataErrors :
673+ def test_bad_zarr_spec (self ):
674+ ds = zarr .group ()
675+ ds ["call_genotype" ] = zarr .array (np .zeros (10 , dtype = np .int8 ))
676+ with pytest .raises (
677+ ValueError , match = "Expecting a VCF Zarr object with 3D call_genotype array"
678+ ):
679+ tsinfer .VariantData (ds , np .zeros (10 , dtype = "<U1" ))
680+
678681 def test_missing_phase (self , tmp_path ):
679682 path = tmp_path / "data.zarr"
680683 ds = sgkit .simulate_genotype_call_dataset (n_variant = 3 , n_sample = 3 )
0 commit comments