@@ -29,6 +29,12 @@ def test_missing_dependency():
29
29
)
30
30
31
31
32
+ def tskit_model_mapping (ind_nodes , ind_names = None ):
33
+ if ind_names is None :
34
+ ind_names = ["tsk{j}" for j in range (len (ind_nodes ))]
35
+ return tskit .VcfModelMapping (ind_nodes , ind_names )
36
+
37
+
32
38
def add_mutations (ts ):
33
39
# Add some mutation to the tree sequence. This guarantees that
34
40
# we have variation at all sites > 0.
@@ -88,15 +94,6 @@ def insert_branch_sites(ts, m=1):
88
94
return tables .tree_sequence ()
89
95
90
96
91
- @pytest .fixture ()
92
- def fx_ts_isolated_samples ():
93
- tables = tskit .Tree .generate_balanced (2 , span = 10 ).tree_sequence .dump_tables ()
94
- # This also tests sample nodes that are not a single block at
95
- # the start of the nodes table.
96
- tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE )
97
- return insert_branch_sites (tables .tree_sequence ())
98
-
99
-
100
97
class TestSimpleTs :
101
98
@pytest .fixture ()
102
99
def conversion (self , tmp_path ):
@@ -193,17 +190,28 @@ class TestTskitFormat:
193
190
"""Unit tests for TskitFormat without using full conversion."""
194
191
195
192
@pytest .fixture ()
196
- def fx_simple_ts (self , tmp_path ):
193
+ def fx_simple_ts (self ):
197
194
return simple_ts (add_individuals = True )
198
195
199
196
@pytest .fixture ()
200
- def fx_ts_2_diploids (self , tmp_path ):
197
+ def fx_ts_2_diploids (self ):
201
198
ts = msprime .sim_ancestry (2 , sequence_length = 10 , random_seed = 42 )
202
199
return add_mutations (ts )
203
200
204
201
@pytest .fixture ()
205
- def fx_no_individuals_ts (self , tmp_path ):
206
- return simple_ts (add_individuals = False )
202
+ def fx_ts_isolated_samples (self ):
203
+ tables = tskit .Tree .generate_balanced (2 , span = 10 ).tree_sequence .dump_tables ()
204
+ # This also tests sample nodes that are not a single block at
205
+ # the start of the nodes table.
206
+ tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE )
207
+ return insert_branch_sites (tables .tree_sequence ())
208
+
209
+ def test_path_or_ts_input (self , tmp_path , fx_simple_ts ):
210
+ f1 = tsk .TskitFormat (fx_simple_ts )
211
+ ts_path = tmp_path / "trees.ts"
212
+ fx_simple_ts .dump (ts_path )
213
+ f2 = tsk .TskitFormat (ts_path )
214
+ f1 .ts .tables .assert_equals (f2 .ts .tables )
207
215
208
216
def test_small_position_dtype (self ):
209
217
tables = tskit .TableCollection (sequence_length = 100 )
@@ -311,6 +319,23 @@ def test_iter_field(self, fx_simple_ts):
311
319
with pytest .raises (ValueError , match = "Unknown field" ):
312
320
list (format_obj .iter_field ("unknown_field" , None , 0 , 3 ))
313
321
322
+ def test_zero_samples (self , fx_simple_ts ):
323
+ model_mapping = tskit_model_mapping (np .array ([]))
324
+ with pytest .raises (ValueError , match = "at least one sample" ):
325
+ tsk .TskitFormat (fx_simple_ts , model_mapping = model_mapping )
326
+
327
+ def test_no_valid_samples (self , fx_simple_ts ):
328
+ model_mapping = fx_simple_ts .map_to_vcf_model ()
329
+ model_mapping .individuals_nodes [:] = - 1
330
+ with pytest .raises (ValueError , match = "at least one valid sample" ):
331
+ tsk .TskitFormat (fx_simple_ts , model_mapping = model_mapping )
332
+
333
+ def test_model_size_mismatch (self , fx_simple_ts ):
334
+ model_mapping = fx_simple_ts .map_to_vcf_model ()
335
+ model_mapping .individuals_name = ["x" ]
336
+ with pytest .raises (ValueError , match = "match number of samples" ):
337
+ tsk .TskitFormat (fx_simple_ts , model_mapping = model_mapping )
338
+
314
339
@pytest .mark .parametrize (
315
340
("ind_nodes" , "expected_gts" ),
316
341
[
@@ -347,10 +372,7 @@ def test_iter_field(self, fx_simple_ts):
347
372
],
348
373
)
349
374
def test_iter_alleles_and_genotypes (self , fx_simple_ts , ind_nodes , expected_gts ):
350
- model_mapping = tskit .VcfModelMapping (
351
- ind_nodes , ["tsk{j}" for j in range (len (ind_nodes ))]
352
- )
353
-
375
+ model_mapping = tskit_model_mapping (ind_nodes )
354
376
format_obj = tsk .TskitFormat (fx_simple_ts , model_mapping = model_mapping )
355
377
356
378
shape = (2 , 2 ) # (num_samples, max_ploidy)
@@ -375,9 +397,7 @@ def test_iter_alleles_and_genotypes(self, fx_simple_ts, ind_nodes, expected_gts)
375
397
def test_iter_alleles_and_genotypes_missing_node (self , fx_ts_2_diploids ):
376
398
# Test with node ID that doesn't exist in tree sequence (out of range)
377
399
ind_nodes = np .array ([[10 , 11 ], [12 , 13 ]], dtype = np .int32 )
378
- model_mapping = tskit .VcfModelMapping (
379
- ind_nodes , ["tsk{j}" for j in range (len (ind_nodes ))]
380
- )
400
+ model_mapping = tskit_model_mapping (ind_nodes )
381
401
format_obj = tsk .TskitFormat (fx_ts_2_diploids , model_mapping = model_mapping )
382
402
shape = (2 , 2 )
383
403
with pytest .raises (
@@ -387,9 +407,7 @@ def test_iter_alleles_and_genotypes_missing_node(self, fx_ts_2_diploids):
387
407
388
408
def test_isolated_as_missing (self , fx_ts_isolated_samples ):
389
409
ind_nodes = np .array ([[0 ], [1 ], [3 ]])
390
- model_mapping = tskit .VcfModelMapping (
391
- ind_nodes , ["tsk{j}" for j in range (len (ind_nodes ))]
392
- )
410
+ model_mapping = tskit_model_mapping (ind_nodes )
393
411
394
412
format_obj_default = tsk .TskitFormat (
395
413
fx_ts_isolated_samples ,
@@ -427,7 +445,7 @@ def test_isolated_as_missing(self, fx_ts_isolated_samples):
427
445
expected_gt_missing = np .array ([[1 ], [0 ], [- 1 ]])
428
446
nt .assert_array_equal (variant_data_missing .genotypes , expected_gt_missing )
429
447
430
- def test_genotype_dtype_i1 (self , tmp_path ):
448
+ def test_genotype_dtype_i1 (self ):
431
449
tables = tskit .TableCollection (sequence_length = 100 )
432
450
for _ in range (4 ):
433
451
tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
@@ -438,15 +456,13 @@ def test_genotype_dtype_i1(self, tmp_path):
438
456
tables .mutations .add_row (site = site_id , node = 0 , derived_state = "T" )
439
457
tables .sort ()
440
458
tree_sequence = tables .tree_sequence ()
441
- ts_path = tmp_path / "small_alleles.trees"
442
- tree_sequence .dump (ts_path )
443
459
444
- format_obj = tsk .TskitFormat (ts_path )
460
+ format_obj = tsk .TskitFormat (tree_sequence )
445
461
schema = format_obj .generate_schema ()
446
462
call_genotype_spec = next (s for s in schema .fields if s .name == "call_genotype" )
447
463
assert call_genotype_spec .dtype == "i1"
448
464
449
- def test_genotype_dtype_i4 (self , tmp_path ):
465
+ def test_genotype_dtype_i4 (self ):
450
466
tables = tskit .TableCollection (sequence_length = 100 )
451
467
for _ in range (4 ):
452
468
tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
@@ -459,10 +475,8 @@ def test_genotype_dtype_i4(self, tmp_path):
459
475
460
476
tables .sort ()
461
477
tree_sequence = tables .tree_sequence ()
462
- ts_path = tmp_path / "large_alleles.trees"
463
- tree_sequence .dump (ts_path )
464
478
465
- format_obj = tsk .TskitFormat (ts_path )
479
+ format_obj = tsk .TskitFormat (tree_sequence )
466
480
schema = format_obj .generate_schema ()
467
481
call_genotype_spec = next (s for s in schema .fields if s .name == "call_genotype" )
468
482
assert call_genotype_spec .dtype == "i4"
@@ -471,6 +485,7 @@ def test_genotype_dtype_i4(self, tmp_path):
471
485
@pytest .mark .parametrize (
472
486
"ts" ,
473
487
[
488
+ # Standard individuals-with-a-given-ploidy situation
474
489
add_mutations (
475
490
msprime .sim_ancestry (4 , ploidy = 1 , sequence_length = 10 , random_seed = 42 )
476
491
),
@@ -480,20 +495,20 @@ def test_genotype_dtype_i4(self, tmp_path):
480
495
add_mutations (
481
496
msprime .sim_ancestry (3 , ploidy = 12 , sequence_length = 10 , random_seed = 142 )
482
497
),
498
+ # No individuals, ploidy1
499
+ add_mutations (msprime .simulate (4 , length = 10 , random_seed = 412 )),
483
500
],
484
501
)
485
502
def test_against_tskit_vcf_output (ts , tmp_path ):
486
503
vcf_path = tmp_path / "ts.vcf"
487
- ts_path = tmp_path / "ts.trees"
488
- ts .dump (ts_path )
489
504
with open (vcf_path , "w" ) as f :
490
505
ts .write_vcf (f )
491
506
492
507
tskit_zarr = tmp_path / "tskit.zarr"
493
508
vcf_zarr = tmp_path / "vcf.zarr"
494
- tsk .convert (ts_path , tskit_zarr )
509
+ tsk .convert (ts , tskit_zarr , worker_processes = 0 )
495
510
496
- vcf .convert ([vcf_path ], vcf_zarr )
511
+ vcf .convert ([vcf_path ], vcf_zarr , worker_processes = 0 )
497
512
ds1 = sg .load_dataset (tskit_zarr )
498
513
ds2 = (
499
514
sg .load_dataset (vcf_zarr )
0 commit comments