4
4
5
5
import numpy as np
6
6
import pytest
7
+ import sgkit as sg
7
8
import tskit
9
+ import xarray .testing as xt
8
10
import zarr
9
11
10
- from bio2zarr import tskit as ts
12
+ from bio2zarr import tskit as tsk
13
+ from bio2zarr import vcf
14
+
15
+
16
+ def simple_ts (add_individuals = False ):
17
+ tables = tskit .TableCollection (sequence_length = 100 )
18
+ for _ in range (4 ):
19
+ ind = - 1
20
+ if add_individuals :
21
+ ind = tables .individuals .add_row ()
22
+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = ind )
23
+ tables .nodes .add_row (flags = 0 , time = 1 ) # MRCA for 0,1
24
+ tables .nodes .add_row (flags = 0 , time = 1 ) # MRCA for 2,3
25
+ tables .edges .add_row (left = 0 , right = 100 , parent = 4 , child = 0 )
26
+ tables .edges .add_row (left = 0 , right = 100 , parent = 4 , child = 1 )
27
+ tables .edges .add_row (left = 0 , right = 100 , parent = 5 , child = 2 )
28
+ tables .edges .add_row (left = 0 , right = 100 , parent = 5 , child = 3 )
29
+ site_id = tables .sites .add_row (position = 10 , ancestral_state = "A" )
30
+ tables .mutations .add_row (site = site_id , node = 4 , derived_state = "TTTT" )
31
+ site_id = tables .sites .add_row (position = 20 , ancestral_state = "CCC" )
32
+ tables .mutations .add_row (site = site_id , node = 5 , derived_state = "G" )
33
+ site_id = tables .sites .add_row (position = 30 , ancestral_state = "G" )
34
+ tables .mutations .add_row (site = site_id , node = 0 , derived_state = "AA" )
35
+
36
+ tables .sort ()
37
+ return tables .tree_sequence ()
11
38
12
39
13
40
class TestTskit :
14
41
def test_simple_tree_sequence (self , tmp_path ):
15
- tables = tskit .TableCollection (sequence_length = 100 )
16
- tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
17
- tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
18
- tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
19
- tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
20
- tables .nodes .add_row (flags = 0 , time = 1 ) # MRCA for 0,1
21
- tables .nodes .add_row (flags = 0 , time = 1 ) # MRCA for 2,3
22
- tables .edges .add_row (left = 0 , right = 100 , parent = 4 , child = 0 )
23
- tables .edges .add_row (left = 0 , right = 100 , parent = 4 , child = 1 )
24
- tables .edges .add_row (left = 0 , right = 100 , parent = 5 , child = 2 )
25
- tables .edges .add_row (left = 0 , right = 100 , parent = 5 , child = 3 )
26
- site_id = tables .sites .add_row (position = 10 , ancestral_state = "A" )
27
- tables .mutations .add_row (site = site_id , node = 4 , derived_state = "TTTT" )
28
- site_id = tables .sites .add_row (position = 20 , ancestral_state = "CCC" )
29
- tables .mutations .add_row (site = site_id , node = 5 , derived_state = "G" )
30
- site_id = tables .sites .add_row (position = 30 , ancestral_state = "G" )
31
- tables .mutations .add_row (site = site_id , node = 0 , derived_state = "AA" )
32
- tables .sort ()
33
- tree_sequence = tables .tree_sequence ()
42
+ tree_sequence = simple_ts ()
34
43
tree_sequence .dump (tmp_path / "test.trees" )
35
44
36
45
# Manually specify the individuals_nodes, other tests use
37
- # ts individuals.
46
+ # tsk individuals.
38
47
ind_nodes = np .array ([[0 , 1 ], [2 , 3 ]])
39
48
40
49
with tempfile .TemporaryDirectory () as tempdir :
41
50
zarr_path = os .path .join (tempdir , "test_output.zarr" )
42
- ts .convert (
51
+ tsk .convert (
43
52
tmp_path / "test.trees" ,
44
53
zarr_path ,
45
54
individuals_nodes = ind_nodes ,
@@ -59,7 +68,7 @@ def test_simple_tree_sequence(self, tmp_path):
59
68
lengths = zroot ["variant_length" ][:]
60
69
assert lengths .shape == (3 ,)
61
70
assert lengths .dtype == np .int8
62
- assert np .array_equal (lengths , [4 , 3 , 2 ])
71
+ assert np .array_equal (lengths , [1 , 3 , 1 ])
63
72
64
73
genotypes = zroot ["call_genotype" ][:]
65
74
assert genotypes .shape == (3 , 2 , 2 )
@@ -91,7 +100,7 @@ def test_simple_tree_sequence(self, tmp_path):
91
100
region_index = zroot ["region_index" ][:]
92
101
assert region_index .shape == (1 , 6 )
93
102
assert region_index .dtype == np .int8
94
- assert np .array_equal (region_index , [[0 , 0 , 10 , 30 , 31 , 3 ]])
103
+ assert np .array_equal (region_index , [[0 , 0 , 10 , 30 , 30 , 3 ]])
95
104
96
105
assert set (zroot .array_keys ()) == {
97
106
"variant_position" ,
@@ -112,7 +121,7 @@ def test_missing_dependency(self):
112
121
side_effect = ImportError ("No module named 'tskit'" ),
113
122
):
114
123
with pytest .raises (ImportError ) as exc_info :
115
- ts .convert (
124
+ tsk .convert (
116
125
"UNUSED_PATH" ,
117
126
"UNUSED_PATH" ,
118
127
)
@@ -193,15 +202,15 @@ def test_position_dtype_selection(self, tmp_path):
193
202
ts_large .dump (ts_path_large )
194
203
195
204
ind_nodes = np .array ([[0 ], [1 ]])
196
- format_obj_small = ts .TskitFormat (ts_path_small , individuals_nodes = ind_nodes )
205
+ format_obj_small = tsk .TskitFormat (ts_path_small , individuals_nodes = ind_nodes )
197
206
schema_small = format_obj_small .generate_schema ()
198
207
199
208
position_field = next (
200
209
f for f in schema_small .fields if f .name == "variant_position"
201
210
)
202
211
assert position_field .dtype == "i1"
203
212
204
- format_obj_large = ts .TskitFormat (ts_path_large , individuals_nodes = ind_nodes )
213
+ format_obj_large = tsk .TskitFormat (ts_path_large , individuals_nodes = ind_nodes )
205
214
schema_large = format_obj_large .generate_schema ()
206
215
207
216
position_field = next (
@@ -213,14 +222,14 @@ def test_initialization(self, simple_ts):
213
222
ts_path , tree_sequence = simple_ts
214
223
215
224
# Test with default parameters
216
- format_obj = ts .TskitFormat (ts_path )
225
+ format_obj = tsk .TskitFormat (ts_path )
217
226
assert format_obj .path == ts_path
218
227
assert format_obj .ts .num_sites == tree_sequence .num_sites
219
228
assert format_obj .contig_id == "1"
220
229
assert not format_obj .isolated_as_missing
221
230
222
231
# Test with custom parameters
223
- format_obj = ts .TskitFormat (
232
+ format_obj = tsk .TskitFormat (
224
233
ts_path ,
225
234
sample_ids = ["ind1" , "ind2" ],
226
235
contig_id = "chr1" ,
@@ -234,7 +243,7 @@ def test_initialization(self, simple_ts):
234
243
235
244
def test_basic_properties (self , simple_ts ):
236
245
ts_path , _ = simple_ts
237
- format_obj = ts .TskitFormat (ts_path )
246
+ format_obj = tsk .TskitFormat (ts_path )
238
247
239
248
assert format_obj .num_records == format_obj .ts .num_sites
240
249
assert format_obj .num_samples == 2 # Two individuals
@@ -251,7 +260,7 @@ def test_basic_properties(self, simple_ts):
251
260
def test_custom_sample_ids (self , simple_ts ):
252
261
ts_path , _ = simple_ts
253
262
custom_ids = ["sample_X" , "sample_Y" ]
254
- format_obj = ts .TskitFormat (ts_path , sample_ids = custom_ids )
263
+ format_obj = tsk .TskitFormat (ts_path , sample_ids = custom_ids )
255
264
256
265
assert format_obj .num_samples == 2
257
266
assert len (format_obj .samples ) == 2
@@ -262,11 +271,11 @@ def test_sample_id_length_mismatch(self, simple_ts):
262
271
ts_path , _ = simple_ts
263
272
# Wrong number of sample IDs
264
273
with pytest .raises (ValueError , match = "Length of sample_ids.*does not match" ):
265
- ts .TskitFormat (ts_path , sample_ids = ["only_one_id" ])
274
+ tsk .TskitFormat (ts_path , sample_ids = ["only_one_id" ])
266
275
267
276
def test_schema_generation (self , simple_ts ):
268
277
ts_path , _ = simple_ts
269
- format_obj = ts .TskitFormat (ts_path )
278
+ format_obj = tsk .TskitFormat (ts_path )
270
279
271
280
schema = format_obj .generate_schema ()
272
281
assert schema .dimensions ["variants" ].size == 3
@@ -289,13 +298,13 @@ def test_schema_generation(self, simple_ts):
289
298
290
299
def test_iter_contig (self , simple_ts ):
291
300
ts_path , _ = simple_ts
292
- format_obj = ts .TskitFormat (ts_path )
301
+ format_obj = tsk .TskitFormat (ts_path )
293
302
contig_indices = list (format_obj .iter_contig (1 , 3 ))
294
303
assert contig_indices == [0 , 0 ]
295
304
296
305
def test_iter_field (self , simple_ts ):
297
306
ts_path , _ = simple_ts
298
- format_obj = ts .TskitFormat (ts_path )
307
+ format_obj = tsk .TskitFormat (ts_path )
299
308
positions = list (format_obj .iter_field ("position" , None , 0 , 3 ))
300
309
assert positions == [10 , 20 , 30 ]
301
310
positions = list (format_obj .iter_field ("position" , None , 1 , 3 ))
@@ -341,7 +350,7 @@ def test_iter_field(self, simple_ts):
341
350
def test_iter_alleles_and_genotypes (self , simple_ts , ind_nodes , expected_gts ):
342
351
ts_path , _ = simple_ts
343
352
344
- format_obj = ts .TskitFormat (ts_path , individuals_nodes = ind_nodes )
353
+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = ind_nodes )
345
354
346
355
shape = (2 , 2 ) # (num_samples, max_ploidy)
347
356
results = list (format_obj .iter_alleles_and_genotypes (0 , 3 , shape , 2 ))
@@ -350,7 +359,7 @@ def test_iter_alleles_and_genotypes(self, simple_ts, ind_nodes, expected_gts):
350
359
351
360
for i , variant_data in enumerate (results ):
352
361
if i == 0 :
353
- assert variant_data .variant_length == 2
362
+ assert variant_data .variant_length == 1
354
363
assert np .array_equal (variant_data .alleles , ("A" , "TT" ))
355
364
elif i == 1 :
356
365
assert variant_data .variant_length == 3
@@ -371,7 +380,7 @@ def test_iter_alleles_and_genotypes_errors(self, simple_ts):
371
380
372
381
# Test with node ID that doesn't exist in tree sequence (out of range)
373
382
invalid_nodes = np .array ([[10 , 11 ], [12 , 13 ]], dtype = np .int32 )
374
- format_obj = ts .TskitFormat (ts_path , individuals_nodes = invalid_nodes )
383
+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = invalid_nodes )
375
384
shape = (2 , 2 )
376
385
with pytest .raises (
377
386
tskit .LibraryError , match = "out of bounds"
@@ -383,23 +392,23 @@ def test_iter_alleles_and_genotypes_errors(self, simple_ts):
383
392
with pytest .raises (
384
393
ValueError , match = "individuals_nodes must have at least one sample"
385
394
):
386
- format_obj = ts .TskitFormat (ts_path , individuals_nodes = empty_nodes )
395
+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = empty_nodes )
387
396
388
397
# Test with all invalid nodes (-1)
389
398
all_invalid = np .full ((2 , 2 ), - 1 , dtype = np .int32 )
390
399
with pytest .raises (
391
400
ValueError , match = "individuals_nodes must have at least one valid sample"
392
401
):
393
- format_obj = ts .TskitFormat (ts_path , individuals_nodes = all_invalid )
402
+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = all_invalid )
394
403
395
404
def test_isolated_as_missing (self , tmp_path ):
396
- def insert_branch_sites (ts , m = 1 ):
405
+ def insert_branch_sites (tsk , m = 1 ):
397
406
if m == 0 :
398
- return ts
399
- tables = ts .dump_tables ()
407
+ return tsk
408
+ tables = tsk .dump_tables ()
400
409
tables .sites .clear ()
401
410
tables .mutations .clear ()
402
- for tree in ts .trees ():
411
+ for tree in tsk .trees ():
403
412
left , right = tree .interval
404
413
delta = (right - left ) / (m * len (list (tree .nodes ())))
405
414
x = left
@@ -422,7 +431,7 @@ def insert_branch_sites(ts, m=1):
422
431
ts_path = tmp_path / "isolated_sample.trees"
423
432
tree_sequence .dump (ts_path )
424
433
ind_nodes = np .array ([[0 ], [1 ], [3 ]])
425
- format_obj_default = ts .TskitFormat (
434
+ format_obj_default = tsk .TskitFormat (
426
435
ts_path , individuals_nodes = ind_nodes , isolated_as_missing = False
427
436
)
428
437
shape = (3 , 1 ) # (num_samples, max_ploidy)
@@ -438,7 +447,7 @@ def insert_branch_sites(ts, m=1):
438
447
expected_gt_default = np .array ([[1 ], [0 ], [0 ]])
439
448
assert np .array_equal (variant_data_default .genotypes , expected_gt_default )
440
449
441
- format_obj_missing = ts .TskitFormat (
450
+ format_obj_missing = tsk .TskitFormat (
442
451
ts_path , individuals_nodes = ind_nodes , isolated_as_missing = True
443
452
)
444
453
results_missing = list (
@@ -469,7 +478,7 @@ def test_genotype_dtype_selection(self, tmp_path):
469
478
tree_sequence .dump (ts_path )
470
479
471
480
ind_nodes = np .array ([[0 , 1 ], [2 , 3 ]])
472
- format_obj = ts .TskitFormat (ts_path , individuals_nodes = ind_nodes )
481
+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = ind_nodes )
473
482
schema = format_obj .generate_schema ()
474
483
call_genotype_spec = next (s for s in schema .fields if s .name == "call_genotype" )
475
484
assert call_genotype_spec .dtype == "i1"
@@ -489,7 +498,36 @@ def test_genotype_dtype_selection(self, tmp_path):
489
498
ts_path = tmp_path / "large_alleles.trees"
490
499
tree_sequence .dump (ts_path )
491
500
492
- format_obj = ts .TskitFormat (ts_path , individuals_nodes = ind_nodes )
501
+ format_obj = tsk .TskitFormat (ts_path , individuals_nodes = ind_nodes )
493
502
schema = format_obj .generate_schema ()
494
503
call_genotype_spec = next (s for s in schema .fields if s .name == "call_genotype" )
495
504
assert call_genotype_spec .dtype == "i4"
505
+
506
+
507
+ @pytest .mark .parametrize (
508
+ "ts" ,
509
+ [
510
+ simple_ts (add_individuals = True ),
511
+ ],
512
+ )
513
+ def test_against_tskit_vcf_output (ts , tmp_path ):
514
+ vcf_path = tmp_path / "ts.vcf"
515
+ ts_path = tmp_path / "ts.trees"
516
+ ts .dump (ts_path )
517
+ with open (vcf_path , "w" ) as f :
518
+ ts .write_vcf (f )
519
+
520
+ tskit_zarr = tmp_path / "tskit.zarr"
521
+ vcf_zarr = tmp_path / "vcf.zarr"
522
+ tsk .convert (ts_path , tskit_zarr )
523
+
524
+ vcf .convert ([vcf_path ], vcf_zarr )
525
+ ds1 = sg .load_dataset (tskit_zarr )
526
+ ds2 = (
527
+ sg .load_dataset (vcf_zarr )
528
+ .drop_dims ("filters" )
529
+ .drop_vars (
530
+ ["variant_id" , "variant_id_mask" , "variant_quality" , "contig_length" ]
531
+ )
532
+ )
533
+ xt .assert_equal (ds1 , ds2 )
0 commit comments