44
55import numpy as np
66import pytest
7+ import sgkit as sg
78import tskit
9+ import xarray .testing as xt
810import zarr
911
10- from bio2zarr import tskit as ts
12+ from bio2zarr import tskit as tsk
13+ from bio2zarr import vcf
14+
15+
16+ def simple_ts (add_individuals = False ):
17+ tables = tskit .TableCollection (sequence_length = 100 )
18+ for _ in range (4 ):
19+ ind = - 1
20+ if add_individuals :
21+ ind = tables .individuals .add_row ()
22+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = ind )
23+ tables .nodes .add_row (flags = 0 , time = 1 ) # MRCA for 0,1
24+ tables .nodes .add_row (flags = 0 , time = 1 ) # MRCA for 2,3
25+ tables .edges .add_row (left = 0 , right = 100 , parent = 4 , child = 0 )
26+ tables .edges .add_row (left = 0 , right = 100 , parent = 4 , child = 1 )
27+ tables .edges .add_row (left = 0 , right = 100 , parent = 5 , child = 2 )
28+ tables .edges .add_row (left = 0 , right = 100 , parent = 5 , child = 3 )
29+ site_id = tables .sites .add_row (position = 10 , ancestral_state = "A" )
30+ tables .mutations .add_row (site = site_id , node = 4 , derived_state = "TTTT" )
31+ site_id = tables .sites .add_row (position = 20 , ancestral_state = "CCC" )
32+ tables .mutations .add_row (site = site_id , node = 5 , derived_state = "G" )
33+ site_id = tables .sites .add_row (position = 30 , ancestral_state = "G" )
34+ tables .mutations .add_row (site = site_id , node = 0 , derived_state = "AA" )
35+
36+ tables .sort ()
37+ return tables .tree_sequence ()
1138
1239
1340class TestTskit :
1441 def test_simple_tree_sequence (self , tmp_path ):
15- tables = tskit .TableCollection (sequence_length = 100 )
16- tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
17- tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
18- tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
19- tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
20- tables .nodes .add_row (flags = 0 , time = 1 ) # MRCA for 0,1
21- tables .nodes .add_row (flags = 0 , time = 1 ) # MRCA for 2,3
22- tables .edges .add_row (left = 0 , right = 100 , parent = 4 , child = 0 )
23- tables .edges .add_row (left = 0 , right = 100 , parent = 4 , child = 1 )
24- tables .edges .add_row (left = 0 , right = 100 , parent = 5 , child = 2 )
25- tables .edges .add_row (left = 0 , right = 100 , parent = 5 , child = 3 )
26- site_id = tables .sites .add_row (position = 10 , ancestral_state = "A" )
27- tables .mutations .add_row (site = site_id , node = 4 , derived_state = "TTTT" )
28- site_id = tables .sites .add_row (position = 20 , ancestral_state = "CCC" )
29- tables .mutations .add_row (site = site_id , node = 5 , derived_state = "G" )
30- site_id = tables .sites .add_row (position = 30 , ancestral_state = "G" )
31- tables .mutations .add_row (site = site_id , node = 0 , derived_state = "AA" )
32- tables .sort ()
33- tree_sequence = tables .tree_sequence ()
42+ tree_sequence = simple_ts ()
3443 tree_sequence .dump (tmp_path / "test.trees" )
3544
3645 # Manually specify the individuals_nodes, other tests use
37- # ts individuals.
46+ # tsk individuals.
3847 ind_nodes = np .array ([[0 , 1 ], [2 , 3 ]])
3948
4049 with tempfile .TemporaryDirectory () as tempdir :
4150 zarr_path = os .path .join (tempdir , "test_output.zarr" )
42- ts .convert (
51+ tsk .convert (
4352 tmp_path / "test.trees" ,
4453 zarr_path ,
4554 individuals_nodes = ind_nodes ,
@@ -59,7 +68,7 @@ def test_simple_tree_sequence(self, tmp_path):
5968 lengths = zroot ["variant_length" ][:]
6069 assert lengths .shape == (3 ,)
6170 assert lengths .dtype == np .int8
62- assert np .array_equal (lengths , [4 , 3 , 2 ])
71+ assert np .array_equal (lengths , [1 , 3 , 1 ])
6372
6473 genotypes = zroot ["call_genotype" ][:]
6574 assert genotypes .shape == (3 , 2 , 2 )
@@ -91,7 +100,7 @@ def test_simple_tree_sequence(self, tmp_path):
91100 region_index = zroot ["region_index" ][:]
92101 assert region_index .shape == (1 , 6 )
93102 assert region_index .dtype == np .int8
94- assert np .array_equal (region_index , [[0 , 0 , 10 , 30 , 31 , 3 ]])
103+ assert np .array_equal (region_index , [[0 , 0 , 10 , 30 , 30 , 3 ]])
95104
96105 assert set (zroot .array_keys ()) == {
97106 "variant_position" ,
@@ -112,7 +121,7 @@ def test_missing_dependency(self):
112121 side_effect = ImportError ("No module named 'tskit'" ),
113122 ):
114123 with pytest .raises (ImportError ) as exc_info :
115- ts .convert (
124+ tsk .convert (
116125 "UNUSED_PATH" ,
117126 "UNUSED_PATH" ,
118127 )
@@ -193,15 +202,15 @@ def test_position_dtype_selection(self, tmp_path):
193202 ts_large .dump (ts_path_large )
194203
195204 ind_nodes = np .array ([[0 ], [1 ]])
196- format_obj_small = ts .TskitFormat (ts_path_small , individuals_nodes = ind_nodes )
205+ format_obj_small = tsk .TskitFormat (ts_path_small , individuals_nodes = ind_nodes )
197206 schema_small = format_obj_small .generate_schema ()
198207
199208 position_field = next (
200209 f for f in schema_small .fields if f .name == "variant_position"
201210 )
202211 assert position_field .dtype == "i1"
203212
204- format_obj_large = ts .TskitFormat (ts_path_large , individuals_nodes = ind_nodes )
213+ format_obj_large = tsk .TskitFormat (ts_path_large , individuals_nodes = ind_nodes )
205214 schema_large = format_obj_large .generate_schema ()
206215
207216 position_field = next (
@@ -213,14 +222,14 @@ def test_initialization(self, simple_ts):
213222 ts_path , tree_sequence = simple_ts
214223
215224 # Test with default parameters
216- format_obj = ts .TskitFormat (ts_path )
225+ format_obj = tsk .TskitFormat (ts_path )
217226 assert format_obj .path == ts_path
218227 assert format_obj .ts .num_sites == tree_sequence .num_sites
219228 assert format_obj .contig_id == "1"
220229 assert not format_obj .isolated_as_missing
221230
222231 # Test with custom parameters
223- format_obj = ts .TskitFormat (
232+ format_obj = tsk .TskitFormat (
224233 ts_path ,
225234 sample_ids = ["ind1" , "ind2" ],
226235 contig_id = "chr1" ,
@@ -234,7 +243,7 @@ def test_initialization(self, simple_ts):
234243
235244 def test_basic_properties (self , simple_ts ):
236245 ts_path , _ = simple_ts
237- format_obj = ts .TskitFormat (ts_path )
246+ format_obj = tsk .TskitFormat (ts_path )
238247
239248 assert format_obj .num_records == format_obj .ts .num_sites
240249 assert format_obj .num_samples == 2 # Two individuals
@@ -251,7 +260,7 @@ def test_basic_properties(self, simple_ts):
251260 def test_custom_sample_ids (self , simple_ts ):
252261 ts_path , _ = simple_ts
253262 custom_ids = ["sample_X" , "sample_Y" ]
254- format_obj = ts .TskitFormat (ts_path , sample_ids = custom_ids )
263+ format_obj = tsk .TskitFormat (ts_path , sample_ids = custom_ids )
255264
256265 assert format_obj .num_samples == 2
257266 assert len (format_obj .samples ) == 2
@@ -262,11 +271,11 @@ def test_sample_id_length_mismatch(self, simple_ts):
262271 ts_path , _ = simple_ts
263272 # Wrong number of sample IDs
264273 with pytest .raises (ValueError , match = "Length of sample_ids.*does not match" ):
265- ts .TskitFormat (ts_path , sample_ids = ["only_one_id" ])
274+ tsk .TskitFormat (ts_path , sample_ids = ["only_one_id" ])
266275
267276 def test_schema_generation (self , simple_ts ):
268277 ts_path , _ = simple_ts
269- format_obj = ts .TskitFormat (ts_path )
278+ format_obj = tsk .TskitFormat (ts_path )
270279
271280 schema = format_obj .generate_schema ()
272281 assert schema .dimensions ["variants" ].size == 3
@@ -289,13 +298,13 @@ def test_schema_generation(self, simple_ts):
289298
290299 def test_iter_contig (self , simple_ts ):
291300 ts_path , _ = simple_ts
292- format_obj = ts .TskitFormat (ts_path )
301+ format_obj = tsk .TskitFormat (ts_path )
293302 contig_indices = list (format_obj .iter_contig (1 , 3 ))
294303 assert contig_indices == [0 , 0 ]
295304
296305 def test_iter_field (self , simple_ts ):
297306 ts_path , _ = simple_ts
298- format_obj = ts .TskitFormat (ts_path )
307+ format_obj = tsk .TskitFormat (ts_path )
299308 positions = list (format_obj .iter_field ("position" , None , 0 , 3 ))
300309 assert positions == [10 , 20 , 30 ]
301310 positions = list (format_obj .iter_field ("position" , None , 1 , 3 ))
@@ -341,7 +350,7 @@ def test_iter_field(self, simple_ts):
341350 def test_iter_alleles_and_genotypes (self , simple_ts , ind_nodes , expected_gts ):
342351 ts_path , _ = simple_ts
343352
344- format_obj = ts .TskitFormat (ts_path , individuals_nodes = ind_nodes )
353+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = ind_nodes )
345354
346355 shape = (2 , 2 ) # (num_samples, max_ploidy)
347356 results = list (format_obj .iter_alleles_and_genotypes (0 , 3 , shape , 2 ))
@@ -350,7 +359,7 @@ def test_iter_alleles_and_genotypes(self, simple_ts, ind_nodes, expected_gts):
350359
351360 for i , variant_data in enumerate (results ):
352361 if i == 0 :
353- assert variant_data .variant_length == 2
362+ assert variant_data .variant_length == 1
354363 assert np .array_equal (variant_data .alleles , ("A" , "TT" ))
355364 elif i == 1 :
356365 assert variant_data .variant_length == 3
@@ -371,7 +380,7 @@ def test_iter_alleles_and_genotypes_errors(self, simple_ts):
371380
372381 # Test with node ID that doesn't exist in tree sequence (out of range)
373382 invalid_nodes = np .array ([[10 , 11 ], [12 , 13 ]], dtype = np .int32 )
374- format_obj = ts .TskitFormat (ts_path , individuals_nodes = invalid_nodes )
383+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = invalid_nodes )
375384 shape = (2 , 2 )
376385 with pytest .raises (
377386 tskit .LibraryError , match = "out of bounds"
@@ -383,23 +392,23 @@ def test_iter_alleles_and_genotypes_errors(self, simple_ts):
383392 with pytest .raises (
384393 ValueError , match = "individuals_nodes must have at least one sample"
385394 ):
386- format_obj = ts .TskitFormat (ts_path , individuals_nodes = empty_nodes )
395+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = empty_nodes )
387396
388397 # Test with all invalid nodes (-1)
389398 all_invalid = np .full ((2 , 2 ), - 1 , dtype = np .int32 )
390399 with pytest .raises (
391400 ValueError , match = "individuals_nodes must have at least one valid sample"
392401 ):
393- format_obj = ts .TskitFormat (ts_path , individuals_nodes = all_invalid )
402+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = all_invalid )
394403
395404 def test_isolated_as_missing (self , tmp_path ):
396- def insert_branch_sites (ts , m = 1 ):
405+ def insert_branch_sites (tsk , m = 1 ):
397406 if m == 0 :
398- return ts
399- tables = ts .dump_tables ()
407+ return tsk
408+ tables = tsk .dump_tables ()
400409 tables .sites .clear ()
401410 tables .mutations .clear ()
402- for tree in ts .trees ():
411+ for tree in tsk .trees ():
403412 left , right = tree .interval
404413 delta = (right - left ) / (m * len (list (tree .nodes ())))
405414 x = left
@@ -422,7 +431,7 @@ def insert_branch_sites(ts, m=1):
422431 ts_path = tmp_path / "isolated_sample.trees"
423432 tree_sequence .dump (ts_path )
424433 ind_nodes = np .array ([[0 ], [1 ], [3 ]])
425- format_obj_default = ts .TskitFormat (
434+ format_obj_default = tsk .TskitFormat (
426435 ts_path , individuals_nodes = ind_nodes , isolated_as_missing = False
427436 )
428437 shape = (3 , 1 ) # (num_samples, max_ploidy)
@@ -438,7 +447,7 @@ def insert_branch_sites(ts, m=1):
438447 expected_gt_default = np .array ([[1 ], [0 ], [0 ]])
439448 assert np .array_equal (variant_data_default .genotypes , expected_gt_default )
440449
441- format_obj_missing = ts .TskitFormat (
450+ format_obj_missing = tsk .TskitFormat (
442451 ts_path , individuals_nodes = ind_nodes , isolated_as_missing = True
443452 )
444453 results_missing = list (
@@ -469,7 +478,7 @@ def test_genotype_dtype_selection(self, tmp_path):
469478 tree_sequence .dump (ts_path )
470479
471480 ind_nodes = np .array ([[0 , 1 ], [2 , 3 ]])
472- format_obj = ts .TskitFormat (ts_path , individuals_nodes = ind_nodes )
481+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = ind_nodes )
473482 schema = format_obj .generate_schema ()
474483 call_genotype_spec = next (s for s in schema .fields if s .name == "call_genotype" )
475484 assert call_genotype_spec .dtype == "i1"
@@ -489,7 +498,36 @@ def test_genotype_dtype_selection(self, tmp_path):
489498 ts_path = tmp_path / "large_alleles.trees"
490499 tree_sequence .dump (ts_path )
491500
492- format_obj = ts .TskitFormat (ts_path , individuals_nodes = ind_nodes )
501+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = ind_nodes )
493502 schema = format_obj .generate_schema ()
494503 call_genotype_spec = next (s for s in schema .fields if s .name == "call_genotype" )
495504 assert call_genotype_spec .dtype == "i4"
505+
506+
507+ @pytest .mark .parametrize (
508+ "ts" ,
509+ [
510+ simple_ts (add_individuals = True ),
511+ ],
512+ )
513+ def test_against_tskit_vcf_output (ts , tmp_path ):
514+ vcf_path = tmp_path / "ts.vcf"
515+ ts_path = tmp_path / "ts.trees"
516+ ts .dump (ts_path )
517+ with open (vcf_path , "w" ) as f :
518+ ts .write_vcf (f )
519+
520+ tskit_zarr = tmp_path / "tskit.zarr"
521+ vcf_zarr = tmp_path / "vcf.zarr"
522+ tsk .convert (ts_path , tskit_zarr )
523+
524+ vcf .convert ([vcf_path ], vcf_zarr )
525+ ds1 = sg .load_dataset (tskit_zarr )
526+ ds2 = (
527+ sg .load_dataset (vcf_zarr )
528+ .drop_dims ("filters" )
529+ .drop_vars (
530+ ["variant_id" , "variant_id_mask" , "variant_quality" , "contig_length" ]
531+ )
532+ )
533+ xt .assert_equal (ds1 , ds2 )
0 commit comments